diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/batch.rs | 200 | ||||
-rw-r--r-- | src/chain.rs | 292 | ||||
-rw-r--r-- | src/expr/bitwise.rs | 69 | ||||
-rw-r--r-- | src/expr/cmp.rs | 220 | ||||
-rw-r--r-- | src/expr/counter.rs | 46 | ||||
-rw-r--r-- | src/expr/ct.rs | 87 | ||||
-rw-r--r-- | src/expr/immediate.rs | 126 | ||||
-rw-r--r-- | src/expr/log.rs | 112 | ||||
-rw-r--r-- | src/expr/lookup.rs | 79 | ||||
-rw-r--r-- | src/expr/masquerade.rs | 24 | ||||
-rw-r--r-- | src/expr/meta.rs | 175 | ||||
-rw-r--r-- | src/expr/mod.rs | 242 | ||||
-rw-r--r-- | src/expr/nat.rs | 99 | ||||
-rw-r--r-- | src/expr/payload.rs | 531 | ||||
-rw-r--r-- | src/expr/register.rs | 34 | ||||
-rw-r--r-- | src/expr/reject.rs | 96 | ||||
-rw-r--r-- | src/expr/verdict.rs | 148 | ||||
-rw-r--r-- | src/expr/wrapper.rs | 60 | ||||
-rw-r--r-- | src/lib.rs | 205 | ||||
-rw-r--r-- | src/query.rs | 130 | ||||
-rw-r--r-- | src/rule.rs | 341 | ||||
-rw-r--r-- | src/set.rs | 273 | ||||
-rw-r--r-- | src/table.rs | 171 |
23 files changed, 3760 insertions, 0 deletions
diff --git a/src/batch.rs b/src/batch.rs new file mode 100644 index 0000000..c8ec5aa --- /dev/null +++ b/src/batch.rs @@ -0,0 +1,200 @@ +use crate::{MsgType, NlMsg}; +use crate::sys::{self as sys, libc}; +use std::ffi::c_void; +use std::os::raw::c_char; +use std::ptr; +use thiserror::Error; + +/// Error while communicating with netlink +#[derive(Error, Debug)] +#[error("Error while communicating with netlink")] +pub struct NetlinkError(()); + +#[cfg(feature = "query")] +/// Check if the kernel supports batched netlink messages to netfilter. +pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> { + match unsafe { sys::nftnl_batch_is_supported() } { + 1 => Ok(true), + 0 => Ok(false), + _ => Err(NetlinkError(())), + } +} + +/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to +/// `nftnl_batch` in libnftnl. +pub struct Batch { + pub(crate) batch: *mut sys::nftnl_batch, + pub(crate) seq: u32, + pub(crate) is_empty: bool, +} + +impl Batch { + /// Creates a new nftnl batch with the [default page size]. + /// + /// [default page size]: fn.default_batch_page_size.html + pub fn new() -> Self { + Self::with_page_size(default_batch_page_size()) + } + + pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self { + Batch { + batch, + seq, + // we assume this batch is not empty by default + is_empty: false, + } + } + + /// Creates a new nftnl batch with the given batch size. + pub fn with_page_size(batch_page_size: u32) -> Self { + let batch = try_alloc!(unsafe { + sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize()) + }); + let mut this = Batch { + batch, + seq: 0, + is_empty: true, + }; + this.write_begin_msg(); + this + } + + /// Adds the given message to this batch. + pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) { + trace!("Writing NlMsg with seq {} to batch", self.seq); + unsafe { msg.write(self.current(), self.seq, msg_type) }; + self.is_empty = false; + self.next() + } + + /// Adds all the messages in the given iterator to this batch. If any message fails to be added + /// the error for that failure is returned and all messages up until that message stays added + /// to the batch. + pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType) + where + T: NlMsg, + I: Iterator<Item = T>, + { + for msg in msg_iter { + self.add(&msg, msg_type); + } + } + + /// Adds the final end message to the batch and returns a [`FinalizedBatch`] that can be used + /// to send the messages to netfilter. + /// + /// Return None if there is no object in the batch (this could block forever). + /// + /// [`FinalizedBatch`]: struct.FinalizedBatch.html + pub fn finalize(mut self) -> Option<FinalizedBatch> { + self.write_end_msg(); + if self.is_empty { + return None; + } + Some(FinalizedBatch { batch: self }) + } + + fn current(&self) -> *mut c_void { + unsafe { sys::nftnl_batch_buffer(self.batch) } + } + + fn next(&mut self) { + if unsafe { sys::nftnl_batch_update(self.batch) } < 0 { + // See try_alloc definition. + std::process::abort(); + } + self.seq += 1; + } + + fn write_begin_msg(&mut self) { + unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) }; + self.next(); + } + + fn write_end_msg(&mut self) { + unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) }; + self.next(); + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns the raw handle. + pub fn as_ptr(&self) -> *const sys::nftnl_batch { + self.batch as *const sys::nftnl_batch + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns a mutable version of the raw handle. + pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch { + self.batch + } +} + +impl Drop for Batch { + fn drop(&mut self) { + unsafe { sys::nftnl_batch_free(self.batch) }; + } +} + +/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper +/// batch end message. Created from [`Batch::finalize`]. +/// +/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch. +/// +/// [`Batch`]: struct.Batch.html +/// [`Batch::finalize`]: struct.Batch.html#method.finalize +pub struct FinalizedBatch { + batch: Batch, +} + +impl FinalizedBatch { + /// Returns the iterator over byte buffers to send to netlink. + pub fn iter(&mut self) -> Iter<'_> { + let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize }; + let mut iovecs = vec![ + libc::iovec { + iov_base: ptr::null_mut(), + iov_len: 0, + }; + num_pages + ]; + let iovecs_ptr = iovecs.as_mut_ptr(); + unsafe { + sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32); + } + Iter { + iovecs: iovecs.into_iter(), + _marker: ::std::marker::PhantomData, + } + } +} + +impl<'a> IntoIterator for &'a mut FinalizedBatch { + type Item = &'a [u8]; + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Iter<'a> { + self.iter() + } +} + +pub struct Iter<'a> { + iovecs: ::std::vec::IntoIter<libc::iovec>, + _marker: ::std::marker::PhantomData<&'a ()>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option<&'a [u8]> { + self.iovecs.next().map(|iovec| unsafe { + ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len) + }) + } +} + +/// selected batch page is 256 Kbytes long to load ruleset of +/// half a million rules without hitting -EMSGSIZE due to large +/// iovec. +pub fn default_batch_page_size() -> u32 { + unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u32 * 32 } +} diff --git a/src/chain.rs b/src/chain.rs new file mode 100644 index 0000000..20043ac --- /dev/null +++ b/src/chain.rs @@ -0,0 +1,292 @@ +use crate::{MsgType, Table}; +use crate::sys::{self as sys, libc}; +#[cfg(feature = "query")] +use std::convert::TryFrom; +use std::{ + ffi::{c_void, CStr, CString}, + fmt, + os::raw::c_char, + rc::Rc, +}; + +pub type Priority = i32; + +/// The netfilter event hooks a chain can register for. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(u16)] +pub enum Hook { + /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. + PreRouting = libc::NF_INET_PRE_ROUTING as u16, + /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. + In = libc::NF_INET_LOCAL_IN as u16, + /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. + Forward = libc::NF_INET_FORWARD as u16, + /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. + Out = libc::NF_INET_LOCAL_OUT as u16, + /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. + PostRouting = libc::NF_INET_POST_ROUTING as u16, +} + +/// A chain policy. Decides what to do with a packet that was processed by the chain but did not +/// match any rules. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(u32)] +pub enum Policy { + /// Accept the packet. + Accept = libc::NF_ACCEPT as u32, + /// Drop the packet. + Drop = libc::NF_DROP as u32, +} + +/// Base chain type. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub enum ChainType { + /// Used to filter packets. + /// Supported protocols: ip, ip6, inet, arp, and bridge tables. + Filter, + /// Used to reroute packets if IP headers or packet marks are modified. + /// Supported protocols: ip, and ip6 tables. + Route, + /// Used to perform NAT. + /// Supported protocols: ip, and ip6 tables. + Nat, +} + +impl ChainType { + fn as_c_str(&self) -> &'static [u8] { + match *self { + ChainType::Filter => b"filter\0", + ChainType::Route => b"route\0", + ChainType::Nat => b"nat\0", + } + } +} + +/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s. +/// +/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more +/// details. +/// +/// [`Table`]: struct.Table.html +/// [`Rule`]: struct.Rule.html +/// [`set_hook`]: #method.set_hook +pub struct Chain { + pub(crate) chain: *mut sys::nftnl_chain, + pub(crate) table: Rc<Table>, +} + +impl Chain { + /// Creates a new chain instance inside the given [`Table`] and with the given name. + /// + /// [`Table`]: struct.Table.html + pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain { + unsafe { + let chain = try_alloc!(sys::nftnl_chain_alloc()); + sys::nftnl_chain_set_u32( + chain, + sys::NFTNL_CHAIN_FAMILY as u16, + table.get_family() as u32, + ); + sys::nftnl_chain_set_str( + chain, + sys::NFTNL_CHAIN_TABLE as u16, + table.get_name().as_ptr(), + ); + sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr()); + Chain { chain, table } + } + } + + pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self { + Chain { chain, table } + } + + /// Sets the hook and priority for this chain. Without calling this method the chain well + /// become a "regular chain" without any hook and will thus not receive any traffic unless + /// some rule forward packets to it via goto or jump verdicts. + /// + /// By calling `set_hook` with a hook the chain that is created will be registered with that + /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the + /// networking stack. + pub fn set_hook(&mut self, hook: Hook, priority: Priority) { + unsafe { + sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32); + sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority); + } + } + + /// Set the type of a base chain. This only applies if the chain has been registered + /// with a hook by calling `set_hook`. + pub fn set_type(&mut self, chain_type: ChainType) { + unsafe { + sys::nftnl_chain_set_str( + self.chain, + sys::NFTNL_CHAIN_TYPE as u16, + chain_type.as_c_str().as_ptr() as *const c_char, + ); + } + } + + /// Sets the default policy for this chain. That means what action netfilter will apply to + /// packets processed by this chain, but that did not match any rules in it. + pub fn set_policy(&mut self, policy: Policy) { + unsafe { + sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32); + } + } + + /// Returns the userdata of this chain. + pub fn get_userdata(&self) -> Option<&CStr> { + unsafe { + let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16); + if ptr == std::ptr::null() { + return None; + } + Some(CStr::from_ptr(ptr)) + } + } + + /// Update the userdata of this chain. + pub fn set_userdata(&self, data: &CStr) { + unsafe { + sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr()); + } + } + + /// Returns the name of this chain. + pub fn get_name(&self) -> &CStr { + unsafe { + let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16); + if ptr.is_null() { + panic!("Impossible situation: retrieving the name of a chain failed") + } else { + CStr::from_ptr(ptr) + } + } + } + + /// Returns a textual description of the chain. + pub fn get_str(&self) -> CString { + let mut descr_buf = vec![0i8; 4096]; + unsafe { + sys::nftnl_chain_snprintf( + descr_buf.as_mut_ptr(), + (descr_buf.len() - 1) as u64, + self.chain, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + CStr::from_ptr(descr_buf.as_ptr()).to_owned() + } + } + + /// Returns a reference to the [`Table`] this chain belongs to + /// + /// [`Table`]: struct.Table.html + pub fn get_table(&self) -> Rc<Table> { + self.table.clone() + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns the raw handle. + pub fn as_ptr(&self) -> *const sys::nftnl_chain { + self.chain as *const sys::nftnl_chain + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns a mutable version of the raw handle. + pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain { + self.chain + } +} + +impl fmt::Debug for Chain { + /// Return a string representation of the chain. + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.get_str()) + } +} + +impl PartialEq for Chain { + fn eq(&self, other: &Self) -> bool { + self.get_table() == other.get_table() && self.get_name() == other.get_name() + } +} + +unsafe impl crate::NlMsg for Chain { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + let raw_msg_type = match msg_type { + MsgType::Add => libc::NFT_MSG_NEWCHAIN, + MsgType::Del => libc::NFT_MSG_DELCHAIN, + }; + let flags: u16 = match msg_type { + MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16, + MsgType::Del => libc::NLM_F_ACK as u16, + } | libc::NLM_F_ACK as u16; + let header = sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + raw_msg_type as u16, + self.table.get_family() as u16, + flags, + seq, + ); + sys::nftnl_chain_nlmsg_build_payload(header, self.chain); + } +} + +impl Drop for Chain { + fn drop(&mut self) { + unsafe { sys::nftnl_chain_free(self.chain) }; + } +} + +#[cfg(feature = "query")] +pub fn get_chains_cb<'a>( + header: &libc::nlmsghdr, + (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), +) -> libc::c_int { + unsafe { + let chain = sys::nftnl_chain_alloc(); + if chain == std::ptr::null_mut() { + return mnl::mnl_sys::MNL_CB_ERROR; + } + let err = sys::nftnl_chain_nlmsg_parse(header, chain); + if err < 0 { + error!("Failed to parse nelink chain message - {}", err); + sys::nftnl_chain_free(chain); + return err; + } + + let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( + chain, + sys::NFTNL_CHAIN_TABLE as u16, + )); + let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16); + let family = match crate::ProtoFamily::try_from(family as i32) { + Ok(family) => family, + Err(crate::InvalidProtocolFamily) => { + error!("The netlink table didn't have a valid protocol family !?"); + sys::nftnl_chain_free(chain); + return mnl::mnl_sys::MNL_CB_ERROR; + } + }; + + if table_name != table.get_name() { + sys::nftnl_chain_free(chain); + return mnl::mnl_sys::MNL_CB_OK; + } + + if family != crate::ProtoFamily::Unspec && family != table.get_family() { + sys::nftnl_chain_free(chain); + return mnl::mnl_sys::MNL_CB_OK; + } + + chains.push(Chain::from_raw(chain, table.clone())); + } + mnl::mnl_sys::MNL_CB_OK +} + +#[cfg(feature = "query")] +pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { + crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) +} diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs new file mode 100644 index 0000000..59ef41b --- /dev/null +++ b/src/expr/bitwise.rs @@ -0,0 +1,69 @@ +use super::{Expression, Rule, ToSlice}; +use crate::sys::{self, libc}; +use std::ffi::c_void; +use std::os::raw::c_char; + +/// Expression for performing bitwise masking and XOR on the data in a register. +pub struct Bitwise<M: ToSlice, X: ToSlice> { + mask: M, + xor: X, +} + +impl<M: ToSlice, X: ToSlice> Bitwise<M, X> { + /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` + /// and then performs xor with the value in `xor`. + pub fn new(mask: M, xor: X) -> Self { + Self { mask, xor } + } +} + +impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> { + fn get_raw_name() -> *const c_char { + b"bitwise\0" as *const _ as *const c_char + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + let mask = self.mask.to_slice(); + let xor = self.xor.to_slice(); + assert!(mask.len() == xor.len()); + let len = mask.len() as u32; + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_BITWISE_SREG as u16, + libc::NFT_REG_1 as u32, + ); + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_BITWISE_DREG as u16, + libc::NFT_REG_1 as u32, + ); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_BITWISE_LEN as u16, len); + + sys::nftnl_expr_set( + expr, + sys::NFTNL_EXPR_BITWISE_MASK as u16, + mask.as_ref() as *const _ as *const c_void, + len, + ); + sys::nftnl_expr_set( + expr, + sys::NFTNL_EXPR_BITWISE_XOR as u16, + xor.as_ref() as *const _ as *const c_void, + len, + ); + + expr + } + } +} + +#[macro_export] +macro_rules! nft_expr_bitwise { + (mask $mask:expr,xor $xor:expr) => { + $crate::expr::Bitwise::new($mask, $xor) + }; +} diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs new file mode 100644 index 0000000..384f0b4 --- /dev/null +++ b/src/expr/cmp.rs @@ -0,0 +1,220 @@ +use super::{DeserializationError, Expression, Rule, ToSlice}; +use crate::sys::{self, libc}; +use std::{ + borrow::Cow, + ffi::{c_void, CString}, + os::raw::c_char, +}; + +/// Comparison operator. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum CmpOp { + /// Equals. + Eq, + /// Not equal. + Neq, + /// Less than. + Lt, + /// Less than, or equal. + Lte, + /// Greater than. + Gt, + /// Greater than, or equal. + Gte, +} + +impl CmpOp { + /// Returns the corresponding `NFT_*` constant for this comparison operation. + pub fn to_raw(self) -> u32 { + use self::CmpOp::*; + match self { + Eq => libc::NFT_CMP_EQ as u32, + Neq => libc::NFT_CMP_NEQ as u32, + Lt => libc::NFT_CMP_LT as u32, + Lte => libc::NFT_CMP_LTE as u32, + Gt => libc::NFT_CMP_GT as u32, + Gte => libc::NFT_CMP_GTE as u32, + } + } + + pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { + use self::CmpOp::*; + match val as i32 { + libc::NFT_CMP_EQ => Ok(Eq), + libc::NFT_CMP_NEQ => Ok(Neq), + libc::NFT_CMP_LT => Ok(Lt), + libc::NFT_CMP_LTE => Ok(Lte), + libc::NFT_CMP_GT => Ok(Gt), + libc::NFT_CMP_GTE => Ok(Gte), + _ => Err(DeserializationError::InvalidValue), + } + } +} + +/// Comparator expression. Allows comparing the content of the netfilter register with any value. +#[derive(Debug, PartialEq)] +pub struct Cmp<T> { + op: CmpOp, + data: T, +} + +impl<T: ToSlice> Cmp<T> { + /// Returns a new comparison expression comparing the value loaded in the register with the + /// data in `data` using the comparison operator `op`. + pub fn new(op: CmpOp, data: T) -> Self { + Cmp { op, data } + } +} + +impl<T: ToSlice> Expression for Cmp<T> { + fn get_raw_name() -> *const c_char { + b"cmp\0" as *const _ as *const c_char + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + let data = self.data.to_slice(); + trace!("Creating a cmp expr comparing with data {:?}", data); + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_CMP_SREG as u16, + libc::NFT_REG_1 as u32, + ); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16, self.op.to_raw()); + sys::nftnl_expr_set( + expr, + sys::NFTNL_EXPR_CMP_DATA as u16, + data.as_ptr() as *const c_void, + data.len() as u32, + ); + + expr + } + } +} + +impl<const N: usize> Expression for Cmp<[u8; N]> { + fn get_raw_name() -> *const c_char { + Cmp::<u8>::get_raw_name() + } + + /// The raw data contained inside `Cmp` expressions can only be deserialized to + /// arrays of bytes, to ensure that the memory layout of retrieved data cannot be + /// violated. It is your responsibility to provide the correct length of the byte + /// data. If the data size is invalid, you will get the error + /// `DeserializationError::InvalidDataSize`. + /// + /// Example (warning, no error checking!): + /// ```rust + /// use std::ffi::CString; + /// use std::net::Ipv4Addr; + /// use std::rc::Rc; + /// + /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table}; + /// + /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); + /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); + /// let mut rule = Rule::new(chain); + /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16)); + /// for expr in Rc::new(rule).get_exprs() { + /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap()); + /// } + /// ``` + /// These limitations occur because casting bytes to any type of the same size + /// as the raw input would be *extremely* dangerous in terms of memory safety. + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { + unsafe { + let ref_len = std::mem::size_of::<[u8; N]>() as u32; + let mut data_len = 0; + let data = sys::nftnl_expr_get( + expr, + sys::NFTNL_EXPR_CMP_DATA as u16, + &mut data_len as *mut u32, + ); + + if data.is_null() { + return Err(DeserializationError::NullPointer); + } else if data_len != ref_len { + return Err(DeserializationError::InvalidDataSize); + } + + let data = *(data as *const [u8; N]); + + let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?; + Ok(Cmp { op, data }) + } + } + + // call to the other implementation to generate the expression + fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + Cmp { + data: &self.data as &[u8], + op: self.op, + } + .to_expr(rule) + } +} + +#[macro_export(local_inner_macros)] +macro_rules! nft_expr_cmp { + (@cmp_op ==) => { + $crate::expr::CmpOp::Eq + }; + (@cmp_op !=) => { + $crate::expr::CmpOp::Neq + }; + (@cmp_op <) => { + $crate::expr::CmpOp::Lt + }; + (@cmp_op <=) => { + $crate::expr::CmpOp::Lte + }; + (@cmp_op >) => { + $crate::expr::CmpOp::Gt + }; + (@cmp_op >=) => { + $crate::expr::CmpOp::Gte + }; + ($op:tt $data:expr) => { + $crate::expr::Cmp::new(nft_expr_cmp!(@cmp_op $op), $data) + }; +} + +/// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please +/// note that it is faster to check interface index than name. +/// +/// [`Meta::IifName`]: enum.Meta.html#variant.IifName +/// [`Meta::OifName`]: enum.Meta.html#variant.OifName +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum InterfaceName { + /// Interface name must be exactly the value of the `CString`. + Exact(CString), + /// Interface name must start with the value of the `CString`. + /// + /// `InterfaceName::StartingWith("eth")` will look like `eth*` when printed and match against + /// `eth0`, `eth1`, ..., `eth99` and so on. + StartingWith(CString), +} + +impl ToSlice for InterfaceName { + fn to_slice(&self) -> Cow<'_, [u8]> { + let bytes = match *self { + InterfaceName::Exact(ref name) => name.as_bytes_with_nul(), + InterfaceName::StartingWith(ref name) => name.as_bytes(), + }; + Cow::from(bytes) + } +} + +impl<'a> ToSlice for &'a InterfaceName { + fn to_slice(&self) -> Cow<'_, [u8]> { + let bytes = match *self { + InterfaceName::Exact(ref name) => name.as_bytes_with_nul(), + InterfaceName::StartingWith(ref name) => name.as_bytes(), + }; + Cow::from(bytes) + } +} diff --git a/src/expr/counter.rs b/src/expr/counter.rs new file mode 100644 index 0000000..71064df --- /dev/null +++ b/src/expr/counter.rs @@ -0,0 +1,46 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys; +use std::os::raw::c_char; + +/// A counter expression adds a counter to the rule that is incremented to count number of packets +/// and number of bytes for all packets that has matched the rule. +#[derive(Debug, PartialEq)] +pub struct Counter { + pub nb_bytes: u64, + pub nb_packets: u64, +} + +impl Counter { + pub fn new() -> Self { + Self { + nb_bytes: 0, + nb_packets: 0, + } + } +} + +impl Expression for Counter { + fn get_raw_name() -> *const c_char { + b"counter\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { + unsafe { + let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); + let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); + Ok(Counter { + nb_bytes, + nb_packets, + }) + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes); + sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets); + expr + } + } +} diff --git a/src/expr/ct.rs b/src/expr/ct.rs new file mode 100644 index 0000000..7d6614c --- /dev/null +++ b/src/expr/ct.rs @@ -0,0 +1,87 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys::{self, libc}; +use std::os::raw::c_char; + +bitflags::bitflags! { + pub struct States: u32 { + const INVALID = 1; + const ESTABLISHED = 2; + const RELATED = 4; + const NEW = 8; + const UNTRACKED = 64; + } +} + +pub enum Conntrack { + State, + Mark { set: bool }, +} + +impl Conntrack { + fn raw_key(&self) -> u32 { + match *self { + Conntrack::State => libc::NFT_CT_STATE as u32, + Conntrack::Mark { .. } => libc::NFT_CT_MARK as u32, + } + } +} + +impl Expression for Conntrack { + fn get_raw_name() -> *const c_char { + b"ct\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16); + let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16); + + match ct_key as i32 { + libc::NFT_CT_STATE => Ok(Conntrack::State), + libc::NFT_CT_MARK => Ok(Conntrack::Mark { + set: ct_sreg_is_set, + }), + _ => Err(DeserializationError::InvalidValue), + } + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + if let Conntrack::Mark { set: true } = self { + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_CT_SREG as u16, + libc::NFT_REG_1 as u32, + ); + } else { + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_CT_DREG as u16, + libc::NFT_REG_1 as u32, + ); + } + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16, self.raw_key()); + + expr + } + } +} + +#[macro_export] +macro_rules! nft_expr_ct { + (state) => { + $crate::expr::Conntrack::State + }; + (mark set) => { + $crate::expr::Conntrack::Mark { set: true } + }; + (mark) => { + $crate::expr::Conntrack::Mark { set: false } + }; +} diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs new file mode 100644 index 0000000..0787e06 --- /dev/null +++ b/src/expr/immediate.rs @@ -0,0 +1,126 @@ +use super::{DeserializationError, Expression, Register, Rule, ToSlice}; +use crate::sys; +use std::ffi::c_void; +use std::os::raw::c_char; + +/// An immediate expression. Used to set immediate data. +/// Verdicts are handled separately by [crate::expr::Verdict]. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Immediate<T> { + pub data: T, + pub register: Register, +} + +impl<T> Immediate<T> { + pub fn new(data: T, register: Register) -> Self { + Self { data, register } + } +} + +impl<T: ToSlice> Expression for Immediate<T> { + fn get_raw_name() -> *const c_char { + b"immediate\0" as *const _ as *const c_char + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_IMM_DREG as u16, + self.register.to_raw(), + ); + + let data = self.data.to_slice(); + sys::nftnl_expr_set( + expr, + sys::NFTNL_EXPR_IMM_DATA as u16, + data.as_ptr() as *const c_void, + data.len() as u32, + ); + + expr + } + } +} + +impl<const N: usize> Expression for Immediate<[u8; N]> { + fn get_raw_name() -> *const c_char { + Immediate::<u8>::get_raw_name() + } + + /// The raw data contained inside `Immediate` expressions can only be deserialized to + /// arrays of bytes, to ensure that the memory layout of retrieved data cannot be + /// violated. It is your responsibility to provide the correct length of the byte + /// data. If the data size is invalid, you will get the error + /// `DeserializationError::InvalidDataSize`. + /// + /// Example (warning, no error checking!): + /// ```rust + /// use std::ffi::CString; + /// use std::net::Ipv4Addr; + /// use std::rc::Rc; + /// + /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table}; + /// + /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); + /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); + /// let mut rule = Rule::new(chain); + /// rule.add_expr(&Immediate::new(42u8, Register::Reg1)); + /// for expr in Rc::new(rule).get_exprs() { + /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap()); + /// } + /// ``` + /// These limitations occur because casting bytes to any type of the same size + /// as the raw input would be *extremely* dangerous in terms of memory safety. + // As casting bytes to any type of the same size as the input would + // be *extremely* dangerous in terms of memory safety, + // rustables only accept to deserialize expressions with variable-size data + // to arrays of bytes, so that the memory layout cannot be invalid. + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { + unsafe { + let ref_len = std::mem::size_of::<[u8; N]>() as u32; + let mut data_len = 0; + let data = sys::nftnl_expr_get( + expr, + sys::NFTNL_EXPR_IMM_DATA as u16, + &mut data_len as *mut u32, + ); + + if data.is_null() { + return Err(DeserializationError::NullPointer); + } else if data_len != ref_len { + return Err(DeserializationError::InvalidDataSize); + } + + let data = *(data as *const [u8; N]); + + let register = Register::from_raw(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_IMM_DREG as u16, + ))?; + + Ok(Immediate { data, register }) + } + } + + // call to the other implementation to generate the expression + fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + Immediate { + register: self.register, + data: &self.data as &[u8], + } + .to_expr(rule) + } +} + +#[macro_export] +macro_rules! nft_expr_immediate { + (data $value:expr) => { + $crate::expr::Immediate { + data: $value, + register: $crate::expr::Register::Reg1, + } + }; +} diff --git a/src/expr/log.rs b/src/expr/log.rs new file mode 100644 index 0000000..5c06897 --- /dev/null +++ b/src/expr/log.rs @@ -0,0 +1,112 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys; +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use thiserror::Error; + +/// A Log expression will log all packets that match the rule. +#[derive(Debug, PartialEq)] +pub struct Log { + pub group: Option<LogGroup>, + pub prefix: Option<LogPrefix>, +} + +impl Expression for Log { + fn get_raw_name() -> *const sys::libc::c_char { + b"log\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + let mut group = None; + if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) { + group = Some(LogGroup(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_LOG_GROUP as u16, + ) as u16)); + } + let mut prefix = None; + if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { + let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); + if raw_prefix.is_null() { + return Err(DeserializationError::NullPointer); + } else { + prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); + } + } + Ok(Log { group, prefix }) + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char)); + if let Some(log_group) = self.group { + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32); + }; + if let Some(LogPrefix(prefix)) = &self.prefix { + sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr()); + }; + + expr + } + } +} + +#[derive(Error, Debug)] +pub enum LogPrefixError { + #[error("The log prefix string is more than 128 characters long")] + TooLongPrefix, + #[error("The log prefix string contains an invalid Nul character.")] + PrefixContainsANul(#[from] std::ffi::NulError), +} + +/// The NFLOG group that will be assigned to each log line. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub struct LogGroup(pub u16); + +/// A prefix that will get prepended to each log line. +#[derive(Debug, Clone, PartialEq)] +pub struct LogPrefix(CString); + +impl LogPrefix { + /// Create a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note + /// that LogPrefix should not be more than 127 characters long. + pub fn new(prefix: &str) -> Result<Self, LogPrefixError> { + if prefix.chars().count() > 127 { + return Err(LogPrefixError::TooLongPrefix); + } + Ok(LogPrefix(CString::new(prefix)?)) + } +} + +#[macro_export] +macro_rules! nft_expr_log { + (group $group:ident prefix $prefix:expr) => { + $crate::expr::Log { + group: $group, + prefix: $prefix, + } + }; + (prefix $prefix:expr) => { + $crate::expr::Log { + group: None, + prefix: $prefix, + } + }; + (group $group:ident) => { + $crate::expr::Log { + group: $group, + prefix: None, + } + }; + () => { + $crate::expr::Log { + group: None, + prefix: None, + } + }; +} diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs new file mode 100644 index 0000000..8e288a0 --- /dev/null +++ b/src/expr/lookup.rs @@ -0,0 +1,79 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::set::Set; +use crate::sys::{self, libc}; +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +#[derive(Debug, PartialEq)] +pub struct Lookup { + set_name: CString, + set_id: u32, +} + +impl Lookup { + /// Creates a new lookup entry. + /// May return None if the set have no name. + pub fn new<K>(set: &Set<'_, K>) -> Option<Self> { + set.get_name().map(|set_name| Lookup { + set_name: set_name.to_owned(), + set_id: set.get_id(), + }) + } +} + +impl Expression for Lookup { + fn get_raw_name() -> *const libc::c_char { + b"lookup\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16); + let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); + + if set_name.is_null() { + return Err(DeserializationError::NullPointer); + } + + let set_name = CStr::from_ptr(set_name).to_owned(); + + Ok(Lookup { set_id, set_name }) + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_LOOKUP_SREG as u16, + libc::NFT_REG_1 as u32, + ); + sys::nftnl_expr_set_str( + expr, + sys::NFTNL_EXPR_LOOKUP_SET as u16, + self.set_name.as_ptr() as *const _ as *const c_char, + ); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id); + + // This code is left here since it's quite likely we need it again when we get further + // if self.reverse { + // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16, + // libc::NFT_LOOKUP_F_INV as u32); + // } + + expr + } + } +} + +#[macro_export] +macro_rules! nft_expr_lookup { + ($set:expr) => { + $crate::expr::Lookup::new($set) + }; +} diff --git a/src/expr/masquerade.rs b/src/expr/masquerade.rs new file mode 100644 index 0000000..c1a06de --- /dev/null +++ b/src/expr/masquerade.rs @@ -0,0 +1,24 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys; +use std::os::raw::c_char; + +/// Sets the source IP to that of the output interface. +#[derive(Debug, PartialEq)] +pub struct Masquerade; + +impl Expression for Masquerade { + fn get_raw_name() -> *const sys::libc::c_char { + b"masq\0" as *const _ as *const c_char + } + + fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + Ok(Masquerade) + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }) + } +} diff --git a/src/expr/meta.rs b/src/expr/meta.rs new file mode 100644 index 0000000..bf77774 --- /dev/null +++ b/src/expr/meta.rs @@ -0,0 +1,175 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys::{self, libc}; +use std::os::raw::c_char; + +/// A meta expression refers to meta data associated with a packet. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum Meta { + /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. + Protocol, + /// Packet mark. + Mark { set: bool }, + /// Packet input interface index (dev->ifindex). + Iif, + /// Packet output interface index (dev->ifindex). + Oif, + /// Packet input interface name (dev->name) + IifName, + /// Packet output interface name (dev->name). + OifName, + /// Packet input interface type (dev->type). + IifType, + /// Packet output interface type (dev->type). + OifType, + /// Originating socket UID (fsuid). + SkUid, + /// Originating socket GID (fsgid). + SkGid, + /// Netfilter protocol (Transport layer protocol). + NfProto, + /// Layer 4 protocol number. + L4Proto, + /// Socket control group (skb->sk->sk_classid). + Cgroup, + /// A 32bit pseudo-random number + PRandom, +} + +impl Meta { + /// Returns the corresponding `NFT_*` constant for this meta expression. + pub fn to_raw_key(&self) -> u32 { + use Meta::*; + match *self { + Protocol => libc::NFT_META_PROTOCOL as u32, + Mark { .. } => libc::NFT_META_MARK as u32, + Iif => libc::NFT_META_IIF as u32, + Oif => libc::NFT_META_OIF as u32, + IifName => libc::NFT_META_IIFNAME as u32, + OifName => libc::NFT_META_OIFNAME as u32, + IifType => libc::NFT_META_IIFTYPE as u32, + OifType => libc::NFT_META_OIFTYPE as u32, + SkUid => libc::NFT_META_SKUID as u32, + SkGid => libc::NFT_META_SKGID as u32, + NfProto => libc::NFT_META_NFPROTO as u32, + L4Proto => libc::NFT_META_L4PROTO as u32, + Cgroup => libc::NFT_META_CGROUP as u32, + PRandom => libc::NFT_META_PRANDOM as u32, + } + } + + fn from_raw(val: u32) -> Result<Self, DeserializationError> { + match val as i32 { + libc::NFT_META_PROTOCOL => Ok(Self::Protocol), + libc::NFT_META_MARK => Ok(Self::Mark { set: false }), + libc::NFT_META_IIF => Ok(Self::Iif), + libc::NFT_META_OIF => Ok(Self::Oif), + libc::NFT_META_IIFNAME => Ok(Self::IifName), + libc::NFT_META_OIFNAME => Ok(Self::OifName), + libc::NFT_META_IIFTYPE => Ok(Self::IifType), + libc::NFT_META_OIFTYPE => Ok(Self::OifType), + libc::NFT_META_SKUID => Ok(Self::SkUid), + libc::NFT_META_SKGID => Ok(Self::SkGid), + libc::NFT_META_NFPROTO => Ok(Self::NfProto), + libc::NFT_META_L4PROTO => Ok(Self::L4Proto), + libc::NFT_META_CGROUP => Ok(Self::Cgroup), + libc::NFT_META_PRANDOM => Ok(Self::PRandom), + _ => Err(DeserializationError::InvalidValue), + } + } +} + +impl Expression for Meta { + fn get_raw_name() -> *const libc::c_char { + b"meta\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_META_KEY as u16, + ))?; + + if let Self::Mark { ref mut set } = ret { + *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); + } + + Ok(ret) + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + if let Meta::Mark { set: true } = self { + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_META_SREG as u16, + libc::NFT_REG_1 as u32, + ); + } else { + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_META_DREG as u16, + libc::NFT_REG_1 as u32, + ); + } + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key()); + expr + } + } +} + +#[macro_export] +macro_rules! nft_expr_meta { + (proto) => { + $crate::expr::Meta::Protocol + }; + (mark set) => { + $crate::expr::Meta::Mark { set: true } + }; + (mark) => { + $crate::expr::Meta::Mark { set: false } + }; + (iif) => { + $crate::expr::Meta::Iif + }; + (oif) => { + $crate::expr::Meta::Oif + }; + (iifname) => { + $crate::expr::Meta::IifName + }; + (oifname) => { + $crate::expr::Meta::OifName + }; + (iiftype) => { + $crate::expr::Meta::IifType + }; + (oiftype) => { + $crate::expr::Meta::OifType + }; + (skuid) => { + $crate::expr::Meta::SkUid + }; + (skgid) => { + $crate::expr::Meta::SkGid + }; + (nfproto) => { + $crate::expr::Meta::NfProto + }; + (l4proto) => { + $crate::expr::Meta::L4Proto + }; + (cgroup) => { + $crate::expr::Meta::Cgroup + }; + (random) => { + $crate::expr::Meta::PRandom + }; +} diff --git a/src/expr/mod.rs b/src/expr/mod.rs new file mode 100644 index 0000000..fbf49d6 --- /dev/null +++ b/src/expr/mod.rs @@ -0,0 +1,242 @@ +//! A module with all the nftables expressions that can be added to [`Rule`]s to build up how +//! they match against packets. +//! +//! [`Rule`]: struct.Rule.html + +use std::borrow::Cow; +use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; + +use super::rule::Rule; +use crate::sys::{self, libc}; +use thiserror::Error; + +mod bitwise; +pub use self::bitwise::*; + +mod cmp; +pub use self::cmp::*; + +mod counter; +pub use self::counter::*; + +pub mod ct; +pub use self::ct::*; + +mod immediate; +pub use self::immediate::*; + +mod log; +pub use self::log::*; + +mod lookup; +pub use self::lookup::*; + +mod masquerade; +pub use self::masquerade::*; + +mod meta; +pub use self::meta::*; + +mod nat; +pub use self::nat::*; + +mod payload; +pub use self::payload::*; + +mod reject; +pub use self::reject::{IcmpCode, Reject}; + +mod register; +pub use self::register::Register; + +mod verdict; +pub use self::verdict::*; + +mod wrapper; +pub use self::wrapper::ExpressionWrapper; + +#[derive(Debug, Error)] +pub enum DeserializationError { + #[error("The expected expression type doesn't match the name of the raw expression")] + /// The expected expression type doesn't match the name of the raw expression + InvalidExpressionKind, + + #[error("Deserializing the requested type isn't implemented yet")] + /// Deserializing the requested type isn't implemented yet + NotImplemented, + + #[error("The expression value cannot be deserialized to the requested type")] + /// The expression value cannot be deserialized to the requested type + InvalidValue, + + #[error("A pointer was null while a non-null pointer was expected")] + /// A pointer was null while a non-null pointer was expected + NullPointer, + + #[error( + "The size of a raw value was incoherent with the expected type of the deserialized value" + )] + /// The size of a raw value was incoherent with the expected type of the deserialized value + InvalidDataSize, + + #[error(transparent)] + /// Couldn't find a matching protocol + InvalidProtolFamily(#[from] super::InvalidProtocolFamily), +} + +/// Trait for every safe wrapper of an nftables expression. +pub trait Expression { + /// Returns the raw name used by nftables to identify the rule. + fn get_raw_name() -> *const libc::c_char; + + /// Try to parse the expression from a raw nftables expression, + /// returning a [DeserializationError] if the attempted parsing failed. + fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + Err(DeserializationError::NotImplemented) + } + + /// Allocates and returns the low level `nftnl_expr` representation of this expression. + /// The caller to this method is responsible for freeing the expression. + fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; +} + +/// A type that can be converted into a byte buffer. +pub trait ToSlice { + /// Returns the data this type represents. + fn to_slice(&self) -> Cow<'_, [u8]>; +} + +impl<'a> ToSlice for &'a [u8] { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(self) + } +} + +impl<'a> ToSlice for &'a [u16] { + fn to_slice(&self) -> Cow<'_, [u8]> { + let ptr = self.as_ptr() as *const u8; + let len = self.len() * 2; + Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) + } +} + +impl ToSlice for IpAddr { + fn to_slice(&self) -> Cow<'_, [u8]> { + match *self { + IpAddr::V4(ref addr) => addr.to_slice(), + IpAddr::V6(ref addr) => addr.to_slice(), + } + } +} + +impl ToSlice for Ipv4Addr { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(self.octets().to_vec()) + } +} + +impl ToSlice for Ipv6Addr { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(self.octets().to_vec()) + } +} + +impl ToSlice for u8 { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(vec![*self]) + } +} + +impl ToSlice for u16 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = (*self & 0x00ff) as u8; + let b1 = (*self >> 8) as u8; + Cow::Owned(vec![b0, b1]) + } +} + +impl ToSlice for u32 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = *self as u8; + let b1 = (*self >> 8) as u8; + let b2 = (*self >> 16) as u8; + let b3 = (*self >> 24) as u8; + Cow::Owned(vec![b0, b1, b2, b3]) + } +} + +impl ToSlice for i32 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = *self as u8; + let b1 = (*self >> 8) as u8; + let b2 = (*self >> 16) as u8; + let b3 = (*self >> 24) as u8; + Cow::Owned(vec![b0, b1, b2, b3]) + } +} + +impl<'a> ToSlice for &'a str { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::from(self.as_bytes()) + } +} + +#[macro_export(local_inner_macros)] +macro_rules! nft_expr { + (bitwise mask $mask:expr,xor $xor:expr) => { + nft_expr_bitwise!(mask $mask, xor $xor) + }; + (cmp $op:tt $data:expr) => { + nft_expr_cmp!($op $data) + }; + (counter) => { + $crate::expr::Counter { nb_bytes: 0, nb_packets: 0} + }; + (ct $key:ident set) => { + nft_expr_ct!($key set) + }; + (ct $key:ident) => { + nft_expr_ct!($key) + }; + (immediate $expr:ident $value:expr) => { + nft_expr_immediate!($expr $value) + }; + (log group $group:ident prefix $prefix:expr) => { + nft_expr_log!(group $group prefix $prefix) + }; + (log group $group:ident) => { + nft_expr_log!(group $group) + }; + (log prefix $prefix:expr) => { + nft_expr_log!(prefix $prefix) + }; + (log) => { + nft_expr_log!() + }; + (lookup $set:expr) => { + nft_expr_lookup!($set) + }; + (masquerade) => { + $crate::expr::Masquerade + }; + (meta $expr:ident set) => { + nft_expr_meta!($expr set) + }; + (meta $expr:ident) => { + nft_expr_meta!($expr) + }; + (payload $proto:ident $field:ident) => { + nft_expr_payload!($proto $field) + }; + (verdict $verdict:ident) => { + nft_expr_verdict!($verdict) + }; + (verdict $verdict:ident $chain:expr) => { + nft_expr_verdict!($verdict $chain) + }; +} diff --git a/src/expr/nat.rs b/src/expr/nat.rs new file mode 100644 index 0000000..8beaa30 --- /dev/null +++ b/src/expr/nat.rs @@ -0,0 +1,99 @@ +use super::{DeserializationError, Expression, Register, Rule}; +use crate::ProtoFamily; +use crate::sys::{self, libc}; +use std::{convert::TryFrom, os::raw::c_char}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(i32)] +pub enum NatType { + /// Source NAT. Changes the source address of a packet + SNat = libc::NFT_NAT_SNAT, + /// Destination NAT. Changeth the destination address of a packet + DNat = libc::NFT_NAT_DNAT, +} + +impl NatType { + fn from_raw(val: u32) -> Result<Self, DeserializationError> { + match val as i32 { + libc::NFT_NAT_SNAT => Ok(NatType::SNat), + libc::NFT_NAT_DNAT => Ok(NatType::DNat), + _ => Err(DeserializationError::InvalidValue), + } + } +} + +/// A source or destination NAT statement. Modifies the source or destination address +/// (and possibly port) of packets. +#[derive(Debug, PartialEq)] +pub struct Nat { + pub nat_type: NatType, + pub family: ProtoFamily, + pub ip_register: Register, + pub port_register: Option<Register>, +} + +impl Expression for Nat { + fn get_raw_name() -> *const libc::c_char { + b"nat\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_NAT_TYPE as u16, + ))?; + + let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_NAT_FAMILY as u16, + ) as i32)?; + + let ip_register = Register::from_raw(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, + ))?; + + let mut port_register = None; + if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) { + port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32( + expr, + sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, + ))?); + } + + Ok(Nat { + ip_register, + nat_type, + family, + port_register, + }) + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }); + + unsafe { + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, self.family as u32); + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, + self.ip_register.to_raw(), + ); + if let Some(port_register) = self.port_register { + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, + port_register.to_raw(), + ); + } + } + + expr + } +} diff --git a/src/expr/payload.rs b/src/expr/payload.rs new file mode 100644 index 0000000..7612fd9 --- /dev/null +++ b/src/expr/payload.rs @@ -0,0 +1,531 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys::{self, libc}; +use std::os::raw::c_char; + +pub trait HeaderField { + fn offset(&self) -> u32; + fn len(&self) -> u32; +} + +/// Payload expressions refer to data from the packet's payload. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Payload { + LinkLayer(LLHeaderField), + Network(NetworkHeaderField), + Transport(TransportHeaderField), +} + +impl Payload { + pub fn build(&self) -> RawPayload { + match *self { + Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData { + offset: f.offset(), + len: f.len(), + }), + Payload::Network(ref f) => RawPayload::Network(RawPayloadData { + offset: f.offset(), + len: f.len(), + }), + Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData { + offset: f.offset(), + len: f.offset(), + }), + } + } +} + +impl Expression for Payload { + fn get_raw_name() -> *const libc::c_char { + RawPayload::get_raw_name() + } + + fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + self.build().to_expr(rule) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct RawPayloadData { + offset: u32, + len: u32, +} + +/// Because deserializing a `Payload` expression is not possible (there is not enough information +/// in the expression itself, this enum should be used to deserialize payloads. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum RawPayload { + LinkLayer(RawPayloadData), + Network(RawPayloadData), + Transport(RawPayloadData), +} + +impl RawPayload { + fn base(&self) -> u32 { + match self { + Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32, + Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32, + Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32, + } + } +} + +impl HeaderField for RawPayload { + fn offset(&self) -> u32 { + match self { + Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset, + } + } + + fn len(&self) -> u32 { + match self { + Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len, + } + } +} + +impl Expression for RawPayload { + fn get_raw_name() -> *const libc::c_char { + b"payload\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { + unsafe { + let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16); + let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16); + let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16); + match base as i32 { + libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })), + libc::NFT_PAYLOAD_NETWORK_HEADER => { + Ok(Self::Network(RawPayloadData { offset, len })) + } + libc::NFT_PAYLOAD_TRANSPORT_HEADER => { + Ok(Self::Transport(RawPayloadData { offset, len })) + } + + _ => return Err(DeserializationError::InvalidValue), + } + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base()); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset()); + sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16, self.len()); + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_PAYLOAD_DREG as u16, + libc::NFT_REG_1 as u32, + ); + + expr + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum LLHeaderField { + Daddr, + Saddr, + EtherType, +} + +impl HeaderField for LLHeaderField { + fn offset(&self) -> u32 { + use self::LLHeaderField::*; + match *self { + Daddr => 0, + Saddr => 6, + EtherType => 12, + } + } + + fn len(&self) -> u32 { + use self::LLHeaderField::*; + match *self { + Daddr => 6, + Saddr => 6, + EtherType => 2, + } + } +} + +impl LLHeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 0 && len == 6 { + Ok(Self::Daddr) + } else if off == 6 && len == 6 { + Ok(Self::Saddr) + } else if off == 12 && len == 2 { + Ok(Self::EtherType) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum NetworkHeaderField { + Ipv4(Ipv4HeaderField), + Ipv6(Ipv6HeaderField), +} + +impl HeaderField for NetworkHeaderField { + fn offset(&self) -> u32 { + use self::NetworkHeaderField::*; + match *self { + Ipv4(ref f) => f.offset(), + Ipv6(ref f) => f.offset(), + } + } + + fn len(&self) -> u32 { + use self::NetworkHeaderField::*; + match *self { + Ipv4(ref f) => f.len(), + Ipv6(ref f) => f.len(), + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum Ipv4HeaderField { + Ttl, + Protocol, + Saddr, + Daddr, +} + +impl HeaderField for Ipv4HeaderField { + fn offset(&self) -> u32 { + use self::Ipv4HeaderField::*; + match *self { + Ttl => 8, + Protocol => 9, + Saddr => 12, + Daddr => 16, + } + } + + fn len(&self) -> u32 { + use self::Ipv4HeaderField::*; + match *self { + Ttl => 1, + Protocol => 1, + Saddr => 4, + Daddr => 4, + } + } +} + +impl Ipv4HeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 8 && len == 1 { + Ok(Self::Ttl) + } else if off == 9 && len == 1 { + Ok(Self::Protocol) + } else if off == 12 && len == 4 { + Ok(Self::Saddr) + } else if off == 16 && len == 4 { + Ok(Self::Daddr) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum Ipv6HeaderField { + NextHeader, + HopLimit, + Saddr, + Daddr, +} + +impl HeaderField for Ipv6HeaderField { + fn offset(&self) -> u32 { + use self::Ipv6HeaderField::*; + match *self { + NextHeader => 6, + HopLimit => 7, + Saddr => 8, + Daddr => 24, + } + } + + fn len(&self) -> u32 { + use self::Ipv6HeaderField::*; + match *self { + NextHeader => 1, + HopLimit => 1, + Saddr => 16, + Daddr => 16, + } + } +} + +impl Ipv6HeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 6 && len == 1 { + Ok(Self::NextHeader) + } else if off == 7 && len == 1 { + Ok(Self::HopLimit) + } else if off == 8 && len == 16 { + Ok(Self::Saddr) + } else if off == 24 && len == 16 { + Ok(Self::Daddr) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum TransportHeaderField { + Tcp(TcpHeaderField), + Udp(UdpHeaderField), + Icmpv6(Icmpv6HeaderField), +} + +impl HeaderField for TransportHeaderField { + fn offset(&self) -> u32 { + use self::TransportHeaderField::*; + match *self { + Tcp(ref f) => f.offset(), + Udp(ref f) => f.offset(), + Icmpv6(ref f) => f.offset(), + } + } + + fn len(&self) -> u32 { + use self::TransportHeaderField::*; + match *self { + Tcp(ref f) => f.len(), + Udp(ref f) => f.len(), + Icmpv6(ref f) => f.len(), + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum TcpHeaderField { + Sport, + Dport, +} + +impl HeaderField for TcpHeaderField { + fn offset(&self) -> u32 { + use self::TcpHeaderField::*; + match *self { + Sport => 0, + Dport => 2, + } + } + + fn len(&self) -> u32 { + use self::TcpHeaderField::*; + match *self { + Sport => 2, + Dport => 2, + } + } +} + +impl TcpHeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 0 && len == 2 { + Ok(Self::Sport) + } else if off == 2 && len == 2 { + Ok(Self::Dport) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum UdpHeaderField { + Sport, + Dport, + Len, +} + +impl HeaderField for UdpHeaderField { + fn offset(&self) -> u32 { + use self::UdpHeaderField::*; + match *self { + Sport => 0, + Dport => 2, + Len => 4, + } + } + + fn len(&self) -> u32 { + use self::UdpHeaderField::*; + match *self { + Sport => 2, + Dport => 2, + Len => 2, + } + } +} + +impl UdpHeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 0 && len == 2 { + Ok(Self::Sport) + } else if off == 2 && len == 2 { + Ok(Self::Dport) + } else if off == 4 && len == 2 { + Ok(Self::Len) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub enum Icmpv6HeaderField { + Type, + Code, + Checksum, +} + +impl HeaderField for Icmpv6HeaderField { + fn offset(&self) -> u32 { + use self::Icmpv6HeaderField::*; + match *self { + Type => 0, + Code => 1, + Checksum => 2, + } + } + + fn len(&self) -> u32 { + use self::Icmpv6HeaderField::*; + match *self { + Type => 1, + Code => 1, + Checksum => 2, + } + } +} + +impl Icmpv6HeaderField { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { + let off = data.offset; + let len = data.len; + + if off == 0 && len == 1 { + Ok(Self::Type) + } else if off == 1 && len == 1 { + Ok(Self::Code) + } else if off == 2 && len == 2 { + Ok(Self::Checksum) + } else { + Err(DeserializationError::InvalidValue) + } + } +} + +#[macro_export(local_inner_macros)] +macro_rules! nft_expr_payload { + (@ipv4_field ttl) => { + $crate::expr::Ipv4HeaderField::Ttl + }; + (@ipv4_field protocol) => { + $crate::expr::Ipv4HeaderField::Protocol + }; + (@ipv4_field saddr) => { + $crate::expr::Ipv4HeaderField::Saddr + }; + (@ipv4_field daddr) => { + $crate::expr::Ipv4HeaderField::Daddr + }; + + (@ipv6_field nextheader) => { + $crate::expr::Ipv6HeaderField::NextHeader + }; + (@ipv6_field hoplimit) => { + $crate::expr::Ipv6HeaderField::HopLimit + }; + (@ipv6_field saddr) => { + $crate::expr::Ipv6HeaderField::Saddr + }; + (@ipv6_field daddr) => { + $crate::expr::Ipv6HeaderField::Daddr + }; + + (@tcp_field sport) => { + $crate::expr::TcpHeaderField::Sport + }; + (@tcp_field dport) => { + $crate::expr::TcpHeaderField::Dport + }; + + (@udp_field sport) => { + $crate::expr::UdpHeaderField::Sport + }; + (@udp_field dport) => { + $crate::expr::UdpHeaderField::Dport + }; + (@udp_field len) => { + $crate::expr::UdpHeaderField::Len + }; + + (ethernet daddr) => { + $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Daddr) + }; + (ethernet saddr) => { + $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Saddr) + }; + (ethernet ethertype) => { + $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::EtherType) + }; + + (ipv4 $field:ident) => { + $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv4( + nft_expr_payload!(@ipv4_field $field), + )) + }; + (ipv6 $field:ident) => { + $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv6( + nft_expr_payload!(@ipv6_field $field), + )) + }; + + (tcp $field:ident) => { + $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Tcp( + nft_expr_payload!(@tcp_field $field), + )) + }; + (udp $field:ident) => { + $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Udp( + nft_expr_payload!(@udp_field $field), + )) + }; +} diff --git a/src/expr/register.rs b/src/expr/register.rs new file mode 100644 index 0000000..f0aed94 --- /dev/null +++ b/src/expr/register.rs @@ -0,0 +1,34 @@ +use std::fmt::Debug; + +use crate::sys::libc; + +use super::DeserializationError; + +/// A netfilter data register. The expressions store and read data to and from these +/// when evaluating rule statements. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(i32)] +pub enum Register { + Verdict = libc::NFT_REG_VERDICT, + Reg1 = libc::NFT_REG_1, + Reg2 = libc::NFT_REG_2, + Reg3 = libc::NFT_REG_3, + Reg4 = libc::NFT_REG_4, +} + +impl Register { + pub fn to_raw(self) -> u32 { + self as u32 + } + + pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { + match val as i32 { + libc::NFT_REG_VERDICT => Ok(Self::Verdict), + libc::NFT_REG_1 => Ok(Self::Reg1), + libc::NFT_REG_2 => Ok(Self::Reg2), + libc::NFT_REG_3 => Ok(Self::Reg3), + libc::NFT_REG_4 => Ok(Self::Reg4), + _ => Err(DeserializationError::InvalidValue), + } + } +} diff --git a/src/expr/reject.rs b/src/expr/reject.rs new file mode 100644 index 0000000..2ea0cbf --- /dev/null +++ b/src/expr/reject.rs @@ -0,0 +1,96 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::ProtoFamily; +use crate::sys::{self, libc::{self, c_char}}; + +/// A reject expression that defines the type of rejection message sent +/// when discarding a packet. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub enum Reject { + /// Return an ICMP unreachable packet + Icmp(IcmpCode), + /// Reject by sending a TCP RST packet + TcpRst, +} + +impl Reject { + fn to_raw(&self, family: ProtoFamily) -> u32 { + use libc::*; + let value = match *self { + Self::Icmp(..) => match family { + ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH, + _ => NFT_REJECT_ICMP_UNREACH, + }, + Self::TcpRst => NFT_REJECT_TCP_RST, + }; + value as u32 + } +} + +impl Expression for Reject { + fn get_raw_name() -> *const libc::c_char { + b"reject\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> + where + Self: Sized, + { + unsafe { + if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16) + == libc::NFT_REJECT_TCP_RST as u32 + { + Ok(Self::TcpRst) + } else { + Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8( + expr, + sys::NFTNL_EXPR_REJECT_CODE as u16, + ))?)) + } + } + } + + fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + let family = rule.get_chain().get_table().get_family(); + + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_REJECT_TYPE as u16, + self.to_raw(family), + ); + + let reject_code = match *self { + Reject::Icmp(code) => code as u8, + Reject::TcpRst => 0, + }; + + sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code); + + expr + } + } +} + +/// An ICMP reject code. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +#[repr(u8)] +pub enum IcmpCode { + NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8, + PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8, + HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8, + AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8, +} + +impl IcmpCode { + fn from_raw(code: u8) -> Result<Self, DeserializationError> { + match code as i32 { + libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute), + libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach), + libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach), + libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited), + _ => Err(DeserializationError::InvalidValue), + } + } +} diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs new file mode 100644 index 0000000..3c4c374 --- /dev/null +++ b/src/expr/verdict.rs @@ -0,0 +1,148 @@ +use super::{DeserializationError, Expression, Rule}; +use crate::sys::{self, libc::{self, c_char}}; +use std::ffi::{CStr, CString}; + +/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl +/// terms, but here it is simplified to only represent a verdict. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum Verdict { + /// Silently drop the packet. + Drop, + /// Accept the packet and let it pass. + Accept, + Queue, + Continue, + Break, + Jump { + chain: CString, + }, + Goto { + chain: CString, + }, + Return, +} + +impl Verdict { + fn chain(&self) -> Option<&CStr> { + match *self { + Verdict::Jump { ref chain } => Some(chain.as_c_str()), + Verdict::Goto { ref chain } => Some(chain.as_c_str()), + _ => None, + } + } +} + +impl Expression for Verdict { + fn get_raw_name() -> *const libc::c_char { + b"immediate\0" as *const _ as *const c_char + } + + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { + unsafe { + let mut chain = None; + if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { + let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); + + if raw_chain.is_null() { + return Err(DeserializationError::NullPointer); + } + chain = Some(CStr::from_ptr(raw_chain).to_owned()); + } + + let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); + + match verdict as i32 { + libc::NF_DROP => Ok(Verdict::Drop), + libc::NF_ACCEPT => Ok(Verdict::Accept), + libc::NF_QUEUE => Ok(Verdict::Queue), + libc::NFT_CONTINUE => Ok(Verdict::Continue), + libc::NFT_BREAK => Ok(Verdict::Break), + libc::NFT_JUMP => { + if let Some(chain) = chain { + Ok(Verdict::Jump { chain }) + } else { + Err(DeserializationError::InvalidValue) + } + } + libc::NFT_GOTO => { + if let Some(chain) = chain { + Ok(Verdict::Goto { chain }) + } else { + Err(DeserializationError::InvalidValue) + } + } + libc::NFT_RETURN => Ok(Verdict::Return), + _ => Err(DeserializationError::InvalidValue), + } + } + } + + fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { + let immediate_const = match *self { + Verdict::Drop => libc::NF_DROP, + Verdict::Accept => libc::NF_ACCEPT, + Verdict::Queue => libc::NF_QUEUE, + Verdict::Continue => libc::NFT_CONTINUE, + Verdict::Break => libc::NFT_BREAK, + Verdict::Jump { .. } => libc::NFT_JUMP, + Verdict::Goto { .. } => libc::NFT_GOTO, + Verdict::Return => libc::NFT_RETURN, + }; + unsafe { + let expr = try_alloc!(sys::nftnl_expr_alloc( + b"immediate\0" as *const _ as *const c_char + )); + + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_IMM_DREG as u16, + libc::NFT_REG_VERDICT as u32, + ); + + if let Some(chain) = self.chain() { + sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr()); + } + sys::nftnl_expr_set_u32( + expr, + sys::NFTNL_EXPR_IMM_VERDICT as u16, + immediate_const as u32, + ); + + expr + } + } +} + +#[macro_export] +macro_rules! nft_expr_verdict { + (drop) => { + $crate::expr::Verdict::Drop + }; + (accept) => { + $crate::expr::Verdict::Accept + }; + (reject icmp $code:expr) => { + $crate::expr::Verdict::Reject(RejectionType::Icmp($code)) + }; + (reject tcp-rst) => { + $crate::expr::Verdict::Reject(RejectionType::TcpRst) + }; + (queue) => { + $crate::expr::Verdict::Queue + }; + (continue) => { + $crate::expr::Verdict::Continue + }; + (break) => { + $crate::expr::Verdict::Break + }; + (jump $chain:expr) => { + $crate::expr::Verdict::Jump { chain: $chain } + }; + (goto $chain:expr) => { + $crate::expr::Verdict::Goto { chain: $chain } + }; + (return) => { + $crate::expr::Verdict::Return + }; +} diff --git a/src/expr/wrapper.rs b/src/expr/wrapper.rs new file mode 100644 index 0000000..1bcc520 --- /dev/null +++ b/src/expr/wrapper.rs @@ -0,0 +1,60 @@ +use std::ffi::CStr; +use std::ffi::CString; +use std::fmt::Debug; +use std::rc::Rc; + +use super::{DeserializationError, Expression}; +use crate::{sys, Rule}; + +pub struct ExpressionWrapper { + pub(crate) expr: *const sys::nftnl_expr, + // we also need the rule here to ensure that the rule lives as long as the `expr` pointer + #[allow(dead_code)] + pub(crate) rule: Rc<Rule>, +} + +impl Debug for ExpressionWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.get_str()) + } +} + +impl ExpressionWrapper { + /// Retrieves a textual description of the expression. + pub fn get_str(&self) -> CString { + let mut descr_buf = vec![0i8; 4096]; + unsafe { + sys::nftnl_expr_snprintf( + descr_buf.as_mut_ptr(), + (descr_buf.len() - 1) as u64, + self.expr, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + CStr::from_ptr(descr_buf.as_ptr()).to_owned() + } + } + + /// Retrieves the type of expression ("log", "counter", ...). + pub fn get_kind(&self) -> Option<&CStr> { + unsafe { + let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16); + if !ptr.is_null() { + Some(CStr::from_ptr(ptr)) + } else { + None + } + } + } + + /// Attempt to decode the expression as the type T. + pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> { + if let Some(kind) = self.get_kind() { + let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) }; + if kind == raw_name { + return T::from_expr(self.expr); + } + } + Err(DeserializationError::InvalidExpressionKind) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6eedf9f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,205 @@ +// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby +// +// This file is free software: you may copy, redistribute and/or modify it +// under the terms of the GNU General Public License as published by the +// Free Software Foundation, either version 3 of the License, or (at your +// option) any later version. +// +// This file is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see the LICENSE file. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright 2018 Amagicom AB. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Safe abstraction for [`libnftnl`]. Provides low-level userspace access to the in-kernel +//! nf_tables subsystem. See [`rustables-sys`] for the low level FFI bindings to the C library. +//! +//! Can be used to create and remove tables, chains, sets and rules from the nftables firewall, +//! the successor to iptables. +//! +//! This library currently has quite rough edges and does not make adding and removing netfilter +//! entries super easy and elegant. That is partly because the library needs more work, but also +//! partly because nftables is super low level and extremely customizable, making it hard, and +//! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. +//! One can also look at how the original project this crate was developed to support uses it: +//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app) +//! +//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by +//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft` +//! binary. Since the implementation is mostly based on trial and error, there might of course be +//! a number of places where the underlying library is used in an invalid or not intended way. +//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome! +//! +//! # Selecting version of `libnftnl` +//! +//! See the documentation for the corresponding sys crate for details: [`rustables-sys`]. +//! This crate has the same features as the sys crate, and selecting version works the same. +//! +//! # Access to raw handles +//! +//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absoluetely +//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`. +//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For example, +//! a program using a const handle to a object in a thread and writing through a mutable handle +//! in another could reach all kind of undefined (and dangerous!) behaviors. +//! By enabling that feature flag, you acknowledge that guaranteeing the respect of safety +//! invariants is now your responsibility! +//! Despite these shortcomings, that feature is still available because it may allow you to perform +//! manipulations that this library doesn't currently expose. If that is your case, we would +//! be very happy to hear from you and maybe help you get the necessary functionality upstream. +//! +//! Our current lack of confidence in our availability to provide a safe abstraction over the +//! use of raw handles in the face of concurrency is the reason we decided to settly on `Rc` +//! pointers instead of `Arc` (besides, this should gives us some nice performance boost, not +//! that it matters much of course) and why we do not declare the types exposes by the library +//! as `Send` nor `Sync`. +//! +//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/ +//! [`nftables`]: https://netfilter.org/projects/nftables/ +//! [`rustables-sys`]: https://crates.io/crates/rustables-sys + +use thiserror::Error; + +#[macro_use] +extern crate log; + +pub mod sys; +use sys::libc; +use std::{convert::TryFrom, ffi::c_void, ops::Deref}; + +macro_rules! try_alloc { + ($e:expr) => {{ + let ptr = $e; + if ptr.is_null() { + // OOM, and the tried allocation was likely very small, + // so we are in a very tight situation. We do what libstd does, aborts. + std::process::abort(); + } + ptr + }}; +} + +mod batch; +#[cfg(feature = "query")] +pub use batch::{batch_is_supported, default_batch_page_size}; +pub use batch::{Batch, FinalizedBatch, NetlinkError}; + +pub mod expr; + +pub mod table; +pub use table::Table; +#[cfg(feature = "query")] +pub use table::{get_tables_cb, list_tables}; + +mod chain; +#[cfg(feature = "query")] +pub use chain::{get_chains_cb, list_chains_for_table}; +pub use chain::{Chain, ChainType, Hook, Policy, Priority}; + +pub mod query; + +mod rule; +pub use rule::Rule; +#[cfg(feature = "query")] +pub use rule::{get_rules_cb, list_rules_for_chain}; + +pub mod set; + +/// The type of the message as it's sent to netfilter. A message consists of an object, such as a +/// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with +/// that object. If a [`Table`] object is sent with `MsgType::Add` then that table will be added +/// to netfilter, if sent with `MsgType::Del` it will be removed. +/// +/// [`Table`]: struct.Table.html +/// [`Chain`]: struct.Chain.html +/// [`Rule`]: struct.Rule.html +/// [`MsgType`]: enum.MsgType.html +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub enum MsgType { + /// Add the object to netfilter. + Add, + /// Remove the object from netfilter. + Del, +} + +/// Denotes a protocol. Used to specify which protocol a table or set belongs to. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(u16)] +pub enum ProtoFamily { + Unspec = libc::NFPROTO_UNSPEC as u16, + /// Inet - Means both IPv4 and IPv6 + Inet = libc::NFPROTO_INET as u16, + Ipv4 = libc::NFPROTO_IPV4 as u16, + Arp = libc::NFPROTO_ARP as u16, + NetDev = libc::NFPROTO_NETDEV as u16, + Bridge = libc::NFPROTO_BRIDGE as u16, + Ipv6 = libc::NFPROTO_IPV6 as u16, + DecNet = libc::NFPROTO_DECNET as u16, +} +#[derive(Error, Debug)] +#[error("Couldn't find a matching protocol")] +pub struct InvalidProtocolFamily; + +impl TryFrom<i32> for ProtoFamily { + type Error = InvalidProtocolFamily; + fn try_from(value: i32) -> Result<Self, Self::Error> { + match value { + libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec), + libc::NFPROTO_INET => Ok(ProtoFamily::Inet), + libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4), + libc::NFPROTO_ARP => Ok(ProtoFamily::Arp), + libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev), + libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge), + libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6), + libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet), + _ => Err(InvalidProtocolFamily), + } + } +} + +/// Trait for all types in this crate that can serialize to a Netlink message. +/// +/// # Unsafe +/// +/// This trait is unsafe to implement because it must never serialize to anything larger than the +/// largest possible netlink message. Internally the `nft_nlmsg_maxsize()` function is used to make +/// sure the `buf` pointer passed to `write` always has room for the largest possible Netlink +/// message. +pub unsafe trait NlMsg { + /// Serializes the Netlink message to the buffer at `buf`. `buf` must have space for at least + /// `nft_nlmsg_maxsize()` bytes. This is not checked by the compiler, which is why this method + /// is unsafe. + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType); +} + +unsafe impl<T, R> NlMsg for T +where + T: Deref<Target = R>, + R: NlMsg, +{ + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + self.deref().write(buf, seq, msg_type); + } +} + +/// The largest nf_tables netlink message is the set element message, which +/// contains the NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is +/// a nest that describes the set elements. Given that the netlink attribute +/// length (nla_len) is 16 bits, the largest message is a bit larger than +/// 64 KBytes. +pub fn nft_nlmsg_maxsize() -> u32 { + u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 +} diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..02c4082 --- /dev/null +++ b/src/query.rs @@ -0,0 +1,130 @@ +use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; +use sys::libc; + +/// Returns a buffer containing a netlink message which requests a list of all the netfilter +/// matching objects (e.g. tables, chains, rules, ...). +/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally +/// a callback to execute on the header, to set parameters for example. +/// To pass arbitrary data inside that callback, please use a closure. +pub fn get_list_of_objects<Error>( + seq: u32, + target: u16, + setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, +) -> Result<Vec<u8>, Error> { + let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; + let hdr = unsafe { + &mut *sys::nftnl_nlmsg_build_hdr( + buffer.as_mut_ptr() as *mut libc::c_char, + target, + ProtoFamily::Unspec as u16, + (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16, + seq, + ) + }; + if let Some(cb) = setup_cb { + cb(hdr)?; + } + Ok(buffer) +} + +#[cfg(feature = "query")] +mod inner { + use crate::FinalizedBatch; + + use super::*; + + #[derive(thiserror::Error, Debug)] + pub enum Error { + #[error("Unable to open netlink socket to netfilter")] + NetlinkOpenError(#[source] std::io::Error), + + #[error("Unable to send netlink command to netfilter")] + NetlinkSendError(#[source] std::io::Error), + + #[error("Error while reading from netlink socket")] + NetlinkRecvError(#[source] std::io::Error), + + #[error("Error while processing an incoming netlink message")] + ProcessNetlinkError(#[source] std::io::Error), + + #[error("Custom error when customizing the query")] + InitError(#[from] Box<dyn std::error::Error + 'static>), + + #[error("Couldn't allocate a netlink object, out of memory ?")] + NetlinkAllocationFailed, + } + + /// List objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of an helper + /// function called by mnl::cb_run2. + /// The callback expect a tuple of additional data (supplied as an argument to + /// this function) and of the output vector, to which it should append the parsed + /// object it received. + pub fn list_objects_with_data<'a, A, T>( + data_type: u16, + cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int, + additional_data: &'a A, + req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, + ) -> Result<Vec<T>, Error> + where + T: 'a, + { + debug!("listing objects of kind {}", data_type); + let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + + let seq = 0; + let portid = 0; + + let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; + socket.send(&chains_buf).map_err(Error::NetlinkSendError)?; + + let mut res = Vec::new(); + + let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + while socket + .recv(&mut msg_buffer) + .map_err(Error::NetlinkRecvError)? + > 0 + { + if let mnl::CbResult::Stop = mnl::cb_run2( + &msg_buffer, + seq, + portid, + cb, + &mut (additional_data, &mut res), + ) + .map_err(Error::ProcessNetlinkError)? + { + break; + } + } + + Ok(res) + } + + pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { + let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + + let seq = 0; + let portid = socket.portid(); + + socket.send_all(batch).map_err(Error::NetlinkSendError)?; + debug!("sent"); + + let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + while socket + .recv(&mut msg_buffer) + .map_err(Error::NetlinkRecvError)? + > 0 + { + if let mnl::CbResult::Stop = + mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)? + { + break; + } + } + Ok(()) + } +} + +#[cfg(feature = "query")] +pub use inner::*; diff --git a/src/rule.rs b/src/rule.rs new file mode 100644 index 0000000..c8cb90d --- /dev/null +++ b/src/rule.rs @@ -0,0 +1,341 @@ +use crate::expr::ExpressionWrapper; +use crate::{chain::Chain, expr::Expression, MsgType}; +use crate::sys::{self, libc}; +use std::ffi::{c_void, CStr, CString}; +use std::fmt::Debug; +use std::os::raw::c_char; +use std::rc::Rc; + +/// A nftables firewall rule. +pub struct Rule { + pub(crate) rule: *mut sys::nftnl_rule, + pub(crate) chain: Rc<Chain>, +} + +impl Rule { + /// Creates a new rule object in the given [`Chain`]. + /// + /// [`Chain`]: struct.Chain.html + pub fn new(chain: Rc<Chain>) -> Rule { + unsafe { + let rule = try_alloc!(sys::nftnl_rule_alloc()); + sys::nftnl_rule_set_u32( + rule, + sys::NFTNL_RULE_FAMILY as u16, + chain.get_table().get_family() as u32, + ); + sys::nftnl_rule_set_str( + rule, + sys::NFTNL_RULE_TABLE as u16, + chain.get_table().get_name().as_ptr(), + ); + sys::nftnl_rule_set_str( + rule, + sys::NFTNL_RULE_CHAIN as u16, + chain.get_name().as_ptr(), + ); + + Rule { rule, chain } + } + } + + pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self { + Rule { rule, chain } + } + + pub fn get_position(&self) -> u64 { + unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) } + } + + /// Sets the position of this rule within the chain it lives in. By default a new rule is added + /// to the end of the chain. + pub fn set_position(&mut self, position: u64) { + unsafe { + sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position); + } + } + + pub fn get_handle(&self) -> u64 { + unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) } + } + + pub fn set_handle(&mut self, handle: u64) { + unsafe { + sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle); + } + } + + /// Adds an expression to this rule. Expressions are evaluated from first to last added. + /// As soon as an expression does not match the packet it's being evaluated for, evaluation + /// stops and the packet is evaluated against the next rule in the chain. + pub fn add_expr(&mut self, expr: &impl Expression) { + unsafe { sys::nftnl_rule_add_expr(self.rule, expr.to_expr(self)) } + } + + /// Returns a reference to the [`Chain`] this rule lives in. + /// + /// [`Chain`]: struct.Chain.html + pub fn get_chain(&self) -> Rc<Chain> { + self.chain.clone() + } + + /// Returns the userdata of this chain. + pub fn get_userdata(&self) -> Option<&CStr> { + unsafe { + let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16); + if !ptr.is_null() { + Some(CStr::from_ptr(ptr)) + } else { + None + } + } + } + + /// Updates the userdata of this chain. + pub fn set_userdata(&self, data: &CStr) { + unsafe { + sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr()); + } + } + + /// Returns a textual description of the rule. + pub fn get_str(&self) -> CString { + let mut descr_buf = vec![0i8; 4096]; + unsafe { + sys::nftnl_rule_snprintf( + descr_buf.as_mut_ptr(), + (descr_buf.len() - 1) as u64, + self.rule, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + CStr::from_ptr(descr_buf.as_ptr()).to_owned() + } + } + + /// Retrieves an iterator to loop over the expressions of the rule + pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter { + RuleExprsIter::new(self.clone()) + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns the raw handle. + pub fn as_ptr(&self) -> *const sys::nftnl_rule { + self.rule as *const sys::nftnl_rule + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns a mutable version of the raw handle. + pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule { + self.rule + } + + /// Perform a deep comparizon of rules, by checking they have the same expressions inside. + /// This is not enabled by default in our PartialEq implementation because of the + /// difficulty to compare an expression generated by the library with the expressions returned + /// by the kernel when iterating over the currently in-use rules. The kernel-returned + /// expressions may have additional attributes despite being generated from the same rule. + /// This is particularly true for the 'nat' expression). + pub fn deep_eq(&self, other: &Self) -> bool { + if self != other { + return false; + } + + let self_exprs = + try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) }); + let other_exprs = try_alloc!(unsafe { + sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule) + }); + + loop { + let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) }; + let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) }; + if self_next.is_null() && other_next.is_null() { + return true; + } else if self_next.is_null() || other_next.is_null() { + return false; + } + + // we are falling back on comparing the strings, because there is no easy mechanism to + // perform a memcmp() between the two expressions :/ + let mut self_str = [0; 256]; + let mut other_str = [0; 256]; + unsafe { + sys::nftnl_expr_snprintf( + self_str.as_mut_ptr(), + (self_str.len() - 1) as u64, + self_next, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + sys::nftnl_expr_snprintf( + other_str.as_mut_ptr(), + (other_str.len() - 1) as u64, + other_next, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + } + + if self_str != other_str { + return false; + } + } + } +} + +impl Debug for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.get_str()) + } +} + +impl PartialEq for Rule { + fn eq(&self, other: &Self) -> bool { + if self.get_chain() != other.get_chain() { + return false; + } + + unsafe { + if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16) + && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16) + { + if self.get_handle() != other.get_handle() { + return false; + } + } + if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16) + && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16) + { + if self.get_position() != other.get_position() { + return false; + } + } + } + + return false; + } +} + +unsafe impl crate::NlMsg for Rule { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + let type_ = match msg_type { + MsgType::Add => libc::NFT_MSG_NEWRULE, + MsgType::Del => libc::NFT_MSG_DELRULE, + }; + let flags: u16 = match msg_type { + MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16, + MsgType::Del => 0u16, + } | libc::NLM_F_ACK as u16; + let header = sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + type_ as u16, + self.chain.get_table().get_family() as u16, + flags, + seq, + ); + sys::nftnl_rule_nlmsg_build_payload(header, self.rule); + } +} + +impl Drop for Rule { + fn drop(&mut self) { + unsafe { sys::nftnl_rule_free(self.rule) }; + } +} + +pub struct RuleExprsIter { + rule: Rc<Rule>, + iter: *mut sys::nftnl_expr_iter, +} + +impl RuleExprsIter { + fn new(rule: Rc<Rule>) -> Self { + let iter = + try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) }); + RuleExprsIter { rule, iter } + } +} + +impl Iterator for RuleExprsIter { + type Item = ExpressionWrapper; + + fn next(&mut self) -> Option<Self::Item> { + let next = unsafe { sys::nftnl_expr_iter_next(self.iter) }; + if next.is_null() { + trace!("RulesExprsIter iterator ending"); + None + } else { + trace!("RulesExprsIter returning new expression"); + Some(ExpressionWrapper { + expr: next, + rule: self.rule.clone(), + }) + } + } +} + +impl Drop for RuleExprsIter { + fn drop(&mut self) { + unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; + } +} + +#[cfg(feature = "query")] +pub fn get_rules_cb( + header: &libc::nlmsghdr, + (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), +) -> libc::c_int { + unsafe { + let rule = sys::nftnl_rule_alloc(); + if rule == std::ptr::null_mut() { + return mnl::mnl_sys::MNL_CB_ERROR; + } + let err = sys::nftnl_rule_nlmsg_parse(header, rule); + if err < 0 { + error!("Failed to parse nelink rule message - {}", err); + sys::nftnl_rule_free(rule); + return err; + } + + rules.push(Rule::from_raw(rule, chain.clone())); + } + mnl::mnl_sys::MNL_CB_OK +} + +#[cfg(feature = "query")] +pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { + crate::query::list_objects_with_data( + libc::NFT_MSG_GETRULE as u16, + get_rules_cb, + &chain, + // only retrieve rules from the currently targetted chain + Some(&|hdr| unsafe { + let rule = sys::nftnl_rule_alloc(); + if rule as *const _ == std::ptr::null() { + return Err(crate::query::Error::NetlinkAllocationFailed); + } + + sys::nftnl_rule_set_str( + rule, + sys::NFTNL_RULE_TABLE as u16, + chain.get_table().get_name().as_ptr(), + ); + sys::nftnl_rule_set_u32( + rule, + sys::NFTNL_RULE_FAMILY as u16, + chain.get_table().get_family() as u32, + ); + sys::nftnl_rule_set_str( + rule, + sys::NFTNL_RULE_CHAIN as u16, + chain.get_name().as_ptr(), + ); + + sys::nftnl_rule_nlmsg_build_payload(hdr, rule); + + sys::nftnl_rule_free(rule); + Ok(()) + }), + ) +} diff --git a/src/set.rs b/src/set.rs new file mode 100644 index 0000000..d6b9514 --- /dev/null +++ b/src/set.rs @@ -0,0 +1,273 @@ +use crate::{table::Table, MsgType, ProtoFamily}; +use crate::sys::{self, libc}; +use std::{ + cell::Cell, + ffi::{c_void, CStr, CString}, + fmt::Debug, + net::{Ipv4Addr, Ipv6Addr}, + os::raw::c_char, + rc::Rc, +}; + +#[macro_export] +macro_rules! nft_set { + ($name:expr, $id:expr, $table:expr, $family:expr) => { + $crate::set::Set::new($name, $id, $table, $family) + }; + ($name:expr, $id:expr, $table:expr, $family:expr; [ ]) => { + nft_set!($name, $id, $table, $family) + }; + ($name:expr, $id:expr, $table:expr, $family:expr; [ $($value:expr,)* ]) => {{ + let mut set = nft_set!($name, $id, $table, $family).expect("Set allocation failed"); + $( + set.add($value).expect(stringify!(Unable to add $value to set $name)); + )* + set + }}; +} + +pub struct Set<'a, K> { + pub(crate) set: *mut sys::nftnl_set, + pub(crate) table: &'a Table, + pub(crate) family: ProtoFamily, + _marker: ::std::marker::PhantomData<K>, +} + +impl<'a, K> Set<'a, K> { + pub fn new(name: &CStr, id: u32, table: &'a Table, family: ProtoFamily) -> Self + where + K: SetKey, + { + unsafe { + let set = try_alloc!(sys::nftnl_set_alloc()); + + sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, family as u32); + sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr()); + sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr()); + sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id); + + sys::nftnl_set_set_u32( + set, + sys::NFTNL_SET_FLAGS as u16, + (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32, + ); + sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE); + sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN); + + Set { + set, + table, + family, + _marker: ::std::marker::PhantomData, + } + } + } + + pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: &'a Table, family: ProtoFamily) -> Self + where + K: SetKey, + { + Set { + set, + table, + family, + _marker: ::std::marker::PhantomData, + } + } + + pub fn add(&mut self, key: &K) + where + K: SetKey, + { + unsafe { + let elem = try_alloc!(sys::nftnl_set_elem_alloc()); + + let data = key.data(); + let data_len = data.len() as u32; + trace!("Adding key {:?} with len {}", data, data_len); + sys::nftnl_set_elem_set( + elem, + sys::NFTNL_SET_ELEM_KEY as u16, + data.as_ref() as *const _ as *const c_void, + data_len, + ); + sys::nftnl_set_elem_add(self.set, elem); + } + } + + pub fn elems_iter(&'a self) -> SetElemsIter<'a, K> { + SetElemsIter::new(self) + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns the raw handle. + pub fn as_ptr(&self) -> *const sys::nftnl_set { + self.set as *const sys::nftnl_set + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns a mutable version of the raw handle. + pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set { + self.set + } + + pub fn get_family(&self) -> ProtoFamily { + self.family + } + + /// Returns a textual description of the set. + pub fn get_str(&self) -> CString { + let mut descr_buf = vec![0i8; 4096]; + unsafe { + sys::nftnl_set_snprintf( + descr_buf.as_mut_ptr(), + (descr_buf.len() - 1) as u64, + self.set, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + CStr::from_ptr(descr_buf.as_ptr()).to_owned() + } + } + + pub fn get_name(&self) -> Option<&CStr> { + unsafe { + let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16); + if !ptr.is_null() { + Some(CStr::from_ptr(ptr)) + } else { + None + } + } + } + + pub fn get_id(&self) -> u32 { + unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) } + } +} + +impl<'a, K> Debug for Set<'a, K> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.get_str()) + } +} + +unsafe impl<'a, K> crate::NlMsg for Set<'a, K> { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + let type_ = match msg_type { + MsgType::Add => libc::NFT_MSG_NEWSET, + MsgType::Del => libc::NFT_MSG_DELSET, + }; + let header = sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + type_ as u16, + self.table.get_family() as u16, + (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16, + seq, + ); + sys::nftnl_set_nlmsg_build_payload(header, self.set); + } +} + +impl<'a, K> Drop for Set<'a, K> { + fn drop(&mut self) { + unsafe { sys::nftnl_set_free(self.set) }; + } +} + +pub struct SetElemsIter<'a, K> { + set: &'a Set<'a, K>, + iter: *mut sys::nftnl_set_elems_iter, + ret: Rc<Cell<i32>>, +} + +impl<'a, K> SetElemsIter<'a, K> { + fn new(set: &'a Set<'a, K>) -> Self { + let iter = try_alloc!(unsafe { + sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set) + }); + SetElemsIter { + set, + iter, + ret: Rc::new(Cell::new(1)), + } + } +} + +impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> { + type Item = SetElemsMsg<'a, K>; + + fn next(&mut self) -> Option<Self::Item> { + if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } { + trace!("SetElemsIter iterator ending"); + None + } else { + trace!("SetElemsIter returning new SetElemsMsg"); + Some(SetElemsMsg { + set: self.set, + iter: self.iter, + ret: self.ret.clone(), + }) + } + } +} + +impl<'a, K> Drop for SetElemsIter<'a, K> { + fn drop(&mut self) { + unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) }; + } +} + +pub struct SetElemsMsg<'a, K> { + set: &'a Set<'a, K>, + iter: *mut sys::nftnl_set_elems_iter, + ret: Rc<Cell<i32>>, +} + +unsafe impl<'a, K> crate::NlMsg for SetElemsMsg<'a, K> { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + trace!("Writing SetElemsMsg to NlMsg"); + let (type_, flags) = match msg_type { + MsgType::Add => ( + libc::NFT_MSG_NEWSETELEM, + libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, + ), + MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), + }; + let header = sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + type_ as u16, + self.set.get_family() as u16, + flags as u16, + seq, + ); + self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter( + header, self.iter, + )); + } +} + +pub trait SetKey { + const TYPE: u32; + const LEN: u32; + + fn data(&self) -> Box<[u8]>; +} + +impl SetKey for Ipv4Addr { + const TYPE: u32 = 7; + const LEN: u32 = 4; + + fn data(&self) -> Box<[u8]> { + self.octets().to_vec().into_boxed_slice() + } +} + +impl SetKey for Ipv6Addr { + const TYPE: u32 = 8; + const LEN: u32 = 16; + + fn data(&self) -> Box<[u8]> { + self.octets().to_vec().into_boxed_slice() + } +} diff --git a/src/table.rs b/src/table.rs new file mode 100644 index 0000000..2f21453 --- /dev/null +++ b/src/table.rs @@ -0,0 +1,171 @@ +use crate::{MsgType, ProtoFamily}; +use crate::sys::{self, libc}; +#[cfg(feature = "query")] +use std::convert::TryFrom; +use std::{ + ffi::{c_void, CStr, CString}, + fmt::Debug, + os::raw::c_char, +}; + +/// Abstraction of `nftnl_table`. The top level container in netfilter. A table has a protocol +/// family and contain [`Chain`]s that in turn hold the rules. +/// +/// [`Chain`]: struct.Chain.html +pub struct Table { + table: *mut sys::nftnl_table, + family: ProtoFamily, +} + +impl Table { + /// Creates a new table instance with the given name and protocol family. + pub fn new<T: AsRef<CStr>>(name: &T, family: ProtoFamily) -> Table { + unsafe { + let table = try_alloc!(sys::nftnl_table_alloc()); + + sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FAMILY as u16, family as u32); + sys::nftnl_table_set_str(table, sys::NFTNL_TABLE_NAME as u16, name.as_ref().as_ptr()); + sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FLAGS as u16, 0u32); + Table { table, family } + } + } + + pub unsafe fn from_raw(table: *mut sys::nftnl_table, family: ProtoFamily) -> Self { + Table { table, family } + } + + /// Returns the name of this table. + pub fn get_name(&self) -> &CStr { + unsafe { + let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16); + if ptr.is_null() { + panic!("Impossible situation: retrieving the name of a chain failed") + } else { + CStr::from_ptr(ptr) + } + } + } + + /// Returns a textual description of the table. + pub fn get_str(&self) -> CString { + let mut descr_buf = vec![0i8; 4096]; + unsafe { + sys::nftnl_table_snprintf( + descr_buf.as_mut_ptr(), + (descr_buf.len() - 1) as u64, + self.table, + sys::NFTNL_OUTPUT_DEFAULT, + 0, + ); + CStr::from_ptr(descr_buf.as_ptr()).to_owned() + } + } + + /// Returns the protocol family for this table. + pub fn get_family(&self) -> ProtoFamily { + self.family + } + + /// Returns the userdata of this chain. + pub fn get_userdata(&self) -> Option<&CStr> { + unsafe { + let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16); + if !ptr.is_null() { + Some(CStr::from_ptr(ptr)) + } else { + None + } + } + } + + /// Update the userdata of this chain. + pub fn set_userdata(&self, data: &CStr) { + unsafe { + sys::nftnl_table_set_str(self.table, sys::NFTNL_TABLE_USERDATA as u16, data.as_ptr()); + } + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns the raw handle. + pub fn as_ptr(&self) -> *const sys::nftnl_table { + self.table as *const sys::nftnl_table + } + + #[cfg(feature = "unsafe-raw-handles")] + /// Returns a mutable version of the raw handle. + pub fn as_mut_ptr(&self) -> *mut sys::nftnl_table { + self.table + } +} + +impl PartialEq for Table { + fn eq(&self, other: &Self) -> bool { + self.get_name() == other.get_name() && self.get_family() == other.get_family() + } +} + +impl Debug for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.get_str()) + } +} + +unsafe impl crate::NlMsg for Table { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + let raw_msg_type = match msg_type { + MsgType::Add => libc::NFT_MSG_NEWTABLE, + MsgType::Del => libc::NFT_MSG_DELTABLE, + }; + let header = sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + raw_msg_type as u16, + self.family as u16, + libc::NLM_F_ACK as u16, + seq, + ); + sys::nftnl_table_nlmsg_build_payload(header, self.table); + } +} + +impl Drop for Table { + fn drop(&mut self) { + unsafe { sys::nftnl_table_free(self.table) }; + } +} + +#[cfg(feature = "query")] +/// A callback to parse the response for messages created with `get_tables_nlmsg`. +pub fn get_tables_cb( + header: &libc::nlmsghdr, + (_, tables): &mut (&(), &mut Vec<Table>), +) -> libc::c_int { + unsafe { + let table = sys::nftnl_table_alloc(); + if table == std::ptr::null_mut() { + return mnl::mnl_sys::MNL_CB_ERROR; + } + let err = sys::nftnl_table_nlmsg_parse(header, table); + if err < 0 { + error!("Failed to parse nelink table message - {}", err); + sys::nftnl_table_free(table); + return err; + } + let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16); + match crate::ProtoFamily::try_from(family as i32) { + Ok(family) => { + tables.push(Table::from_raw(table, family)); + mnl::mnl_sys::MNL_CB_OK + } + Err(crate::InvalidProtocolFamily) => { + error!("The netlink table didn't have a valid protocol family !?"); + sys::nftnl_table_free(table); + mnl::mnl_sys::MNL_CB_ERROR + } + } + } +} + +#[cfg(feature = "query")] +pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { + crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None) +} |