aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSimon THOBY <git@nightmared.fr>2023-01-09 18:54:11 +0000
committerSimon THOBY <git@nightmared.fr>2023-01-09 18:54:11 +0000
commitd5b9ec5185a27414286ee303eb3d21ce3069db09 (patch)
tree369eb90e8a2da307d7cd8f0b15a3318bbdba0003 /src
parent3e48e7efa516183d623f80d2e4e393cecc2acde9 (diff)
parentc3e3773cccd01f80f2d72a7691e0654d304e6b2d (diff)
Merge branch 'no_mnl' into 'master'
experimental support for a full-rust rewrite of the codebase (no libnftnl/libmnl anymore) See merge request rustwall/rustables!16
Diffstat (limited to 'src')
-rw-r--r--src/batch.rs237
-rw-r--r--src/chain.rs381
-rw-r--r--src/chain_methods.rs40
-rw-r--r--src/data_type.rs42
-rw-r--r--src/error.rs180
-rw-r--r--src/expr/bitwise.rs100
-rw-r--r--src/expr/cmp.rs204
-rw-r--r--src/expr/counter.rs43
-rw-r--r--src/expr/ct.rs108
-rw-r--r--src/expr/immediate.rs154
-rw-r--r--src/expr/log.rs127
-rw-r--r--src/expr/lookup.rs94
-rw-r--r--src/expr/masquerade.rs28
-rw-r--r--src/expr/meta.rs183
-rw-r--r--src/expr/mod.rs314
-rw-r--r--src/expr/nat.rs102
-rw-r--r--src/expr/payload.rs443
-rw-r--r--src/expr/register.rs33
-rw-r--r--src/expr/reject.rs109
-rw-r--r--src/expr/verdict.rs169
-rw-r--r--src/expr/wrapper.rs61
-rw-r--r--src/lib.rs179
-rw-r--r--src/nlmsg.rs182
-rw-r--r--src/parser.rs216
-rw-r--r--src/parser_impls.rs243
-rw-r--r--src/query.rs277
-rw-r--r--src/rule.rs392
-rw-r--r--src/rule_methods.rs355
-rw-r--r--src/set.rs337
-rw-r--r--src/sys.rs3
-rw-r--r--src/table.rs197
-rw-r--r--src/tests/batch.rs96
-rw-r--r--src/tests/chain.rs120
-rw-r--r--src/tests/expr.rs591
-rw-r--r--src/tests/mod.rs193
-rw-r--r--src/tests/rule.rs132
-rw-r--r--src/tests/set.rs119
-rw-r--r--src/tests/table.rs67
38 files changed, 3668 insertions, 3183 deletions
diff --git a/src/batch.rs b/src/batch.rs
index 198e8d0..b5c88b8 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -1,31 +1,29 @@
-use crate::{MsgType, NlMsg};
-use crate::sys::{self as sys, libc};
-use std::ffi::c_void;
-use std::os::raw::c_char;
-use std::ptr;
+use libc;
+
use thiserror::Error;
+use crate::error::QueryError;
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::sys::NFNL_SUBSYS_NFTABLES;
+use crate::{MsgType, ProtocolFamily};
+
+use nix::sys::socket::{
+ self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
+};
+
/// Error while communicating with netlink.
#[derive(Error, Debug)]
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());
-#[cfg(feature = "query")]
-/// Check if the kernel supports batched netlink messages to netfilter.
-pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
- match unsafe { sys::nftnl_batch_is_supported() } {
- 1 => Ok(true),
- 0 => Ok(false),
- _ => Err(NetlinkError(())),
- }
-}
-
-/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to
-/// `nftnl_batch` in libnftnl.
+/// A batch of netfilter messages to be performed in one atomic operation.
pub struct Batch {
- pub(crate) batch: *mut sys::nftnl_batch,
- pub(crate) seq: u32,
- pub(crate) is_empty: bool,
+ buf: Box<Vec<u8>>,
+ // the 'static lifetime here is a cheat, as the writer can only be used as long
+ // as `self.buf` exists. This is why this member must never be exposed directly to
+ // the rest of the crate (let alone publicly).
+ writer: NfNetlinkWriter<'static>,
+ seq: u32,
}
impl Batch {
@@ -33,48 +31,40 @@ impl Batch {
///
/// [default page size]: fn.default_batch_page_size.html
pub fn new() -> Self {
- Self::with_page_size(default_batch_page_size())
- }
-
- pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self {
- Batch {
- batch,
+ // TODO: use a pinned Box ?
+ let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
+ let mut writer = NfNetlinkWriter::new(unsafe {
+ std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
+ });
+ let seq = 0;
+ writer.write_header(
+ libc::NFNL_MSG_BATCH_BEGIN as u16,
+ ProtocolFamily::Unspec,
+ 0,
seq,
- // we assume this batch is not empty by default
- is_empty: false,
+ Some(libc::NFNL_SUBSYS_NFTABLES as u16),
+ );
+ writer.finalize_writing_object();
+ Batch {
+ buf,
+ writer,
+ seq: seq + 1,
}
}
- /// Creates a new nftnl batch with the given batch size.
- pub fn with_page_size(batch_page_size: u32) -> Self {
- let batch = try_alloc!(unsafe {
- sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize())
- });
- let mut this = Batch {
- batch,
- seq: 0,
- is_empty: true,
- };
- this.write_begin_msg();
- this
- }
-
/// Adds the given message to this batch.
- pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) {
+ pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
trace!("Writing NlMsg with seq {} to batch", self.seq);
- unsafe { msg.write(self.current(), self.seq, msg_type) };
- self.is_empty = false;
- self.next()
+ msg.add_or_remove(&mut self.writer, msg_type, self.seq);
+ self.seq += 1;
}
- /// Adds all the messages in the given iterator to this batch. If any message fails to be
- /// added the error for that failure is returned and all messages up until that message stay
- /// added to the batch.
- pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType)
- where
- T: NlMsg,
- I: Iterator<Item = T>,
- {
+ /// Adds all the messages in the given iterator to this batch.
+ pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
+ &mut self,
+ msg_iter: I,
+ msg_type: MsgType,
+ ) {
for msg in msg_iter {
self.add(&msg, msg_type);
}
@@ -86,109 +76,46 @@ impl Batch {
/// Return None if there is no object in the batch (this could block forever).
///
/// [`FinalizedBatch`]: struct.FinalizedBatch.html
- pub fn finalize(mut self) -> Option<FinalizedBatch> {
- self.write_end_msg();
- if self.is_empty {
- return None;
- }
- Some(FinalizedBatch { batch: self })
- }
-
- fn current(&self) -> *mut c_void {
- unsafe { sys::nftnl_batch_buffer(self.batch) }
- }
-
- fn next(&mut self) {
- if unsafe { sys::nftnl_batch_update(self.batch) } < 0 {
- // See try_alloc definition.
- std::process::abort();
- }
- self.seq += 1;
- }
-
- fn write_begin_msg(&mut self) {
- unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) };
- self.next();
- }
-
- fn write_end_msg(&mut self) {
- unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) };
- self.next();
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_batch {
- self.batch as *const sys::nftnl_batch
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch {
- self.batch
- }
-}
-
-impl Drop for Batch {
- fn drop(&mut self) {
- unsafe { sys::nftnl_batch_free(self.batch) };
- }
-}
-
-/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper
-/// batch end message. Created from [`Batch::finalize`].
-///
-/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch.
-///
-/// [`Batch`]: struct.Batch.html
-/// [`Batch::finalize`]: struct.Batch.html#method.finalize
-pub struct FinalizedBatch {
- batch: Batch,
-}
-
-impl FinalizedBatch {
- /// Returns the iterator over byte buffers to send to netlink.
- pub fn iter(&mut self) -> Iter<'_> {
- let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize };
- let mut iovecs = vec![
- libc::iovec {
- iov_base: ptr::null_mut(),
- iov_len: 0,
- };
- num_pages
- ];
- let iovecs_ptr = iovecs.as_mut_ptr();
- unsafe {
- sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32);
- }
- Iter {
- iovecs: iovecs.into_iter(),
- _marker: ::std::marker::PhantomData,
+ pub fn finalize(mut self) -> Vec<u8> {
+ self.writer.write_header(
+ libc::NFNL_MSG_BATCH_END as u16,
+ ProtocolFamily::Unspec,
+ 0,
+ self.seq,
+ Some(NFNL_SUBSYS_NFTABLES as u16),
+ );
+ self.writer.finalize_writing_object();
+ *self.buf
+ }
+
+ pub fn send(self) -> Result<(), QueryError> {
+ use crate::query::{recv_and_process, socket_close_wrapper};
+
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(QueryError::NetlinkOpenError)?;
+
+ let max_seq = self.seq - 1;
+
+ let addr = SockAddr::Netlink(NetlinkAddr::new(0, 0));
+ // while this bind() is not strictly necessary, strace have trouble decoding the messages
+ // if we don't
+ socket::bind(sock, &addr).expect("bind");
+
+ let to_send = self.finalize();
+ let mut sent = 0;
+ while sent != to_send.len() {
+ sent += socket::send(sock, &to_send[sent..], MsgFlags::empty())
+ .map_err(QueryError::NetlinkSendError)?;
}
- }
-}
-
-impl<'a> IntoIterator for &'a mut FinalizedBatch {
- type Item = &'a [u8];
- type IntoIter = Iter<'a>;
-
- fn into_iter(self) -> Iter<'a> {
- self.iter()
- }
-}
-
-pub struct Iter<'a> {
- iovecs: ::std::vec::IntoIter<libc::iovec>,
- _marker: ::std::marker::PhantomData<&'a ()>,
-}
-
-impl<'a> Iterator for Iter<'a> {
- type Item = &'a [u8];
- fn next(&mut self) -> Option<&'a [u8]> {
- self.iovecs.next().map(|iovec| unsafe {
- ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len)
- })
+ Ok(socket_close_wrapper(sock, move |sock| {
+ recv_and_process(sock, Some(max_seq), None, &mut ())
+ })?)
}
}
diff --git a/src/chain.rs b/src/chain.rs
index a942a37..37e4cb3 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,41 +1,85 @@
-use crate::{MsgType, Table};
-use crate::sys::{self as sys, libc};
-#[cfg(feature = "query")]
-use std::convert::TryFrom;
-use std::{
- ffi::{c_void, CStr, CString},
- fmt,
- os::raw::c_char,
- rc::Rc,
+use libc::{NF_ACCEPT, NF_DROP};
+use rustables_macros::nfnetlink_struct;
+
+use crate::error::{DecodeError, QueryError};
+use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
+use crate::sys::{
+ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
+ NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
+ NFT_MSG_NEWCHAIN,
};
+use crate::{Batch, ProtocolFamily, Table};
+use std::fmt::Debug;
-pub type Priority = i32;
+pub type ChainPriority = i32;
/// The netfilter event hooks a chain can register for.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u16)]
-pub enum Hook {
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[repr(i32)]
+pub enum HookClass {
/// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`.
- PreRouting = libc::NF_INET_PRE_ROUTING as u16,
+ PreRouting = libc::NF_INET_PRE_ROUTING,
/// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`.
- In = libc::NF_INET_LOCAL_IN as u16,
+ In = libc::NF_INET_LOCAL_IN,
/// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`.
- Forward = libc::NF_INET_FORWARD as u16,
+ Forward = libc::NF_INET_FORWARD,
/// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`.
- Out = libc::NF_INET_LOCAL_OUT as u16,
+ Out = libc::NF_INET_LOCAL_OUT,
/// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`.
- PostRouting = libc::NF_INET_POST_ROUTING as u16,
+ PostRouting = libc::NF_INET_POST_ROUTING,
+}
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true)]
+pub struct Hook {
+ /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it.
+ #[field(NFTA_HOOK_HOOKNUM)]
+ class: u32,
+ #[field(NFTA_HOOK_PRIORITY)]
+ priority: u32,
+}
+
+impl Hook {
+ pub fn new(class: HookClass, priority: ChainPriority) -> Self {
+ Hook::default()
+ .with_class(class as u32)
+ .with_priority(priority as u32)
+ }
}
/// A chain policy. Decides what to do with a packet that was processed by the chain but did not
/// match any rules.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u32)]
-pub enum Policy {
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[repr(i32)]
+pub enum ChainPolicy {
/// Accept the packet.
- Accept = libc::NF_ACCEPT as u32,
+ Accept = NF_ACCEPT,
/// Drop the packet.
- Drop = libc::NF_DROP as u32,
+ Drop = NF_DROP,
+}
+
+impl NfNetlinkAttribute for ChainPolicy {
+ fn get_size(&self) -> usize {
+ (*self as i32).get_size()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ (*self as i32).write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ChainPolicy {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (v, remaining_data) = i32::deserialize(buf)?;
+ Ok((
+ match v {
+ NF_ACCEPT => ChainPolicy::Accept,
+ NF_DROP => ChainPolicy::Accept,
+ _ => return Err(DecodeError::UnknownChainPolicy),
+ },
+ remaining_data,
+ ))
+ }
}
/// Base chain type.
@@ -53,240 +97,117 @@ pub enum ChainType {
}
impl ChainType {
- fn as_c_str(&self) -> &'static [u8] {
+ fn as_str(&self) -> &'static str {
match *self {
- ChainType::Filter => b"filter\0",
- ChainType::Route => b"route\0",
- ChainType::Nat => b"nat\0",
+ ChainType::Filter => "filter",
+ ChainType::Route => "route",
+ ChainType::Nat => "nat",
}
}
}
-/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s.
-///
-/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more
-/// details.
+impl NfNetlinkAttribute for ChainType {
+ fn get_size(&self) -> usize {
+ self.as_str().len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ self.as_str().to_string().write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ChainType {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (s, remaining_data) = String::deserialize(buf)?;
+ Ok((
+ match s.as_str() {
+ "filter" => ChainType::Filter,
+ "route" => ChainType::Route,
+ "nat" => ChainType::Nat,
+ _ => return Err(DecodeError::UnknownChainType),
+ },
+ remaining_data,
+ ))
+ }
+}
+
+/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s.
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-/// [`set_hook`]: #method.set_hook
+#[derive(PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(derive_deserialize = false)]
pub struct Chain {
- pub(crate) chain: *mut sys::nftnl_chain,
- pub(crate) table: Rc<Table>,
+ family: ProtocolFamily,
+ #[field(NFTA_CHAIN_TABLE)]
+ table: String,
+ #[field(NFTA_CHAIN_NAME)]
+ name: String,
+ #[field(NFTA_CHAIN_HOOK)]
+ hook: Hook,
+ #[field(NFTA_CHAIN_POLICY)]
+ policy: ChainPolicy,
+ #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")]
+ chain_type: ChainType,
+ #[field(NFTA_CHAIN_FLAGS)]
+ flags: u32,
+ #[field(NFTA_CHAIN_USERDATA)]
+ userdata: Vec<u8>,
}
impl Chain {
- /// Creates a new chain instance inside the given [`Table`] and with the given name.
+ /// Creates a new chain instance inside the given [`Table`].
///
/// [`Table`]: struct.Table.html
- pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain {
- unsafe {
- let chain = try_alloc!(sys::nftnl_chain_alloc());
- sys::nftnl_chain_set_u32(
- chain,
- sys::NFTNL_CHAIN_FAMILY as u16,
- table.get_family() as u32,
- );
- sys::nftnl_chain_set_str(
- chain,
- sys::NFTNL_CHAIN_TABLE as u16,
- table.get_name().as_ptr(),
- );
- sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr());
- Chain { chain, table }
- }
- }
-
- pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self {
- Chain { chain, table }
- }
+ pub fn new(table: &Table) -> Chain {
+ let mut chain = Chain::default();
+ chain.family = table.get_family();
- /// Sets the hook and priority for this chain. Without calling this method the chain will
- /// become a "regular chain" without any hook and will thus not receive any traffic unless
- /// some rule forward packets to it via goto or jump verdicts.
- ///
- /// By calling `set_hook` with a hook the chain that is created will be registered with that
- /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the
- /// networking stack.
- pub fn set_hook(&mut self, hook: Hook, priority: Priority) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32);
- sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority);
+ if let Some(table_name) = table.get_name() {
+ chain.set_table(table_name);
}
- }
- /// Set the type of a base chain. This only applies if the chain has been registered
- /// with a hook by calling `set_hook`.
- pub fn set_type(&mut self, chain_type: ChainType) {
- unsafe {
- sys::nftnl_chain_set_str(
- self.chain,
- sys::NFTNL_CHAIN_TYPE as u16,
- chain_type.as_c_str().as_ptr() as *const c_char,
- );
- }
+ chain
}
- /// Sets the default policy for this chain. That means what action netfilter will apply to
- /// packets processed by this chain, but that did not match any rules in it.
- pub fn set_policy(&mut self, policy: Policy) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32);
- }
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16);
- if ptr == std::ptr::null() {
- return None;
- }
- Some(CStr::from_ptr(ptr))
- }
- }
-
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr());
- }
- }
-
- /// Returns the name of this chain.
- pub fn get_name(&self) -> &CStr {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16);
- if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a chain failed")
- } else {
- CStr::from_ptr(ptr)
- }
- }
- }
-
- /// Returns a textual description of the chain.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_chain_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.chain,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- /// Returns a reference to the [`Table`] this chain belongs to.
- ///
- /// [`Table`]: struct.Table.html
- pub fn get_table(&self) -> Rc<Table> {
- self.table.clone()
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_chain {
- self.chain as *const sys::nftnl_chain
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain {
- self.chain
+ /// Appends this chain to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
}
}
-impl fmt::Debug for Chain {
- /// Returns a string representation of the chain.
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "{:?}", self.get_str())
- }
-}
+impl NfNetlinkObject for Chain {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN;
-impl PartialEq for Chain {
- fn eq(&self, other: &Self) -> bool {
- self.get_table() == other.get_table() && self.get_name() == other.get_name()
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
}
-}
-unsafe impl crate::NlMsg for Chain {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- let raw_msg_type = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWCHAIN,
- MsgType::Del => libc::NFT_MSG_DELCHAIN,
- };
- let flags: u16 = match msg_type {
- MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16,
- MsgType::Del => libc::NLM_F_ACK as u16,
- } | libc::NLM_F_ACK as u16;
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- raw_msg_type as u16,
- self.table.get_family() as u16,
- flags,
- seq,
- );
- sys::nftnl_chain_nlmsg_build_payload(header, self.chain);
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-impl Drop for Chain {
- fn drop(&mut self) {
- unsafe { sys::nftnl_chain_free(self.chain) };
- }
-}
-
-#[cfg(feature = "query")]
-pub fn get_chains_cb<'a>(
- header: &libc::nlmsghdr,
- (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>),
-) -> libc::c_int {
- unsafe {
- let chain = sys::nftnl_chain_alloc();
- if chain == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
- }
- let err = sys::nftnl_chain_nlmsg_parse(header, chain);
- if err < 0 {
- error!("Failed to parse nelink chain message - {}", err);
- sys::nftnl_chain_free(chain);
- return err;
- }
-
- let table_name = CStr::from_ptr(sys::nftnl_chain_get_str(
- chain,
- sys::NFTNL_CHAIN_TABLE as u16,
- ));
- let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16);
- let family = match crate::ProtoFamily::try_from(family as i32) {
- Ok(family) => family,
- Err(crate::InvalidProtocolFamily) => {
- error!("The netlink table didn't have a valid protocol family !?");
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_ERROR;
+pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> {
+ let mut result = Vec::new();
+ crate::query::list_objects_with_data(
+ libc::NFT_MSG_GETCHAIN as u16,
+ &|chain: Chain, (table, chains): &mut (&Table, &mut Vec<Chain>)| {
+ if chain.get_table() == table.get_name() {
+ chains.push(chain);
+ } else {
+ info!(
+ "Ignoring chain {:?} because it doesn't map the table {:?}",
+ chain.get_name(),
+ table.get_name()
+ );
}
- };
-
- if table_name != table.get_name() {
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
- }
-
- if family != crate::ProtoFamily::Unspec && family != table.get_family() {
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
- }
-
- chains.push(Chain::from_raw(chain, table.clone()));
- }
- mnl::mnl_sys::MNL_CB_OK
-}
-
-#[cfg(feature = "query")]
-pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> {
- crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None)
+ Ok(())
+ },
+ None,
+ &mut (&table, &mut result),
+ )?;
+ Ok(result)
}
diff --git a/src/chain_methods.rs b/src/chain_methods.rs
deleted file mode 100644
index d384c35..0000000
--- a/src/chain_methods.rs
+++ /dev/null
@@ -1,40 +0,0 @@
-use crate::{Batch, Chain, Hook, MsgType, Policy, Table};
-use std::ffi::CString;
-use std::rc::Rc;
-
-
-/// A helper trait over [`crate::Chain`].
-pub trait ChainMethods {
- /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`].
- fn from_hook(hook: Hook, table: Rc<Table>) -> Self
- where Self: std::marker::Sized;
- /// Adds a [`crate::Policy`] to the current Chain.
- fn verdict(self, policy: Policy) -> Self;
- fn add_to_batch(self, batch: &mut Batch) -> Self;
-}
-
-
-impl ChainMethods for Chain {
- fn from_hook(hook: Hook, table: Rc<Table>) -> Self {
- let chain_name = match hook {
- Hook::PreRouting => "prerouting",
- Hook::Out => "out",
- Hook::PostRouting => "postrouting",
- Hook::Forward => "forward",
- Hook::In => "in",
- };
- let chain_name = CString::new(chain_name).unwrap();
- let mut chain = Chain::new(&chain_name, table);
- chain.set_hook(hook, 0);
- chain
- }
- fn verdict(mut self, policy: Policy) -> Self {
- self.set_policy(policy);
- self
- }
- fn add_to_batch(self, batch: &mut Batch) -> Self {
- batch.add(&self, MsgType::Add);
- self
- }
-}
-
diff --git a/src/data_type.rs b/src/data_type.rs
new file mode 100644
index 0000000..43a7f1a
--- /dev/null
+++ b/src/data_type.rs
@@ -0,0 +1,42 @@
+use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
+
+pub trait DataType {
+ const TYPE: u32;
+ const LEN: u32;
+
+ fn data(&self) -> Vec<u8>;
+}
+
+impl DataType for Ipv4Addr {
+ const TYPE: u32 = 7;
+ const LEN: u32 = 4;
+
+ fn data(&self) -> Vec<u8> {
+ self.octets().to_vec()
+ }
+}
+
+impl DataType for Ipv6Addr {
+ const TYPE: u32 = 8;
+ const LEN: u32 = 16;
+
+ fn data(&self) -> Vec<u8> {
+ self.octets().to_vec()
+ }
+}
+
+impl<const N: usize> DataType for [u8; N] {
+ const TYPE: u32 = 5;
+ const LEN: u32 = N as u32;
+
+ fn data(&self) -> Vec<u8> {
+ self.to_vec()
+ }
+}
+
+pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> {
+ match ip {
+ IpAddr::V4(x) => x.octets().to_vec(),
+ IpAddr::V6(x) => x.octets().to_vec(),
+ }
+}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..f6b6247
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,180 @@
+use std::string::FromUtf8Error;
+
+use nix::errno::Errno;
+use thiserror::Error;
+
+use crate::sys::nlmsgerr;
+
+#[derive(Error, Debug)]
+pub enum DecodeError {
+ #[error("The buffer is too small to hold a valid message")]
+ BufTooSmall,
+
+ #[error("The message is too small")]
+ NlMsgTooSmall,
+
+ #[error("The message holds unexpected data")]
+ InvalidDataSize,
+
+ #[error("Invalid subsystem, expected NFTABLES")]
+ InvalidSubsystem(u8),
+
+ #[error("Invalid version, expected NFNETLINK_V0")]
+ InvalidVersion(u8),
+
+ #[error("Invalid port ID")]
+ InvalidPortId(u32),
+
+ #[error("Invalid sequence number")]
+ InvalidSeq(u32),
+
+ #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
+ ConcurrentGenerationUpdate,
+
+ #[error("Unsupported message type")]
+ UnsupportedType(u16),
+
+ #[error("Invalid attribute type")]
+ InvalidAttributeType,
+
+ #[error("Invalid type for a chain")]
+ UnknownChainType,
+
+ #[error("Invalid policy for a chain")]
+ UnknownChainPolicy,
+
+ #[error("Unknown type for a Meta expression")]
+ UnknownMetaType(u32),
+
+ #[error("Unsupported value for an icmp reject type")]
+ UnknownRejectType(u32),
+
+ #[error("Unsupported value for an icmp code in a reject expression")]
+ UnknownIcmpCode(u8),
+
+ #[error("Invalid value for a register")]
+ UnknownRegister(u32),
+
+ #[error("Invalid type for a verdict expression")]
+ UnknownVerdictType(i32),
+
+ #[error("Invalid type for a nat expression")]
+ UnknownNatType(i32),
+
+ #[error("Invalid type for a payload expression")]
+ UnknownPayloadType(u32),
+
+ #[error("Invalid type for a compare expression")]
+ UnknownCmpOp(u32),
+
+ #[error("Invalid type for a conntrack key")]
+ UnknownConntrackKey(u32),
+
+ #[error("Unsupported value for a link layer header field")]
+ UnknownLinkLayerHeaderField(u32, u32),
+
+ #[error("Unsupported value for an IPv4 header field")]
+ UnknownIPv4HeaderField(u32, u32),
+
+ #[error("Unsupported value for an IPv6 header field")]
+ UnknownIPv6HeaderField(u32, u32),
+
+ #[error("Unsupported value for a TCP header field")]
+ UnknownTCPHeaderField(u32, u32),
+
+ #[error("Unsupported value for an UDP header field")]
+ UnknownUDPHeaderField(u32, u32),
+
+ #[error("Unsupported value for an ICMPv6 header field")]
+ UnknownICMPv6HeaderField(u32, u32),
+
+ #[error("Missing the 'base' attribute to deserialize the payload object")]
+ PayloadMissingBase,
+
+ #[error("Missing the 'offset' attribute to deserialize the payload object")]
+ PayloadMissingOffset,
+
+ #[error("Missing the 'len' attribute to deserialize the payload object")]
+ PayloadMissingLen,
+
+ #[error("The object does not contain a name for the expression being parsed")]
+ MissingExpressionName,
+
+ #[error("Unsupported attribute type")]
+ UnsupportedAttributeType(u16),
+
+ #[error("Unexpected message type")]
+ UnexpectedType(u16),
+
+ #[error("The decoded String is not UTF8 compliant")]
+ StringDecodeFailure(#[from] FromUtf8Error),
+
+ #[error("Invalid value for a protocol family")]
+ UnknownProtocolFamily(i32),
+
+ #[error("A custom error occured")]
+ Custom(Box<dyn std::error::Error + 'static>),
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum BuilderError {
+ #[error("The length of the arguments are not compatible with each other")]
+ IncompatibleLength,
+
+ #[error("The table does not have a name")]
+ MissingTableName,
+
+ #[error("Missing information in the chain to create a rule")]
+ MissingChainInformationError,
+
+ #[error("Missing name for the set")]
+ MissingSetName,
+
+ #[error("The interface name is too long to be written")]
+ InterfaceNameTooLong,
+
+ #[error("The log prefix string is more than 127 characters long")]
+ TooLongLogPrefix,
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum QueryError {
+ #[error("Unable to open netlink socket to netfilter")]
+ NetlinkOpenError(#[source] nix::Error),
+
+ #[error("Unable to send netlink command to netfilter")]
+ NetlinkSendError(#[source] nix::Error),
+
+ #[error("Error while reading from netlink socket")]
+ NetlinkRecvError(#[source] nix::Error),
+
+ #[error("Error while processing an incoming netlink message")]
+ ProcessNetlinkError(#[from] DecodeError),
+
+ #[error("Error while building netlink objects in Rust")]
+ BuilderError(#[from] BuilderError),
+
+ #[error("Error received from the kernel")]
+ NetlinkError(nlmsgerr),
+
+ #[error("Custom error when customizing the query")]
+ InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
+
+ #[error("Couldn't allocate a netlink object, out of memory ?")]
+ NetlinkAllocationFailed,
+
+ #[error("This socket is not a netlink socket")]
+ NotNetlinkSocket,
+
+ #[error("Couldn't retrieve information on a socket")]
+ RetrievingSocketInfoFailed,
+
+ #[error("Only a part of the message was sent")]
+ TruncatedSend,
+
+ #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")]
+ UndecidableMessageTermination,
+
+ #[error("Couldn't close the socket")]
+ CloseFailed(#[source] Errno),
+}
diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs
index d34d22c..fb40a04 100644
--- a/src/expr/bitwise.rs
+++ b/src/expr/bitwise.rs
@@ -1,69 +1,47 @@
-use super::{Expression, Rule, ToSlice};
-use crate::sys::{self, libc};
-use std::ffi::c_void;
-use std::os::raw::c_char;
-
-/// Expression for performing bitwise masking and XOR on the data in a register.
-pub struct Bitwise<M: ToSlice, X: ToSlice> {
- mask: M,
- xor: X,
+use rustables_macros::nfnetlink_struct;
+
+use super::{Expression, Register};
+use crate::error::BuilderError;
+use crate::parser_impls::NfNetlinkData;
+use crate::sys::{
+ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR,
+};
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
+pub struct Bitwise {
+ #[field(NFTA_BITWISE_SREG)]
+ sreg: Register,
+ #[field(NFTA_BITWISE_DREG)]
+ dreg: Register,
+ #[field(NFTA_BITWISE_LEN)]
+ len: u32,
+ #[field(NFTA_BITWISE_MASK)]
+ mask: NfNetlinkData,
+ #[field(NFTA_BITWISE_XOR)]
+ xor: NfNetlinkData,
}
-impl<M: ToSlice, X: ToSlice> Bitwise<M, X> {
- /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and
- /// then performs xor with the value in `xor`.
- pub fn new(mask: M, xor: X) -> Self {
- Self { mask, xor }
+impl Expression for Bitwise {
+ fn get_name() -> &'static str {
+ "bitwise"
}
}
-impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> {
- fn get_raw_name() -> *const c_char {
- b"bitwise\0" as *const _ as *const c_char
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- let mask = self.mask.to_slice();
- let xor = self.xor.to_slice();
- assert!(mask.len() == xor.len());
- let len = mask.len() as u32;
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_BITWISE_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_BITWISE_DREG as u16,
- libc::NFT_REG_1 as u32,
- );
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_BITWISE_LEN as u16, len);
-
- sys::nftnl_expr_set(
- expr,
- sys::NFTNL_EXPR_BITWISE_MASK as u16,
- mask.as_ref() as *const _ as *const c_void,
- len,
- );
- sys::nftnl_expr_set(
- expr,
- sys::NFTNL_EXPR_BITWISE_XOR as u16,
- xor.as_ref() as *const _ as *const c_void,
- len,
- );
-
- expr
+impl Bitwise {
+ /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and
+ /// then performs xor with the value in `xor`
+ pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> {
+ let mask = mask.into();
+ let xor = xor.into();
+ if mask.len() != xor.len() {
+ return Err(BuilderError::IncompatibleLength);
}
+ Ok(Bitwise::default()
+ .with_sreg(Register::Reg1)
+ .with_dreg(Register::Reg1)
+ .with_len(mask.len() as u32)
+ .with_xor(NfNetlinkData::default().with_value(xor))
+ .with_mask(NfNetlinkData::default().with_value(mask)))
}
}
-
-#[macro_export]
-macro_rules! nft_expr_bitwise {
- (mask $mask:expr,xor $xor:expr) => {
- $crate::expr::Bitwise::new($mask, $xor)
- };
-}
diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs
index f6ea900..86d3587 100644
--- a/src/expr/cmp.rs
+++ b/src/expr/cmp.rs
@@ -1,187 +1,64 @@
-use super::{DeserializationError, Expression, Rule, ToSlice};
-use crate::sys::{self, libc};
-use std::{
- borrow::Cow,
- ffi::{c_void, CString},
- os::raw::c_char,
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
+
+use crate::{
+ parser_impls::NfNetlinkData,
+ sys::{
+ NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT,
+ NFT_CMP_LTE, NFT_CMP_NEQ,
+ },
};
+use super::{Expression, Register};
+
/// Comparison operator.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+#[nfnetlink_enum(u32, nested = true)]
pub enum CmpOp {
/// Equals.
- Eq,
+ Eq = NFT_CMP_EQ,
/// Not equal.
- Neq,
+ Neq = NFT_CMP_NEQ,
/// Less than.
- Lt,
+ Lt = NFT_CMP_LT,
/// Less than, or equal.
- Lte,
+ Lte = NFT_CMP_LTE,
/// Greater than.
- Gt,
+ Gt = NFT_CMP_GT,
/// Greater than, or equal.
- Gte,
-}
-
-impl CmpOp {
- /// Returns the corresponding `NFT_*` constant for this comparison operation.
- pub fn to_raw(self) -> u32 {
- use self::CmpOp::*;
- match self {
- Eq => libc::NFT_CMP_EQ as u32,
- Neq => libc::NFT_CMP_NEQ as u32,
- Lt => libc::NFT_CMP_LT as u32,
- Lte => libc::NFT_CMP_LTE as u32,
- Gt => libc::NFT_CMP_GT as u32,
- Gte => libc::NFT_CMP_GTE as u32,
- }
- }
-
- pub fn from_raw(val: u32) -> Result<Self, DeserializationError> {
- use self::CmpOp::*;
- match val as i32 {
- libc::NFT_CMP_EQ => Ok(Eq),
- libc::NFT_CMP_NEQ => Ok(Neq),
- libc::NFT_CMP_LT => Ok(Lt),
- libc::NFT_CMP_LTE => Ok(Lte),
- libc::NFT_CMP_GT => Ok(Gt),
- libc::NFT_CMP_GTE => Ok(Gte),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
+ Gte = NFT_CMP_GTE,
}
/// Comparator expression. Allows comparing the content of the netfilter register with any value.
-#[derive(Debug, PartialEq)]
-pub struct Cmp<T> {
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct]
+pub struct Cmp {
+ #[field(NFTA_CMP_SREG)]
+ sreg: Register,
+ #[field(NFTA_CMP_OP)]
op: CmpOp,
- data: T,
+ #[field(NFTA_CMP_DATA)]
+ data: NfNetlinkData,
}
-impl<T: ToSlice> Cmp<T> {
+impl Cmp {
/// Returns a new comparison expression comparing the value loaded in the register with the
/// data in `data` using the comparison operator `op`.
- pub fn new(op: CmpOp, data: T) -> Self {
- Cmp { op, data }
- }
-}
-
-impl<T: ToSlice> Expression for Cmp<T> {
- fn get_raw_name() -> *const c_char {
- b"cmp\0" as *const _ as *const c_char
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- let data = self.data.to_slice();
- trace!("Creating a cmp expr comparing with data {:?}", data);
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_CMP_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16, self.op.to_raw());
- sys::nftnl_expr_set(
- expr,
- sys::NFTNL_EXPR_CMP_DATA as u16,
- data.as_ptr() as *const c_void,
- data.len() as u32,
- );
-
- expr
- }
- }
-}
-
-impl<const N: usize> Expression for Cmp<[u8; N]> {
- fn get_raw_name() -> *const c_char {
- Cmp::<u8>::get_raw_name()
- }
-
- /// The raw data contained inside `Cmp` expressions can only be deserialized to arrays of
- /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your
- /// responsibility to provide the correct length of the byte data. If the data size is invalid,
- /// you will get the error `DeserializationError::InvalidDataSize`.
- ///
- /// Example (warning, no error checking!):
- /// ```rust
- /// use std::ffi::CString;
- /// use std::net::Ipv4Addr;
- /// use std::rc::Rc;
- ///
- /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table};
- ///
- /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet));
- /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table));
- /// let mut rule = Rule::new(chain);
- /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16));
- /// for expr in Rc::new(rule).get_exprs() {
- /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap());
- /// }
- /// ```
- /// These limitations occur because casting bytes to any type of the same size
- /// as the raw input would be *extremely* dangerous in terms of memory safety.
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
- unsafe {
- let ref_len = std::mem::size_of::<[u8; N]>() as u32;
- let mut data_len = 0;
- let data = sys::nftnl_expr_get(
- expr,
- sys::NFTNL_EXPR_CMP_DATA as u16,
- &mut data_len as *mut u32,
- );
-
- if data.is_null() {
- return Err(DeserializationError::NullPointer);
- } else if data_len != ref_len {
- return Err(DeserializationError::InvalidDataSize);
- }
-
- let data = *(data as *const [u8; N]);
-
- let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?;
- Ok(Cmp { op, data })
- }
- }
-
- // call to the other implementation to generate the expression
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self {
Cmp {
- data: &self.data as &[u8],
- op: self.op,
+ sreg: Some(Register::Reg1),
+ op: Some(op),
+ data: Some(NfNetlinkData::default().with_value(data.into())),
}
- .to_expr(rule)
}
}
-#[macro_export(local_inner_macros)]
-macro_rules! nft_expr_cmp {
- (@cmp_op ==) => {
- $crate::expr::CmpOp::Eq
- };
- (@cmp_op !=) => {
- $crate::expr::CmpOp::Neq
- };
- (@cmp_op <) => {
- $crate::expr::CmpOp::Lt
- };
- (@cmp_op <=) => {
- $crate::expr::CmpOp::Lte
- };
- (@cmp_op >) => {
- $crate::expr::CmpOp::Gt
- };
- (@cmp_op >=) => {
- $crate::expr::CmpOp::Gte
- };
- ($op:tt $data:expr) => {
- $crate::expr::Cmp::new(nft_expr_cmp!(@cmp_op $op), $data)
- };
+impl Expression for Cmp {
+ fn get_name() -> &'static str {
+ "cmp"
+ }
}
+/*
/// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please note
/// that it is faster to check interface index than name.
///
@@ -207,13 +84,4 @@ impl ToSlice for InterfaceName {
Cow::from(bytes)
}
}
-
-impl<'a> ToSlice for &'a InterfaceName {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let bytes = match *self {
- InterfaceName::Exact(ref name) => name.as_bytes_with_nul(),
- InterfaceName::StartingWith(ref name) => name.as_bytes(),
- };
- Cow::from(bytes)
- }
-}
+*/
diff --git a/src/expr/counter.rs b/src/expr/counter.rs
index 4732e85..d22fb8a 100644
--- a/src/expr/counter.rs
+++ b/src/expr/counter.rs
@@ -1,46 +1,21 @@
-use super::{DeserializationError, Expression, Rule};
+use rustables_macros::nfnetlink_struct;
+
+use super::Expression;
use crate::sys;
-use std::os::raw::c_char;
/// A counter expression adds a counter to the rule that is incremented to count number of packets
/// and number of bytes for all packets that have matched the rule.
-#[derive(Debug, PartialEq)]
+#[derive(Default, Clone, Debug, PartialEq, Eq)]
+#[nfnetlink_struct]
pub struct Counter {
+ #[field(sys::NFTA_COUNTER_BYTES)]
pub nb_bytes: u64,
+ #[field(sys::NFTA_COUNTER_PACKETS)]
pub nb_packets: u64,
}
-impl Counter {
- pub fn new() -> Self {
- Self {
- nb_bytes: 0,
- nb_packets: 0,
- }
- }
-}
-
impl Expression for Counter {
- fn get_raw_name() -> *const c_char {
- b"counter\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
- unsafe {
- let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16);
- let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16);
- Ok(Counter {
- nb_bytes,
- nb_packets,
- })
- }
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
- sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes);
- sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets);
- expr
- }
+ fn get_name() -> &'static str {
+ "counter"
}
}
diff --git a/src/expr/ct.rs b/src/expr/ct.rs
index 7d6614c..ad76989 100644
--- a/src/expr/ct.rs
+++ b/src/expr/ct.rs
@@ -1,9 +1,13 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys::{self, libc};
-use std::os::raw::c_char;
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
+
+use crate::sys::{
+ NFTA_CT_DIRECTION, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_CT_SREG, NFT_CT_MARK, NFT_CT_STATE,
+};
+
+use super::{Expression, Register};
bitflags::bitflags! {
- pub struct States: u32 {
+ pub struct ConnTrackState: u32 {
const INVALID = 1;
const ESTABLISHED = 2;
const RELATED = 4;
@@ -12,76 +16,54 @@ bitflags::bitflags! {
}
}
-pub enum Conntrack {
- State,
- Mark { set: bool },
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+#[nfnetlink_enum(u32, nested = true)]
+pub enum ConntrackKey {
+ State = NFT_CT_STATE,
+ Mark = NFT_CT_MARK,
}
-impl Conntrack {
- fn raw_key(&self) -> u32 {
- match *self {
- Conntrack::State => libc::NFT_CT_STATE as u32,
- Conntrack::Mark { .. } => libc::NFT_CT_MARK as u32,
- }
- }
+#[derive(Default, Clone, Debug, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
+pub struct Conntrack {
+ #[field(NFTA_CT_DREG)]
+ pub dreg: Register,
+ #[field(NFTA_CT_KEY)]
+ pub key: ConntrackKey,
+ #[field(NFTA_CT_DIRECTION)]
+ pub direction: u8,
+ #[field(NFTA_CT_SREG)]
+ pub sreg: Register,
}
impl Expression for Conntrack {
- fn get_raw_name() -> *const c_char {
- b"ct\0" as *const _ as *const c_char
+ fn get_name() -> &'static str {
+ "ct"
}
+}
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16);
- let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16);
-
- match ct_key as i32 {
- libc::NFT_CT_STATE => Ok(Conntrack::State),
- libc::NFT_CT_MARK => Ok(Conntrack::Mark {
- set: ct_sreg_is_set,
- }),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
+impl Conntrack {
+ pub fn new(key: ConntrackKey) -> Self {
+ Self::default().with_dreg(Register::Reg1).with_key(key)
}
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
+ pub fn set_mark_value(&mut self, reg: Register) {
+ self.set_sreg(reg);
+ self.set_key(ConntrackKey::Mark);
+ }
- if let Conntrack::Mark { set: true } = self {
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_CT_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- } else {
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_CT_DREG as u16,
- libc::NFT_REG_1 as u32,
- );
- }
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16, self.raw_key());
+ pub fn with_mark_value(mut self, reg: Register) -> Self {
+ self.set_mark_value(reg);
+ self
+ }
- expr
- }
+ pub fn retrieve_value(&mut self, key: ConntrackKey) {
+ self.set_key(key);
+ self.set_dreg(Register::Reg1);
}
-}
-#[macro_export]
-macro_rules! nft_expr_ct {
- (state) => {
- $crate::expr::Conntrack::State
- };
- (mark set) => {
- $crate::expr::Conntrack::Mark { set: true }
- };
- (mark) => {
- $crate::expr::Conntrack::Mark { set: false }
- };
+ pub fn with_retrieve_value(mut self, key: ConntrackKey) -> Self {
+ self.retrieve_value(key);
+ self
+ }
}
diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs
index 71453b3..2fd9bd5 100644
--- a/src/expr/immediate.rs
+++ b/src/expr/immediate.rs
@@ -1,124 +1,50 @@
-use super::{DeserializationError, Expression, Register, Rule, ToSlice};
-use crate::sys;
-use std::ffi::c_void;
-use std::os::raw::c_char;
-
-/// An immediate expression. Used to set immediate data. Verdicts are handled separately by
-/// [crate::expr::Verdict].
-#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub struct Immediate<T> {
- pub data: T,
- pub register: Register,
+use rustables_macros::nfnetlink_struct;
+
+use super::{Expression, Register, Verdict, VerdictKind, VerdictType};
+use crate::{
+ parser_impls::NfNetlinkData,
+ sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG},
+};
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
+pub struct Immediate {
+ #[field(NFTA_IMMEDIATE_DREG)]
+ dreg: Register,
+ #[field(NFTA_IMMEDIATE_DATA)]
+ data: NfNetlinkData,
}
-impl<T> Immediate<T> {
- pub fn new(data: T, register: Register) -> Self {
- Self { data, register }
+impl Immediate {
+ pub fn new_data(data: Vec<u8>, register: Register) -> Self {
+ Immediate::default()
+ .with_dreg(register)
+ .with_data(NfNetlinkData::default().with_value(data))
}
-}
-
-impl<T: ToSlice> Expression for Immediate<T> {
- fn get_raw_name() -> *const c_char {
- b"immediate\0" as *const _ as *const c_char
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_IMM_DREG as u16,
- self.register.to_raw(),
- );
-
- let data = self.data.to_slice();
- sys::nftnl_expr_set(
- expr,
- sys::NFTNL_EXPR_IMM_DATA as u16,
- data.as_ptr() as *const c_void,
- data.len() as u32,
- );
-
- expr
+ pub fn new_verdict(kind: VerdictKind) -> Self {
+ let code = match kind {
+ VerdictKind::Drop => VerdictType::Drop,
+ VerdictKind::Accept => VerdictType::Accept,
+ VerdictKind::Queue => VerdictType::Queue,
+ VerdictKind::Continue => VerdictType::Continue,
+ VerdictKind::Break => VerdictType::Break,
+ VerdictKind::Jump { .. } => VerdictType::Jump,
+ VerdictKind::Goto { .. } => VerdictType::Goto,
+ VerdictKind::Return => VerdictType::Return,
+ };
+ let mut data = Verdict::default().with_code(code);
+ if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind {
+ data.set_chain(chain);
}
+ Immediate::default()
+ .with_dreg(Register::Verdict)
+ .with_data(NfNetlinkData::default().with_verdict(data))
}
}
-impl<const N: usize> Expression for Immediate<[u8; N]> {
- fn get_raw_name() -> *const c_char {
- Immediate::<u8>::get_raw_name()
- }
-
- /// The raw data contained inside `Immediate` expressions can only be deserialized to arrays of
- /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your
- /// responsibility to provide the correct length of the byte data. If the data size is invalid,
- /// you will get the error `DeserializationError::InvalidDataSize`.
- ///
- /// Example (warning, no error checking!):
- /// ```rust
- /// use std::ffi::CString;
- /// use std::net::Ipv4Addr;
- /// use std::rc::Rc;
- ///
- /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table};
- ///
- /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet));
- /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table));
- /// let mut rule = Rule::new(chain);
- /// rule.add_expr(&Immediate::new(42u8, Register::Reg1));
- /// for expr in Rc::new(rule).get_exprs() {
- /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap());
- /// }
- /// ```
- /// These limitations occur because casting bytes to any type of the same size as the raw input
- /// would be *extremely* dangerous in terms of memory safety.
- // As casting bytes to any type of the same size as the input would be *extremely* dangerous in
- // terms of memory safety, rustables only accept to deserialize expressions with variable-size
- // data to arrays of bytes, so that the memory layout cannot be invalid.
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
- unsafe {
- let ref_len = std::mem::size_of::<[u8; N]>() as u32;
- let mut data_len = 0;
- let data = sys::nftnl_expr_get(
- expr,
- sys::NFTNL_EXPR_IMM_DATA as u16,
- &mut data_len as *mut u32,
- );
-
- if data.is_null() {
- return Err(DeserializationError::NullPointer);
- } else if data_len != ref_len {
- return Err(DeserializationError::InvalidDataSize);
- }
-
- let data = *(data as *const [u8; N]);
-
- let register = Register::from_raw(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_IMM_DREG as u16,
- ))?;
-
- Ok(Immediate { data, register })
- }
- }
-
- // call to the other implementation to generate the expression
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
- Immediate {
- register: self.register,
- data: &self.data as &[u8],
- }
- .to_expr(rule)
+impl Expression for Immediate {
+ fn get_name() -> &'static str {
+ "immediate"
}
}
-
-#[macro_export]
-macro_rules! nft_expr_immediate {
- (data $value:expr) => {
- $crate::expr::Immediate {
- data: $value,
- register: $crate::expr::Register::Reg1,
- }
- };
-}
diff --git a/src/expr/log.rs b/src/expr/log.rs
index 8d20b48..cc2728e 100644
--- a/src/expr/log.rs
+++ b/src/expr/log.rs
@@ -1,112 +1,41 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys;
-use std::ffi::{CStr, CString};
-use std::os::raw::c_char;
-use thiserror::Error;
+use rustables_macros::nfnetlink_struct;
+use super::Expression;
+use crate::{
+ error::BuilderError,
+ sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX},
+};
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
/// A Log expression will log all packets that match the rule.
-#[derive(Debug, PartialEq)]
pub struct Log {
- pub group: Option<LogGroup>,
- pub prefix: Option<LogPrefix>,
+ #[field(NFTA_LOG_GROUP)]
+ group: u16,
+ #[field(NFTA_LOG_PREFIX)]
+ prefix: String,
}
-impl Expression for Log {
- fn get_raw_name() -> *const sys::libc::c_char {
- b"log\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let mut group = None;
- if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) {
- group = Some(LogGroup(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_LOG_GROUP as u16,
- ) as u16));
- }
- let mut prefix = None;
- if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) {
- let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16);
- if raw_prefix.is_null() {
- return Err(DeserializationError::NullPointer);
- } else {
- prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned()));
- }
- }
- Ok(Log { group, prefix })
+impl Log {
+ pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> {
+ let mut res = Log::default();
+ if let Some(group) = group {
+ res.set_group(group);
}
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char));
- if let Some(log_group) = self.group {
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32);
- };
- if let Some(LogPrefix(prefix)) = &self.prefix {
- sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr());
- };
+ if let Some(prefix) = prefix {
+ let prefix = prefix.into();
- expr
+ if prefix.bytes().count() > 127 {
+ return Err(BuilderError::TooLongLogPrefix);
+ }
+ res.set_prefix(prefix);
}
+ Ok(res)
}
}
-#[derive(Error, Debug)]
-pub enum LogPrefixError {
- #[error("The log prefix string is more than 128 characters long")]
- TooLongPrefix,
- #[error("The log prefix string contains an invalid Nul character.")]
- PrefixContainsANul(#[from] std::ffi::NulError),
-}
-
-/// The NFLOG group that will be assigned to each log line.
-#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
-pub struct LogGroup(pub u16);
-
-/// A prefix that will get prepended to each log line.
-#[derive(Debug, Clone, PartialEq)]
-pub struct LogPrefix(CString);
-
-impl LogPrefix {
- /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that
- /// LogPrefix should not be more than 127 characters long.
- pub fn new(prefix: &str) -> Result<Self, LogPrefixError> {
- if prefix.chars().count() > 127 {
- return Err(LogPrefixError::TooLongPrefix);
- }
- Ok(LogPrefix(CString::new(prefix)?))
+impl Expression for Log {
+ fn get_name() -> &'static str {
+ "log"
}
}
-
-#[macro_export]
-macro_rules! nft_expr_log {
- (group $group:ident prefix $prefix:expr) => {
- $crate::expr::Log {
- group: $group,
- prefix: $prefix,
- }
- };
- (prefix $prefix:expr) => {
- $crate::expr::Log {
- group: None,
- prefix: $prefix,
- }
- };
- (group $group:ident) => {
- $crate::expr::Log {
- group: $group,
- prefix: None,
- }
- };
- () => {
- $crate::expr::Log {
- group: None,
- prefix: None,
- }
- };
-}
diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs
index a0cc021..2ef830e 100644
--- a/src/expr/lookup.rs
+++ b/src/expr/lookup.rs
@@ -1,78 +1,40 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::set::Set;
-use crate::sys::{self, libc};
-use std::ffi::{CStr, CString};
-use std::os::raw::c_char;
+use rustables_macros::nfnetlink_struct;
-#[derive(Debug, PartialEq)]
+use super::{Expression, Register};
+use crate::error::BuilderError;
+use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG};
+use crate::Set;
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
pub struct Lookup {
- set_name: CString,
+ #[field(NFTA_LOOKUP_SET)]
+ set: String,
+ #[field(NFTA_LOOKUP_SREG)]
+ sreg: Register,
+ #[field(NFTA_LOOKUP_DREG)]
+ dreg: Register,
+ #[field(NFTA_LOOKUP_SET_ID)]
set_id: u32,
}
impl Lookup {
- /// Creates a new lookup entry. May return None if the set has no name.
- pub fn new<K>(set: &Set<K>) -> Option<Self> {
- set.get_name().map(|set_name| Lookup {
- set_name: set_name.to_owned(),
- set_id: set.get_id(),
- })
- }
-}
-
-impl Expression for Lookup {
- fn get_raw_name() -> *const libc::c_char {
- b"lookup\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16);
- let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16);
-
- if set_name.is_null() {
- return Err(DeserializationError::NullPointer);
- }
-
- let set_name = CStr::from_ptr(set_name).to_owned();
-
- Ok(Lookup { set_id, set_name })
+ /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name.
+ pub fn new(set: &Set) -> Result<Self, BuilderError> {
+ let mut res = Lookup::default()
+ .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?)
+ .with_sreg(Register::Reg1);
+
+ if let Some(id) = set.get_id() {
+ res.set_set_id(*id);
}
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_LOOKUP_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- sys::nftnl_expr_set_str(
- expr,
- sys::NFTNL_EXPR_LOOKUP_SET as u16,
- self.set_name.as_ptr() as *const _ as *const c_char,
- );
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id);
- // This code is left here since it's quite likely we need it again when we get further
- // if self.reverse {
- // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16,
- // libc::NFT_LOOKUP_F_INV as u32);
- // }
-
- expr
- }
+ Ok(res)
}
}
-#[macro_export]
-macro_rules! nft_expr_lookup {
- ($set:expr) => {
- $crate::expr::Lookup::new($set)
- };
+impl Expression for Lookup {
+ fn get_name() -> &'static str {
+ "lookup"
+ }
}
diff --git a/src/expr/masquerade.rs b/src/expr/masquerade.rs
index c1a06de..dce787f 100644
--- a/src/expr/masquerade.rs
+++ b/src/expr/masquerade.rs
@@ -1,24 +1,20 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys;
-use std::os::raw::c_char;
+use rustables_macros::nfnetlink_struct;
+
+use super::Expression;
/// Sets the source IP to that of the output interface.
-#[derive(Debug, PartialEq)]
+#[derive(Default, Debug, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
pub struct Masquerade;
-impl Expression for Masquerade {
- fn get_raw_name() -> *const sys::libc::c_char {
- b"masq\0" as *const _ as *const c_char
- }
-
- fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- Ok(Masquerade)
+impl Clone for Masquerade {
+ fn clone(&self) -> Self {
+ Masquerade {}
}
+}
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) })
+impl Expression for Masquerade {
+ fn get_name() -> &'static str {
+ "masq"
}
}
diff --git a/src/expr/meta.rs b/src/expr/meta.rs
index a015f65..3ecb1d1 100644
--- a/src/expr/meta.rs
+++ b/src/expr/meta.rs
@@ -1,175 +1,62 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys::{self, libc};
-use std::os::raw::c_char;
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
+
+use super::{Expression, Register};
+use crate::sys;
/// A meta expression refers to meta data associated with a packet.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+#[nfnetlink_enum(u32)]
#[non_exhaustive]
-pub enum Meta {
+pub enum MetaType {
/// Packet ethertype protocol (skb->protocol), invalid in OUTPUT.
- Protocol,
+ Protocol = sys::NFT_META_PROTOCOL,
/// Packet mark.
- Mark { set: bool },
+ Mark = sys::NFT_META_MARK,
/// Packet input interface index (dev->ifindex).
- Iif,
+ Iif = sys::NFT_META_IIF,
/// Packet output interface index (dev->ifindex).
- Oif,
+ Oif = sys::NFT_META_OIF,
/// Packet input interface name (dev->name).
- IifName,
+ IifName = sys::NFT_META_IIFNAME,
/// Packet output interface name (dev->name).
- OifName,
+ OifName = sys::NFT_META_OIFNAME,
/// Packet input interface type (dev->type).
- IifType,
+ IifType = libc::NFT_META_IIFTYPE,
/// Packet output interface type (dev->type).
- OifType,
+ OifType = sys::NFT_META_OIFTYPE,
/// Originating socket UID (fsuid).
- SkUid,
+ SkUid = sys::NFT_META_SKUID,
/// Originating socket GID (fsgid).
- SkGid,
+ SkGid = sys::NFT_META_SKGID,
/// Netfilter protocol (Transport layer protocol).
- NfProto,
+ NfProto = sys::NFT_META_NFPROTO,
/// Layer 4 protocol number.
- L4Proto,
+ L4Proto = sys::NFT_META_L4PROTO,
/// Socket control group (skb->sk->sk_classid).
- Cgroup,
+ Cgroup = sys::NFT_META_CGROUP,
/// A 32bit pseudo-random number.
- PRandom,
+ PRandom = sys::NFT_META_PRANDOM,
}
-impl Meta {
- /// Returns the corresponding `NFT_*` constant for this meta expression.
- pub fn to_raw_key(&self) -> u32 {
- use Meta::*;
- match *self {
- Protocol => libc::NFT_META_PROTOCOL as u32,
- Mark { .. } => libc::NFT_META_MARK as u32,
- Iif => libc::NFT_META_IIF as u32,
- Oif => libc::NFT_META_OIF as u32,
- IifName => libc::NFT_META_IIFNAME as u32,
- OifName => libc::NFT_META_OIFNAME as u32,
- IifType => libc::NFT_META_IIFTYPE as u32,
- OifType => libc::NFT_META_OIFTYPE as u32,
- SkUid => libc::NFT_META_SKUID as u32,
- SkGid => libc::NFT_META_SKGID as u32,
- NfProto => libc::NFT_META_NFPROTO as u32,
- L4Proto => libc::NFT_META_L4PROTO as u32,
- Cgroup => libc::NFT_META_CGROUP as u32,
- PRandom => libc::NFT_META_PRANDOM as u32,
- }
- }
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
+pub struct Meta {
+ #[field(sys::NFTA_META_DREG)]
+ dreg: Register,
+ #[field(sys::NFTA_META_KEY)]
+ key: MetaType,
+ #[field(sys::NFTA_META_SREG)]
+ sreg: Register,
+}
- fn from_raw(val: u32) -> Result<Self, DeserializationError> {
- match val as i32 {
- libc::NFT_META_PROTOCOL => Ok(Self::Protocol),
- libc::NFT_META_MARK => Ok(Self::Mark { set: false }),
- libc::NFT_META_IIF => Ok(Self::Iif),
- libc::NFT_META_OIF => Ok(Self::Oif),
- libc::NFT_META_IIFNAME => Ok(Self::IifName),
- libc::NFT_META_OIFNAME => Ok(Self::OifName),
- libc::NFT_META_IIFTYPE => Ok(Self::IifType),
- libc::NFT_META_OIFTYPE => Ok(Self::OifType),
- libc::NFT_META_SKUID => Ok(Self::SkUid),
- libc::NFT_META_SKGID => Ok(Self::SkGid),
- libc::NFT_META_NFPROTO => Ok(Self::NfProto),
- libc::NFT_META_L4PROTO => Ok(Self::L4Proto),
- libc::NFT_META_CGROUP => Ok(Self::Cgroup),
- libc::NFT_META_PRANDOM => Ok(Self::PRandom),
- _ => Err(DeserializationError::InvalidValue),
- }
+impl Meta {
+ pub fn new(ty: MetaType) -> Self {
+ Meta::default().with_dreg(Register::Reg1).with_key(ty)
}
}
impl Expression for Meta {
- fn get_raw_name() -> *const libc::c_char {
- b"meta\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let mut ret = Self::from_raw(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_META_KEY as u16,
- ))?;
-
- if let Self::Mark { ref mut set } = ret {
- *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16);
- }
-
- Ok(ret)
- }
+ fn get_name() -> &'static str {
+ "meta"
}
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- if let Meta::Mark { set: true } = self {
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_META_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- } else {
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_META_DREG as u16,
- libc::NFT_REG_1 as u32,
- );
- }
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key());
- expr
- }
- }
-}
-
-#[macro_export]
-macro_rules! nft_expr_meta {
- (proto) => {
- $crate::expr::Meta::Protocol
- };
- (mark set) => {
- $crate::expr::Meta::Mark { set: true }
- };
- (mark) => {
- $crate::expr::Meta::Mark { set: false }
- };
- (iif) => {
- $crate::expr::Meta::Iif
- };
- (oif) => {
- $crate::expr::Meta::Oif
- };
- (iifname) => {
- $crate::expr::Meta::IifName
- };
- (oifname) => {
- $crate::expr::Meta::OifName
- };
- (iiftype) => {
- $crate::expr::Meta::IifType
- };
- (oiftype) => {
- $crate::expr::Meta::OifType
- };
- (skuid) => {
- $crate::expr::Meta::SkUid
- };
- (skgid) => {
- $crate::expr::Meta::SkGid
- };
- (nfproto) => {
- $crate::expr::Meta::NfProto
- };
- (l4proto) => {
- $crate::expr::Meta::L4Proto
- };
- (cgroup) => {
- $crate::expr::Meta::Cgroup
- };
- (random) => {
- $crate::expr::Meta::PRandom
- };
}
diff --git a/src/expr/mod.rs b/src/expr/mod.rs
index dc59507..058b0cb 100644
--- a/src/expr/mod.rs
+++ b/src/expr/mod.rs
@@ -3,14 +3,14 @@
//!
//! [`Rule`]: struct.Rule.html
-use std::borrow::Cow;
-use std::net::IpAddr;
-use std::net::Ipv4Addr;
-use std::net::Ipv6Addr;
+use std::fmt::Debug;
-use super::rule::Rule;
-use crate::sys::{self, libc};
-use thiserror::Error;
+use rustables_macros::nfnetlink_struct;
+
+use crate::error::DecodeError;
+use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable};
+use crate::parser_impls::NfNetlinkList;
+use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME};
mod bitwise;
pub use self::bitwise::*;
@@ -46,7 +46,7 @@ mod payload;
pub use self::payload::*;
mod reject;
-pub use self::reject::{IcmpCode, Reject};
+pub use self::reject::{IcmpCode, Reject, RejectType};
mod register;
pub use self::register::Register;
@@ -54,189 +54,161 @@ pub use self::register::Register;
mod verdict;
pub use self::verdict::*;
-mod wrapper;
-pub use self::wrapper::ExpressionWrapper;
-
-#[derive(Debug, Error)]
-pub enum DeserializationError {
- #[error("The expected expression type doesn't match the name of the raw expression")]
- /// The expected expression type doesn't match the name of the raw expression.
- InvalidExpressionKind,
-
- #[error("Deserializing the requested type isn't implemented yet")]
- /// Deserializing the requested type isn't implemented yet.
- NotImplemented,
-
- #[error("The expression value cannot be deserialized to the requested type")]
- /// The expression value cannot be deserialized to the requested type.
- InvalidValue,
-
- #[error("A pointer was null while a non-null pointer was expected")]
- /// A pointer was null while a non-null pointer was expected.
- NullPointer,
-
- #[error(
- "The size of a raw value was incoherent with the expected type of the deserialized value"
- )]
- /// The size of a raw value was incoherent with the expected type of the deserialized value/
- InvalidDataSize,
-
- #[error(transparent)]
- /// Couldn't find a matching protocol.
- InvalidProtolFamily(#[from] super::InvalidProtocolFamily),
-}
-
-/// Trait for every safe wrapper of an nftables expression.
pub trait Expression {
- /// Returns the raw name used by nftables to identify the rule.
- fn get_raw_name() -> *const libc::c_char;
-
- /// Try to parse the expression from a raw nftables expression, returning a
- /// [DeserializationError] if the attempted parsing failed.
- fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- Err(DeserializationError::NotImplemented)
- }
-
- /// Allocates and returns the low level `nftnl_expr` representation of this expression. The
- /// caller to this method is responsible for freeing the expression.
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr;
+ fn get_name() -> &'static str;
}
-/// A type that can be converted into a byte buffer.
-pub trait ToSlice {
- /// Returns the data this type represents.
- fn to_slice(&self) -> Cow<'_, [u8]>;
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true, derive_decoder = false)]
+pub struct RawExpression {
+ #[field(NFTA_EXPR_NAME)]
+ name: String,
+ #[field(NFTA_EXPR_DATA)]
+ data: ExpressionVariant,
}
-impl<'a> ToSlice for &'a [u8] {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Borrowed(self)
+impl<T> From<T> for RawExpression
+where
+ T: Expression,
+ ExpressionVariant: From<T>,
+{
+ fn from(val: T) -> Self {
+ RawExpression::default()
+ .with_name(T::get_name())
+ .with_data(ExpressionVariant::from(val))
}
}
-impl<'a> ToSlice for &'a [u16] {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let ptr = self.as_ptr() as *const u8;
- let len = self.len() * 2;
- Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) })
- }
-}
-
-impl ToSlice for IpAddr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- match *self {
- IpAddr::V4(ref addr) => addr.to_slice(),
- IpAddr::V6(ref addr) => addr.to_slice(),
+macro_rules! create_expr_variant {
+ ($enum:ident $(, [$name:ident, $type:ty])+) => {
+ #[derive(Debug, Clone, PartialEq, Eq)]
+ pub enum $enum {
+ $(
+ $name($type),
+ )+
}
- }
-}
-impl ToSlice for Ipv4Addr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(self.octets().to_vec())
- }
-}
-
-impl ToSlice for Ipv6Addr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(self.octets().to_vec())
- }
-}
+ impl $crate::nlmsg::NfNetlinkAttribute for $enum {
+ fn is_nested(&self) -> bool {
+ true
+ }
+
+ fn get_size(&self) -> usize {
+ match self {
+ $(
+ $enum::$name(val) => val.get_size(),
+ )+
+ }
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ match self {
+ $(
+ $enum::$name(val) => val.write_payload(addr),
+ )+
+ }
+ }
+ }
-impl ToSlice for u8 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(vec![*self])
- }
+ $(
+ impl From<$type> for $enum {
+ fn from(val: $type) -> Self {
+ $enum::$name(val)
+ }
+ }
+ )+
+
+ impl $crate::nlmsg::AttributeDecoder for RawExpression {
+ fn decode_attribute(
+ &mut self,
+ attr_type: u16,
+ buf: &[u8],
+ ) -> Result<(), $crate::error::DecodeError> {
+ debug!("Decoding attribute {} in an expression", attr_type);
+ match attr_type {
+ x if x == sys::NFTA_EXPR_NAME => {
+ debug!("Calling {}::deserialize()", std::any::type_name::<String>());
+ let (val, remaining) = String::deserialize(buf)?;
+ if remaining.len() != 0 {
+ return Err($crate::error::DecodeError::InvalidDataSize);
+ }
+ self.name = Some(val);
+ Ok(())
+ },
+ x if x == sys::NFTA_EXPR_DATA => {
+ // we can assume we have already the name parsed, as that's how we identify the
+ // type of expression
+ let name = self.name.as_ref()
+ .ok_or($crate::error::DecodeError::MissingExpressionName)?;
+ match name {
+ $(
+ x if x == <$type>::get_name() => {
+ debug!("Calling {}::deserialize()", std::any::type_name::<$type>());
+ let (res, remaining) = <$type>::deserialize(buf)?;
+ if remaining.len() != 0 {
+ return Err($crate::error::DecodeError::InvalidDataSize);
+ }
+ self.data = Some(ExpressionVariant::from(res));
+ Ok(())
+ },
+ )+
+ name => {
+ info!("Unrecognized expression '{}', generating an ExpressionRaw", name);
+ self.data = Some(ExpressionVariant::ExpressionRaw(ExpressionRaw::deserialize(buf)?.0));
+ Ok(())
+ }
+ }
+ },
+ _ => Err(DecodeError::UnsupportedAttributeType(attr_type)),
+ }
+ }
+ }
+ };
}
-impl ToSlice for u16 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = (*self & 0x00ff) as u8;
- let b1 = (*self >> 8) as u8;
- Cow::Owned(vec![b0, b1])
+create_expr_variant!(
+ ExpressionVariant,
+ [Bitwise, Bitwise],
+ [Cmp, Cmp],
+ [Conntrack, Conntrack],
+ [Counter, Counter],
+ [ExpressionRaw, ExpressionRaw],
+ [Immediate, Immediate],
+ [Log, Log],
+ [Lookup, Lookup],
+ [Masquerade, Masquerade],
+ [Meta, Meta],
+ [Nat, Nat],
+ [Payload, Payload],
+ [Reject, Reject]
+);
+
+pub type ExpressionList = NfNetlinkList<RawExpression>;
+
+// default type for expressions that we do not handle yet
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ExpressionRaw(Vec<u8>);
+
+impl NfNetlinkAttribute for ExpressionRaw {
+ fn get_size(&self) -> usize {
+ self.0.get_size()
}
-}
-impl ToSlice for u32 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = *self as u8;
- let b1 = (*self >> 8) as u8;
- let b2 = (*self >> 16) as u8;
- let b3 = (*self >> 24) as u8;
- Cow::Owned(vec![b0, b1, b2, b3])
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ self.0.write_payload(addr);
}
}
-impl ToSlice for i32 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = *self as u8;
- let b1 = (*self >> 8) as u8;
- let b2 = (*self >> 16) as u8;
- let b3 = (*self >> 24) as u8;
- Cow::Owned(vec![b0, b1, b2, b3])
+impl NfNetlinkDeserializable for ExpressionRaw {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((ExpressionRaw(buf.to_vec()), &[]))
}
}
-impl<'a> ToSlice for &'a str {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::from(self.as_bytes())
+// Because we loose the name of the expression when parsing, this is the only expression
+// where deserializing a message and then reserializing it is invalid
+impl Expression for ExpressionRaw {
+ fn get_name() -> &'static str {
+ "unknown_expression"
}
}
-
-#[macro_export(local_inner_macros)]
-macro_rules! nft_expr {
- (bitwise mask $mask:expr,xor $xor:expr) => {
- nft_expr_bitwise!(mask $mask, xor $xor)
- };
- (cmp $op:tt $data:expr) => {
- nft_expr_cmp!($op $data)
- };
- (counter) => {
- $crate::expr::Counter { nb_bytes: 0, nb_packets: 0}
- };
- (ct $key:ident set) => {
- nft_expr_ct!($key set)
- };
- (ct $key:ident) => {
- nft_expr_ct!($key)
- };
- (immediate $expr:ident $value:expr) => {
- nft_expr_immediate!($expr $value)
- };
- (log group $group:ident prefix $prefix:expr) => {
- nft_expr_log!(group $group prefix $prefix)
- };
- (log group $group:ident) => {
- nft_expr_log!(group $group)
- };
- (log prefix $prefix:expr) => {
- nft_expr_log!(prefix $prefix)
- };
- (log) => {
- nft_expr_log!()
- };
- (lookup $set:expr) => {
- nft_expr_lookup!($set)
- };
- (masquerade) => {
- $crate::expr::Masquerade
- };
- (meta $expr:ident set) => {
- nft_expr_meta!($expr set)
- };
- (meta $expr:ident) => {
- nft_expr_meta!($expr)
- };
- (payload $proto:ident $field:ident) => {
- nft_expr_payload!($proto $field)
- };
- (verdict $verdict:ident) => {
- nft_expr_verdict!($verdict)
- };
- (verdict $verdict:ident $chain:expr) => {
- nft_expr_verdict!($verdict $chain)
- };
-}
diff --git a/src/expr/nat.rs b/src/expr/nat.rs
index ce6b881..406b2e6 100644
--- a/src/expr/nat.rs
+++ b/src/expr/nat.rs
@@ -1,99 +1,37 @@
-use super::{DeserializationError, Expression, Register, Rule};
-use crate::ProtoFamily;
-use crate::sys::{self, libc};
-use std::{convert::TryFrom, os::raw::c_char};
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
+
+use super::{Expression, Register};
+use crate::{
+ sys::{self, NFT_NAT_DNAT, NFT_NAT_SNAT},
+ ProtocolFamily,
+};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(i32)]
+#[nfnetlink_enum(i32)]
pub enum NatType {
/// Source NAT. Changes the source address of a packet.
- SNat = libc::NFT_NAT_SNAT,
+ SNat = NFT_NAT_SNAT,
/// Destination NAT. Changes the destination address of a packet.
- DNat = libc::NFT_NAT_DNAT,
-}
-
-impl NatType {
- fn from_raw(val: u32) -> Result<Self, DeserializationError> {
- match val as i32 {
- libc::NFT_NAT_SNAT => Ok(NatType::SNat),
- libc::NFT_NAT_DNAT => Ok(NatType::DNat),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
+ DNat = NFT_NAT_DNAT,
}
/// A source or destination NAT statement. Modifies the source or destination address (and possibly
/// port) of packets.
-#[derive(Debug, PartialEq)]
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
pub struct Nat {
+ #[field(sys::NFTA_NAT_TYPE)]
pub nat_type: NatType,
- pub family: ProtoFamily,
+ #[field(sys::NFTA_NAT_FAMILY)]
+ pub family: ProtocolFamily,
+ #[field(sys::NFTA_NAT_REG_ADDR_MIN)]
pub ip_register: Register,
- pub port_register: Option<Register>,
+ #[field(sys::NFTA_NAT_REG_PROTO_MIN)]
+ pub port_register: Register,
}
impl Expression for Nat {
- fn get_raw_name() -> *const libc::c_char {
- b"nat\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_NAT_TYPE as u16,
- ))?;
-
- let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_NAT_FAMILY as u16,
- ) as i32)?;
-
- let ip_register = Register::from_raw(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16,
- ))?;
-
- let mut port_register = None;
- if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) {
- port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32(
- expr,
- sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16,
- ))?);
- }
-
- Ok(Nat {
- ip_register,
- nat_type,
- family,
- port_register,
- })
- }
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) });
-
- unsafe {
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32);
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, self.family as u32);
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16,
- self.ip_register.to_raw(),
- );
- if let Some(port_register) = self.port_register {
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16,
- port_register.to_raw(),
- );
- }
- }
-
- expr
+ fn get_name() -> &'static str {
+ "nat"
}
}
diff --git a/src/expr/payload.rs b/src/expr/payload.rs
index a108fe8..d0b2cea 100644
--- a/src/expr/payload.rs
+++ b/src/expr/payload.rs
@@ -1,128 +1,96 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys::{self, libc};
-use std::os::raw::c_char;
+use rustables_macros::nfnetlink_struct;
-pub trait HeaderField {
- fn offset(&self) -> u32;
- fn len(&self) -> u32;
+use super::{Expression, Register};
+use crate::{
+ error::DecodeError,
+ sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER},
+};
+
+/// Payload expressions refer to data from the packet's payload.
+#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
+pub struct Payload {
+ #[field(sys::NFTA_PAYLOAD_DREG)]
+ dreg: Register,
+ #[field(sys::NFTA_PAYLOAD_BASE)]
+ base: u32,
+ #[field(sys::NFTA_PAYLOAD_OFFSET)]
+ offset: u32,
+ #[field(sys::NFTA_PAYLOAD_LEN)]
+ len: u32,
+ #[field(sys::NFTA_PAYLOAD_SREG)]
+ sreg: Register,
+}
+
+impl Expression for Payload {
+ fn get_name() -> &'static str {
+ "payload"
+ }
}
/// Payload expressions refer to data from the packet's payload.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-pub enum Payload {
+pub enum HighLevelPayload {
LinkLayer(LLHeaderField),
Network(NetworkHeaderField),
Transport(TransportHeaderField),
}
-impl Payload {
- pub fn build(&self) -> RawPayload {
+impl HighLevelPayload {
+ pub fn build(&self) -> Payload {
match *self {
- Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData {
- offset: f.offset(),
- len: f.len(),
- }),
- Payload::Network(ref f) => RawPayload::Network(RawPayloadData {
- offset: f.offset(),
- len: f.len(),
- }),
- Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData {
- offset: f.offset(),
- len: f.len(),
- }),
+ HighLevelPayload::LinkLayer(ref f) => Payload::default()
+ .with_base(NFT_PAYLOAD_LL_HEADER)
+ .with_offset(f.offset())
+ .with_len(f.len()),
+ HighLevelPayload::Network(ref f) => Payload::default()
+ .with_base(NFT_PAYLOAD_NETWORK_HEADER)
+ .with_offset(f.offset())
+ .with_len(f.len()),
+ HighLevelPayload::Transport(ref f) => Payload::default()
+ .with_base(NFT_PAYLOAD_TRANSPORT_HEADER)
+ .with_offset(f.offset())
+ .with_len(f.len()),
}
+ .with_dreg(Register::Reg1)
}
}
-impl Expression for Payload {
- fn get_raw_name() -> *const libc::c_char {
- RawPayload::get_raw_name()
- }
-
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
- self.build().to_expr(rule)
- }
-}
-
-#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-pub struct RawPayloadData {
- offset: u32,
- len: u32,
-}
-
-/// Because deserializing a `Payload` expression is not possible (there is not enough information
-/// in the expression itself), this enum should be used to deserialize payloads.
+/// Payload expressions refer to data from the packet's payload.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-pub enum RawPayload {
- LinkLayer(RawPayloadData),
- Network(RawPayloadData),
- Transport(RawPayloadData),
+pub enum PayloadType {
+ LinkLayer(LLHeaderField),
+ Network,
+ Transport,
}
-impl RawPayload {
- fn base(&self) -> u32 {
- match self {
- Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32,
- Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32,
- Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32,
+impl PayloadType {
+ pub fn parse_from_payload(raw: &Payload) -> Result<Self, DecodeError> {
+ if raw.base.is_none() {
+ return Err(DecodeError::PayloadMissingBase);
}
- }
-}
-
-impl HeaderField for RawPayload {
- fn offset(&self) -> u32 {
- match self {
- Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset,
+ if raw.len.is_none() {
+ return Err(DecodeError::PayloadMissingLen);
}
- }
-
- fn len(&self) -> u32 {
- match self {
- Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len,
+ if raw.offset.is_none() {
+ return Err(DecodeError::PayloadMissingOffset);
}
+ Ok(match raw.base {
+ Some(NFT_PAYLOAD_LL_HEADER) => PayloadType::LinkLayer(LLHeaderField::from_raw_data(
+ raw.offset.unwrap(),
+ raw.len.unwrap(),
+ )?),
+ Some(NFT_PAYLOAD_NETWORK_HEADER) => PayloadType::Network,
+ Some(NFT_PAYLOAD_TRANSPORT_HEADER) => PayloadType::Transport,
+ Some(v) => return Err(DecodeError::UnknownPayloadType(v)),
+ None => return Err(DecodeError::PayloadMissingBase),
+ })
}
}
-impl Expression for RawPayload {
- fn get_raw_name() -> *const libc::c_char {
- b"payload\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
- unsafe {
- let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16);
- let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16);
- let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16);
- match base as i32 {
- libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })),
- libc::NFT_PAYLOAD_NETWORK_HEADER => {
- Ok(Self::Network(RawPayloadData { offset, len }))
- }
- libc::NFT_PAYLOAD_TRANSPORT_HEADER => {
- Ok(Self::Transport(RawPayloadData { offset, len }))
- }
-
- _ => return Err(DeserializationError::InvalidValue),
- }
- }
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base());
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset());
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16, self.len());
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_PAYLOAD_DREG as u16,
- libc::NFT_REG_1 as u32,
- );
-
- expr
- }
- }
+pub trait HeaderField {
+ fn offset(&self) -> u32;
+ fn len(&self) -> u32;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
@@ -154,58 +122,52 @@ impl HeaderField for LLHeaderField {
}
impl LLHeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 0 && len == 6 {
- Ok(Self::Daddr)
- } else if off == 6 && len == 6 {
- Ok(Self::Saddr)
- } else if off == 12 && len == 2 {
- Ok(Self::EtherType)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (0, 6) => Self::Daddr,
+ (6, 6) => Self::Saddr,
+ (12, 2) => Self::EtherType,
+ _ => return Err(DecodeError::UnknownLinkLayerHeaderField(offset, len)),
+ })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum NetworkHeaderField {
- Ipv4(Ipv4HeaderField),
- Ipv6(Ipv6HeaderField),
+ IPv4(IPv4HeaderField),
+ IPv6(IPv6HeaderField),
}
impl HeaderField for NetworkHeaderField {
fn offset(&self) -> u32 {
use self::NetworkHeaderField::*;
match *self {
- Ipv4(ref f) => f.offset(),
- Ipv6(ref f) => f.offset(),
+ IPv4(ref f) => f.offset(),
+ IPv6(ref f) => f.offset(),
}
}
fn len(&self) -> u32 {
use self::NetworkHeaderField::*;
match *self {
- Ipv4(ref f) => f.len(),
- Ipv6(ref f) => f.len(),
+ IPv4(ref f) => f.len(),
+ IPv6(ref f) => f.len(),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
-pub enum Ipv4HeaderField {
+pub enum IPv4HeaderField {
Ttl,
Protocol,
Saddr,
Daddr,
}
-impl HeaderField for Ipv4HeaderField {
+impl HeaderField for IPv4HeaderField {
fn offset(&self) -> u32 {
- use self::Ipv4HeaderField::*;
+ use self::IPv4HeaderField::*;
match *self {
Ttl => 8,
Protocol => 9,
@@ -215,7 +177,7 @@ impl HeaderField for Ipv4HeaderField {
}
fn len(&self) -> u32 {
- use self::Ipv4HeaderField::*;
+ use self::IPv4HeaderField::*;
match *self {
Ttl => 1,
Protocol => 1,
@@ -225,37 +187,30 @@ impl HeaderField for Ipv4HeaderField {
}
}
-impl Ipv4HeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 8 && len == 1 {
- Ok(Self::Ttl)
- } else if off == 9 && len == 1 {
- Ok(Self::Protocol)
- } else if off == 12 && len == 4 {
- Ok(Self::Saddr)
- } else if off == 16 && len == 4 {
- Ok(Self::Daddr)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+impl IPv4HeaderField {
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (8, 1) => Self::Ttl,
+ (9, 1) => Self::Protocol,
+ (12, 4) => Self::Saddr,
+ (16, 4) => Self::Daddr,
+ _ => return Err(DecodeError::UnknownIPv4HeaderField(offset, len)),
+ })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
-pub enum Ipv6HeaderField {
+pub enum IPv6HeaderField {
NextHeader,
HopLimit,
Saddr,
Daddr,
}
-impl HeaderField for Ipv6HeaderField {
+impl HeaderField for IPv6HeaderField {
fn offset(&self) -> u32 {
- use self::Ipv6HeaderField::*;
+ use self::IPv6HeaderField::*;
match *self {
NextHeader => 6,
HopLimit => 7,
@@ -265,7 +220,7 @@ impl HeaderField for Ipv6HeaderField {
}
fn len(&self) -> u32 {
- use self::Ipv6HeaderField::*;
+ use self::IPv6HeaderField::*;
match *self {
NextHeader => 1,
HopLimit => 1,
@@ -275,31 +230,24 @@ impl HeaderField for Ipv6HeaderField {
}
}
-impl Ipv6HeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 6 && len == 1 {
- Ok(Self::NextHeader)
- } else if off == 7 && len == 1 {
- Ok(Self::HopLimit)
- } else if off == 8 && len == 16 {
- Ok(Self::Saddr)
- } else if off == 24 && len == 16 {
- Ok(Self::Daddr)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+impl IPv6HeaderField {
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (6, 1) => Self::NextHeader,
+ (7, 1) => Self::HopLimit,
+ (8, 16) => Self::Saddr,
+ (24, 16) => Self::Daddr,
+ _ => return Err(DecodeError::UnknownIPv6HeaderField(offset, len)),
+ })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum TransportHeaderField {
- Tcp(TcpHeaderField),
- Udp(UdpHeaderField),
- Icmpv6(Icmpv6HeaderField),
+ Tcp(TCPHeaderField),
+ Udp(UDPHeaderField),
+ ICMPv6(ICMPv6HeaderField),
}
impl HeaderField for TransportHeaderField {
@@ -308,7 +256,7 @@ impl HeaderField for TransportHeaderField {
match *self {
Tcp(ref f) => f.offset(),
Udp(ref f) => f.offset(),
- Icmpv6(ref f) => f.offset(),
+ ICMPv6(ref f) => f.offset(),
}
}
@@ -317,21 +265,21 @@ impl HeaderField for TransportHeaderField {
match *self {
Tcp(ref f) => f.len(),
Udp(ref f) => f.len(),
- Icmpv6(ref f) => f.len(),
+ ICMPv6(ref f) => f.len(),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
-pub enum TcpHeaderField {
+pub enum TCPHeaderField {
Sport,
Dport,
}
-impl HeaderField for TcpHeaderField {
+impl HeaderField for TCPHeaderField {
fn offset(&self) -> u32 {
- use self::TcpHeaderField::*;
+ use self::TCPHeaderField::*;
match *self {
Sport => 0,
Dport => 2,
@@ -339,7 +287,7 @@ impl HeaderField for TcpHeaderField {
}
fn len(&self) -> u32 {
- use self::TcpHeaderField::*;
+ use self::TCPHeaderField::*;
match *self {
Sport => 2,
Dport => 2,
@@ -347,32 +295,27 @@ impl HeaderField for TcpHeaderField {
}
}
-impl TcpHeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 0 && len == 2 {
- Ok(Self::Sport)
- } else if off == 2 && len == 2 {
- Ok(Self::Dport)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+impl TCPHeaderField {
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (0, 2) => Self::Sport,
+ (2, 2) => Self::Dport,
+ _ => return Err(DecodeError::UnknownTCPHeaderField(offset, len)),
+ })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
-pub enum UdpHeaderField {
+pub enum UDPHeaderField {
Sport,
Dport,
Len,
}
-impl HeaderField for UdpHeaderField {
+impl HeaderField for UDPHeaderField {
fn offset(&self) -> u32 {
- use self::UdpHeaderField::*;
+ use self::UDPHeaderField::*;
match *self {
Sport => 0,
Dport => 2,
@@ -381,7 +324,7 @@ impl HeaderField for UdpHeaderField {
}
fn len(&self) -> u32 {
- use self::UdpHeaderField::*;
+ use self::UDPHeaderField::*;
match *self {
Sport => 2,
Dport => 2,
@@ -390,34 +333,28 @@ impl HeaderField for UdpHeaderField {
}
}
-impl UdpHeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 0 && len == 2 {
- Ok(Self::Sport)
- } else if off == 2 && len == 2 {
- Ok(Self::Dport)
- } else if off == 4 && len == 2 {
- Ok(Self::Len)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+impl UDPHeaderField {
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (0, 2) => Self::Sport,
+ (2, 2) => Self::Dport,
+ (4, 2) => Self::Len,
+ _ => return Err(DecodeError::UnknownUDPHeaderField(offset, len)),
+ })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
-pub enum Icmpv6HeaderField {
+pub enum ICMPv6HeaderField {
Type,
Code,
Checksum,
}
-impl HeaderField for Icmpv6HeaderField {
+impl HeaderField for ICMPv6HeaderField {
fn offset(&self) -> u32 {
- use self::Icmpv6HeaderField::*;
+ use self::ICMPv6HeaderField::*;
match *self {
Type => 0,
Code => 1,
@@ -426,7 +363,7 @@ impl HeaderField for Icmpv6HeaderField {
}
fn len(&self) -> u32 {
- use self::Icmpv6HeaderField::*;
+ use self::ICMPv6HeaderField::*;
match *self {
Type => 1,
Code => 1,
@@ -435,97 +372,13 @@ impl HeaderField for Icmpv6HeaderField {
}
}
-impl Icmpv6HeaderField {
- pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
- let off = data.offset;
- let len = data.len;
-
- if off == 0 && len == 1 {
- Ok(Self::Type)
- } else if off == 1 && len == 1 {
- Ok(Self::Code)
- } else if off == 2 && len == 2 {
- Ok(Self::Checksum)
- } else {
- Err(DeserializationError::InvalidValue)
- }
+impl ICMPv6HeaderField {
+ pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> {
+ Ok(match (offset, len) {
+ (0, 1) => Self::Type,
+ (1, 1) => Self::Code,
+ (2, 2) => Self::Checksum,
+ _ => return Err(DecodeError::UnknownICMPv6HeaderField(offset, len)),
+ })
}
}
-
-#[macro_export(local_inner_macros)]
-macro_rules! nft_expr_payload {
- (@ipv4_field ttl) => {
- $crate::expr::Ipv4HeaderField::Ttl
- };
- (@ipv4_field protocol) => {
- $crate::expr::Ipv4HeaderField::Protocol
- };
- (@ipv4_field saddr) => {
- $crate::expr::Ipv4HeaderField::Saddr
- };
- (@ipv4_field daddr) => {
- $crate::expr::Ipv4HeaderField::Daddr
- };
-
- (@ipv6_field nextheader) => {
- $crate::expr::Ipv6HeaderField::NextHeader
- };
- (@ipv6_field hoplimit) => {
- $crate::expr::Ipv6HeaderField::HopLimit
- };
- (@ipv6_field saddr) => {
- $crate::expr::Ipv6HeaderField::Saddr
- };
- (@ipv6_field daddr) => {
- $crate::expr::Ipv6HeaderField::Daddr
- };
-
- (@tcp_field sport) => {
- $crate::expr::TcpHeaderField::Sport
- };
- (@tcp_field dport) => {
- $crate::expr::TcpHeaderField::Dport
- };
-
- (@udp_field sport) => {
- $crate::expr::UdpHeaderField::Sport
- };
- (@udp_field dport) => {
- $crate::expr::UdpHeaderField::Dport
- };
- (@udp_field len) => {
- $crate::expr::UdpHeaderField::Len
- };
-
- (ethernet daddr) => {
- $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Daddr)
- };
- (ethernet saddr) => {
- $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Saddr)
- };
- (ethernet ethertype) => {
- $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::EtherType)
- };
-
- (ipv4 $field:ident) => {
- $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv4(
- nft_expr_payload!(@ipv4_field $field),
- ))
- };
- (ipv6 $field:ident) => {
- $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv6(
- nft_expr_payload!(@ipv6_field $field),
- ))
- };
-
- (tcp $field:ident) => {
- $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Tcp(
- nft_expr_payload!(@tcp_field $field),
- ))
- };
- (udp $field:ident) => {
- $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Udp(
- nft_expr_payload!(@udp_field $field),
- ))
- };
-}
diff --git a/src/expr/register.rs b/src/expr/register.rs
index a05af7e..9cc1bee 100644
--- a/src/expr/register.rs
+++ b/src/expr/register.rs
@@ -1,34 +1,17 @@
use std::fmt::Debug;
-use crate::sys::libc;
+use rustables_macros::nfnetlink_enum;
-use super::DeserializationError;
+use crate::sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT};
/// A netfilter data register. The expressions store and read data to and from these when
/// evaluating rule statements.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(i32)]
+#[nfnetlink_enum(u32)]
pub enum Register {
- Verdict = libc::NFT_REG_VERDICT,
- Reg1 = libc::NFT_REG_1,
- Reg2 = libc::NFT_REG_2,
- Reg3 = libc::NFT_REG_3,
- Reg4 = libc::NFT_REG_4,
-}
-
-impl Register {
- pub fn to_raw(self) -> u32 {
- self as u32
- }
-
- pub fn from_raw(val: u32) -> Result<Self, DeserializationError> {
- match val as i32 {
- libc::NFT_REG_VERDICT => Ok(Self::Verdict),
- libc::NFT_REG_1 => Ok(Self::Reg1),
- libc::NFT_REG_2 => Ok(Self::Reg2),
- libc::NFT_REG_3 => Ok(Self::Reg3),
- libc::NFT_REG_4 => Ok(Self::Reg4),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
+ Verdict = NFT_REG_VERDICT,
+ Reg1 = NFT_REG_1,
+ Reg2 = NFT_REG_2,
+ Reg3 = NFT_REG_3,
+ Reg4 = NFT_REG_4,
}
diff --git a/src/expr/reject.rs b/src/expr/reject.rs
index 19752ce..83fd843 100644
--- a/src/expr/reject.rs
+++ b/src/expr/reject.rs
@@ -1,95 +1,40 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::ProtoFamily;
-use crate::sys::{self, libc::{self, c_char}};
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
-/// A reject expression that defines the type of rejection message sent when discarding a packet.
-#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
-pub enum Reject {
- /// Returns an ICMP unreachable packet.
- Icmp(IcmpCode),
- /// Rejects by sending a TCP RST packet.
- TcpRst,
-}
+use crate::sys;
-impl Reject {
- fn to_raw(&self, family: ProtoFamily) -> u32 {
- use libc::*;
- let value = match *self {
- Self::Icmp(..) => match family {
- ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH,
- _ => NFT_REJECT_ICMP_UNREACH,
- },
- Self::TcpRst => NFT_REJECT_TCP_RST,
- };
- value as u32
- }
-}
+use super::Expression;
impl Expression for Reject {
- fn get_raw_name() -> *const libc::c_char {
- b"reject\0" as *const _ as *const c_char
+ fn get_name() -> &'static str {
+ "reject"
}
+}
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16)
- == libc::NFT_REJECT_TCP_RST as u32
- {
- Ok(Self::TcpRst)
- } else {
- Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8(
- expr,
- sys::NFTNL_EXPR_REJECT_CODE as u16,
- ))?))
- }
- }
- }
-
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
- let family = rule.get_chain().get_table().get_family();
-
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_REJECT_TYPE as u16,
- self.to_raw(family),
- );
-
- let reject_code = match *self {
- Reject::Icmp(code) => code as u8,
- Reject::TcpRst => 0,
- };
-
- sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code);
-
- expr
- }
- }
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
+/// A reject expression that defines the type of rejection message sent when discarding a packet.
+pub struct Reject {
+ #[field(sys::NFTA_REJECT_TYPE, name_in_functions = "type")]
+ reject_type: RejectType,
+ #[field(sys::NFTA_REJECT_ICMP_CODE)]
+ icmp_code: IcmpCode,
}
/// An ICMP reject code.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
-#[repr(u8)]
-pub enum IcmpCode {
- NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8,
- PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8,
- HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8,
- AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8,
+#[nfnetlink_enum(u32)]
+pub enum RejectType {
+ IcmpUnreach = sys::NFT_REJECT_ICMP_UNREACH,
+ TcpRst = sys::NFT_REJECT_TCP_RST,
+ IcmpxUnreach = sys::NFT_REJECT_ICMPX_UNREACH,
}
-impl IcmpCode {
- fn from_raw(code: u8) -> Result<Self, DeserializationError> {
- match code as i32 {
- libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute),
- libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach),
- libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach),
- libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
+/// An ICMP reject code.
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
+#[nfnetlink_enum(u8)]
+pub enum IcmpCode {
+ NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE,
+ PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH,
+ HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH,
+ AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED,
}
diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs
index 3c4c374..7edf7cd 100644
--- a/src/expr/verdict.rs
+++ b/src/expr/verdict.rs
@@ -1,11 +1,39 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::sys::{self, libc::{self, c_char}};
-use std::ffi::{CStr, CString};
+use std::fmt::Debug;
+
+use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE};
+use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
+
+use crate::sys::{
+ NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE,
+ NFT_GOTO, NFT_JUMP, NFT_RETURN,
+};
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[nfnetlink_enum(i32)]
+pub enum VerdictType {
+ Drop = NF_DROP,
+ Accept = NF_ACCEPT,
+ Queue = NF_QUEUE,
+ Continue = NFT_CONTINUE,
+ Break = NFT_BREAK,
+ Jump = NFT_JUMP,
+ Goto = NFT_GOTO,
+ Return = NFT_RETURN,
+}
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true)]
+pub struct Verdict {
+ #[field(NFTA_VERDICT_CODE)]
+ code: VerdictType,
+ #[field(NFTA_VERDICT_CHAIN)]
+ chain: String,
+ #[field(NFTA_VERDICT_CHAIN_ID)]
+ chain_id: u32,
+}
-/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl
-/// terms, but here it is simplified to only represent a verdict.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub enum Verdict {
+pub enum VerdictKind {
/// Silently drop the packet.
Drop,
/// Accept the packet and let it pass.
@@ -14,135 +42,10 @@ pub enum Verdict {
Continue,
Break,
Jump {
- chain: CString,
+ chain: String,
},
Goto {
- chain: CString,
+ chain: String,
},
Return,
}
-
-impl Verdict {
- fn chain(&self) -> Option<&CStr> {
- match *self {
- Verdict::Jump { ref chain } => Some(chain.as_c_str()),
- Verdict::Goto { ref chain } => Some(chain.as_c_str()),
- _ => None,
- }
- }
-}
-
-impl Expression for Verdict {
- fn get_raw_name() -> *const libc::c_char {
- b"immediate\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
- unsafe {
- let mut chain = None;
- if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) {
- let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16);
-
- if raw_chain.is_null() {
- return Err(DeserializationError::NullPointer);
- }
- chain = Some(CStr::from_ptr(raw_chain).to_owned());
- }
-
- let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16);
-
- match verdict as i32 {
- libc::NF_DROP => Ok(Verdict::Drop),
- libc::NF_ACCEPT => Ok(Verdict::Accept),
- libc::NF_QUEUE => Ok(Verdict::Queue),
- libc::NFT_CONTINUE => Ok(Verdict::Continue),
- libc::NFT_BREAK => Ok(Verdict::Break),
- libc::NFT_JUMP => {
- if let Some(chain) = chain {
- Ok(Verdict::Jump { chain })
- } else {
- Err(DeserializationError::InvalidValue)
- }
- }
- libc::NFT_GOTO => {
- if let Some(chain) = chain {
- Ok(Verdict::Goto { chain })
- } else {
- Err(DeserializationError::InvalidValue)
- }
- }
- libc::NFT_RETURN => Ok(Verdict::Return),
- _ => Err(DeserializationError::InvalidValue),
- }
- }
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- let immediate_const = match *self {
- Verdict::Drop => libc::NF_DROP,
- Verdict::Accept => libc::NF_ACCEPT,
- Verdict::Queue => libc::NF_QUEUE,
- Verdict::Continue => libc::NFT_CONTINUE,
- Verdict::Break => libc::NFT_BREAK,
- Verdict::Jump { .. } => libc::NFT_JUMP,
- Verdict::Goto { .. } => libc::NFT_GOTO,
- Verdict::Return => libc::NFT_RETURN,
- };
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"immediate\0" as *const _ as *const c_char
- ));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_IMM_DREG as u16,
- libc::NFT_REG_VERDICT as u32,
- );
-
- if let Some(chain) = self.chain() {
- sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr());
- }
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_IMM_VERDICT as u16,
- immediate_const as u32,
- );
-
- expr
- }
- }
-}
-
-#[macro_export]
-macro_rules! nft_expr_verdict {
- (drop) => {
- $crate::expr::Verdict::Drop
- };
- (accept) => {
- $crate::expr::Verdict::Accept
- };
- (reject icmp $code:expr) => {
- $crate::expr::Verdict::Reject(RejectionType::Icmp($code))
- };
- (reject tcp-rst) => {
- $crate::expr::Verdict::Reject(RejectionType::TcpRst)
- };
- (queue) => {
- $crate::expr::Verdict::Queue
- };
- (continue) => {
- $crate::expr::Verdict::Continue
- };
- (break) => {
- $crate::expr::Verdict::Break
- };
- (jump $chain:expr) => {
- $crate::expr::Verdict::Jump { chain: $chain }
- };
- (goto $chain:expr) => {
- $crate::expr::Verdict::Goto { chain: $chain }
- };
- (return) => {
- $crate::expr::Verdict::Return
- };
-}
diff --git a/src/expr/wrapper.rs b/src/expr/wrapper.rs
deleted file mode 100644
index 12ef60b..0000000
--- a/src/expr/wrapper.rs
+++ /dev/null
@@ -1,61 +0,0 @@
-use std::ffi::CStr;
-use std::ffi::CString;
-use std::fmt::Debug;
-use std::rc::Rc;
-use std::os::raw::c_char;
-
-use super::{DeserializationError, Expression};
-use crate::{sys, Rule};
-
-pub struct ExpressionWrapper {
- pub(crate) expr: *const sys::nftnl_expr,
- // we also need the rule here to ensure that the rule lives as long as the `expr` pointer
- #[allow(dead_code)]
- pub(crate) rule: Rc<Rule>,
-}
-
-impl Debug for ExpressionWrapper {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
- }
-}
-
-impl ExpressionWrapper {
- /// Retrieves a textual description of the expression.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_expr_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.expr,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- /// Retrieves the type of expression ("log", "counter", ...).
- pub fn get_kind(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
- }
-
- /// Attempts to decode the expression as the type T.
- pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> {
- if let Some(kind) = self.get_kind() {
- let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) };
- if kind == raw_name {
- return T::from_expr(self.expr);
- }
- }
- Err(DeserializationError::InvalidExpressionKind)
- }
-}
diff --git a/src/lib.rs b/src/lib.rs
index fbb96f3..dec5b76 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby
+// Copyryght (c) 2021-2022 GPL lafleur@boum.org and Simon Thoby
//
// This file is free software: you may copy, redistribute and/or modify it
// under the terms of the GNU General Public License as published by the
@@ -24,106 +24,70 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
-//! Safe abstraction for [`libnftnl`]. Provides userspace access to the in-kernel nf_tables
-//! subsystem. Can be used to create and remove tables, chains, sets and rules from the nftables
+//! Safe abstraction for userspace access to the in-kernel nf_tables subsystem.
+//! Can be used to create and remove tables, chains, sets and rules from the nftables
//! firewall, the successor to iptables.
//!
//! This library currently has quite rough edges and does not make adding and removing netfilter
//! entries super easy and elegant. That is partly because the library needs more work, but also
//! partly because nftables is super low level and extremely customizable, making it hard, and
//! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration.
-//! One can also look at how the original project this crate was developed to support uses it:
-//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app)
//!
-//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by
-//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft`
-//! binary. Since the implementation is mostly based on trial and error, there might of course be
-//! a number of places where the underlying library is used in an invalid or not intended way.
-//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome!
+//! Understanding how to use the netlink subsystem and implementing this crate has mostly been done by
+//! reading the source code for the [`nftables`] userspace program and its corresponding kernel code,
+//! as well as attaching debuggers to the `nft` binary.
+//! Since the implementation is mostly based on trial and error, there might of course be
+//! a number of places where the forged netlink messages are used in an invalid or not intended way.
+//! Contributions are welcome!
//!
-//! # Supported versions of `libnftnl`
-//!
-//! This crate will automatically link to the currently installed version of libnftnl upon build.
-//! It requires libnftnl version 1.0.6 or higher. See how the low level FFI bindings to the C
-//! library are generated in [`build.rs`].
-//!
-//! # Access to raw handles
-//!
-//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absolutely
-//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`.
-//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For
-//! example, a program using a const handle to a object in a thread and writing through a mutable
-//! handle in another could reach all kind of undefined (and dangerous!) behaviors. By enabling
-//! that feature flag, you acknowledge that guaranteeing the respect of safety invariants is now
-//! your responsibility! Despite these shortcomings, that feature is still available because it
-//! may allow you to perform manipulations that this library doesn't currently expose. If that is
-//! your case, we would be very happy to hear from you and maybe help you get the necessary
-//! functionality upstream.
-//!
-//! Our current lack of confidence in our availability to provide a safe abstraction over the use
-//! of raw handles in the face of concurrency is the reason we decided to settly on `Rc` pointers
-//! instead of `Arc` (besides, this should gives us some nice performance boost, not that it
-//! matters much of course) and why we do not declare the types exposed by the library as `Send`
-//! nor `Sync`.
-//!
-//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/
//! [`nftables`]: https://netfilter.org/projects/nftables/
-//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs
-
-use thiserror::Error;
#[macro_use]
extern crate log;
-pub mod sys;
-use std::{convert::TryFrom, ffi::c_void, ops::Deref};
-use sys::libc;
-
-macro_rules! try_alloc {
- ($e:expr) => {{
- let ptr = $e;
- if ptr.is_null() {
- // OOM, and the tried allocation was likely very small,
- // so we are in a very tight situation. We do what libstd does, aborts.
- std::process::abort();
- }
- ptr
- }};
-}
+use libc;
+
+use rustables_macros::nfnetlink_enum;
+use std::convert::TryFrom;
mod batch;
-#[cfg(feature = "query")]
-pub use batch::{batch_is_supported, default_batch_page_size};
-pub use batch::{Batch, FinalizedBatch, NetlinkError};
+pub use batch::{default_batch_page_size, Batch};
-pub mod expr;
+pub mod data_type;
-pub mod table;
+mod table;
+pub use table::list_tables;
pub use table::Table;
-#[cfg(feature = "query")]
-pub use table::{get_tables_cb, list_tables};
mod chain;
-#[cfg(feature = "query")]
-pub use chain::{get_chains_cb, list_chains_for_table};
-pub use chain::{Chain, ChainType, Hook, Policy, Priority};
+pub use chain::list_chains_for_table;
+pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass};
-mod chain_methods;
-pub use chain_methods::ChainMethods;
+pub mod error;
pub mod query;
+pub(crate) mod nlmsg;
+pub(crate) mod parser;
+pub(crate) mod parser_impls;
+
mod rule;
+pub use rule::list_rules_for_chain;
pub use rule::Rule;
-#[cfg(feature = "query")]
-pub use rule::{get_rules_cb, list_rules_for_chain};
+
+pub mod expr;
mod rule_methods;
-pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError};
+pub use rule_methods::{iface_index, Protocol};
pub mod set;
pub use set::Set;
+pub mod sys;
+
+#[cfg(test)]
+mod tests;
+
/// The type of the message as it's sent to netfilter. A message consists of an object, such as a
/// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with
/// that object. If a [`Table`] object is sent with `MsgType::Add` then that table will be added
@@ -133,7 +97,7 @@ pub use set::Set;
/// [`Chain`]: struct.Chain.html
/// [`Rule`]: struct.Rule.html
/// [`MsgType`]: enum.MsgType.html
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum MsgType {
/// Add the object to netfilter.
Add,
@@ -142,69 +106,22 @@ pub enum MsgType {
}
/// Denotes a protocol. Used to specify which protocol a table or set belongs to.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u16)]
-pub enum ProtoFamily {
- Unspec = libc::NFPROTO_UNSPEC as u16,
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+#[nfnetlink_enum(i32)]
+pub enum ProtocolFamily {
+ Unspec = libc::NFPROTO_UNSPEC,
/// Inet - Means both IPv4 and IPv6
- Inet = libc::NFPROTO_INET as u16,
- Ipv4 = libc::NFPROTO_IPV4 as u16,
- Arp = libc::NFPROTO_ARP as u16,
- NetDev = libc::NFPROTO_NETDEV as u16,
- Bridge = libc::NFPROTO_BRIDGE as u16,
- Ipv6 = libc::NFPROTO_IPV6 as u16,
- DecNet = libc::NFPROTO_DECNET as u16,
-}
-#[derive(Error, Debug)]
-#[error("Couldn't find a matching protocol")]
-pub struct InvalidProtocolFamily;
-
-impl TryFrom<i32> for ProtoFamily {
- type Error = InvalidProtocolFamily;
- fn try_from(value: i32) -> Result<Self, Self::Error> {
- match value {
- libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec),
- libc::NFPROTO_INET => Ok(ProtoFamily::Inet),
- libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4),
- libc::NFPROTO_ARP => Ok(ProtoFamily::Arp),
- libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev),
- libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge),
- libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6),
- libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet),
- _ => Err(InvalidProtocolFamily),
- }
- }
+ Inet = libc::NFPROTO_INET,
+ Ipv4 = libc::NFPROTO_IPV4,
+ Arp = libc::NFPROTO_ARP,
+ NetDev = libc::NFPROTO_NETDEV,
+ Bridge = libc::NFPROTO_BRIDGE,
+ Ipv6 = libc::NFPROTO_IPV6,
+ DecNet = libc::NFPROTO_DECNET,
}
-/// Trait for all types in this crate that can serialize to a Netlink message.
-///
-/// # Unsafe
-///
-/// This trait is unsafe to implement because it must never serialize to anything larger than the
-/// largest possible netlink message. Internally the `nft_nlmsg_maxsize()` function is used to
-/// make sure the `buf` pointer passed to `write` always has room for the largest possible Netlink
-/// message.
-pub unsafe trait NlMsg {
- /// Serializes the Netlink message to the buffer at `buf`. `buf` must have space for at least
- /// `nft_nlmsg_maxsize()` bytes. This is not checked by the compiler, which is why this method
- /// is unsafe.
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType);
-}
-
-unsafe impl<T, R> NlMsg for T
-where
- T: Deref<Target = R>,
- R: NlMsg,
-{
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- self.deref().write(buf, seq, msg_type);
+impl Default for ProtocolFamily {
+ fn default() -> Self {
+ ProtocolFamily::Unspec
}
}
-
-/// The largest nf_tables netlink message is the set element message, which contains the
-/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
-/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
-/// a bit larger than 64 KBytes.
-pub fn nft_nlmsg_maxsize() -> u32 {
- u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
-}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
new file mode 100644
index 0000000..1c5b519
--- /dev/null
+++ b/src/nlmsg.rs
@@ -0,0 +1,182 @@
+use std::{fmt::Debug, mem::size_of};
+
+use crate::{
+ error::DecodeError,
+ sys::{
+ nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
+ NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE,
+ },
+ MsgType, ProtocolFamily,
+};
+///
+/// The largest nf_tables netlink message is the set element message, which contains the
+/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
+/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
+/// a bit larger than 64 KBytes.
+pub fn nft_nlmsg_maxsize() -> u32 {
+ u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
+}
+
+#[inline]
+pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize {
+ // align on a 4 bytes boundary
+ (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1)
+}
+
+#[inline]
+pub const fn pad_netlink_object<T>() -> usize {
+ let size = size_of::<T>();
+ pad_netlink_object_with_variable_size(size)
+}
+
+pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
+ ((x & 0xff00) >> 8) as u8
+}
+
+pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
+ (x & 0x00ff) as u8
+}
+
+pub struct NfNetlinkWriter<'a> {
+ buf: &'a mut Vec<u8>,
+ headers: Option<(usize, usize)>,
+}
+
+impl<'a> NfNetlinkWriter<'a> {
+ pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> {
+ NfNetlinkWriter { buf, headers: None }
+ }
+
+ pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] {
+ let padded_size = pad_netlink_object_with_variable_size(size);
+ let start = self.buf.len();
+ self.buf.resize(start + padded_size, 0);
+
+ if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers {
+ let mut hdr: &mut nlmsghdr = unsafe {
+ std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr)
+ };
+ hdr.nlmsg_len += padded_size as u32;
+ }
+
+ &mut self.buf[start..start + size]
+ }
+
+ // rewrite of `__nftnl_nlmsg_build_hdr`
+ pub fn write_header(
+ &mut self,
+ msg_type: u16,
+ family: ProtocolFamily,
+ flags: u16,
+ seq: u32,
+ ressource_id: Option<u16>,
+ ) {
+ if self.headers.is_some() {
+ error!("Calling write_header while still holding headers open!?");
+ }
+
+ let nlmsghdr_len = pad_netlink_object::<nlmsghdr>();
+ let nfgenmsg_len = pad_netlink_object::<nfgenmsg>();
+
+ let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len);
+ let mut hdr: &mut nlmsghdr =
+ unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
+ hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32;
+ hdr.nlmsg_type = msg_type;
+ // batch messages are not specific to the nftables subsystem
+ if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 {
+ hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8;
+ }
+ hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
+ hdr.nlmsg_seq = seq;
+
+ let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len);
+ let mut nfgenmsg: &mut nfgenmsg =
+ unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) };
+ nfgenmsg.nfgen_family = family as u8;
+ nfgenmsg.version = NFNETLINK_V0 as u8;
+ nfgenmsg.res_id = ressource_id.unwrap_or(0);
+
+ self.headers = Some((
+ self.buf.len() - (nlmsghdr_len + nfgenmsg_len),
+ self.buf.len() - nfgenmsg_len,
+ ));
+ }
+
+ pub fn finalize_writing_object(&mut self) {
+ self.headers = None;
+ }
+}
+
+pub trait AttributeDecoder {
+ fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>;
+}
+
+pub trait NfNetlinkDeserializable: Sized {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>;
+}
+
+pub trait NfNetlinkObject:
+ Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute
+{
+ const MSG_TYPE_ADD: u32;
+ const MSG_TYPE_DEL: u32;
+
+ fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
+ let raw_msg_type = match msg_type {
+ MsgType::Add => Self::MSG_TYPE_ADD,
+ MsgType::Del => Self::MSG_TYPE_DEL,
+ } as u16;
+ writer.write_header(
+ raw_msg_type,
+ self.get_family(),
+ (if let MsgType::Add = msg_type {
+ self.get_add_flags()
+ } else {
+ self.get_del_flags()
+ } | NLM_F_ACK) as u16,
+ seq,
+ None,
+ );
+ let buf = writer.add_data_zeroed(self.get_size());
+ unsafe {
+ self.write_payload(buf.as_mut_ptr());
+ }
+ writer.finalize_writing_object();
+ }
+
+ fn get_family(&self) -> ProtocolFamily;
+
+ fn set_family(&mut self, _family: ProtocolFamily) {
+ // the default impl do nothing, because some types are family-agnostic
+ }
+
+ fn with_family(mut self, family: ProtocolFamily) -> Self {
+ self.set_family(family);
+ self
+ }
+
+ fn get_add_flags(&self) -> u32 {
+ NLM_F_CREATE
+ }
+
+ fn get_del_flags(&self) -> u32 {
+ 0
+ }
+}
+
+pub type NetlinkType = u16;
+
+pub trait NfNetlinkAttribute: Debug + Sized {
+ // is it a nested argument that must be marked with a NLA_F_NESTED flag?
+ fn is_nested(&self) -> bool {
+ false
+ }
+
+ fn get_size(&self) -> usize {
+ size_of::<Self>()
+ }
+
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
+ unsafe fn write_payload(&self, addr: *mut u8);
+}
diff --git a/src/parser.rs b/src/parser.rs
new file mode 100644
index 0000000..6ea34c1
--- /dev/null
+++ b/src/parser.rs
@@ -0,0 +1,216 @@
+use std::{
+ fmt::{Debug, DebugStruct},
+ mem::{size_of, transmute},
+};
+
+use crate::{
+ error::DecodeError,
+ nlmsg::{
+ get_operation_from_nlmsghdr_type, get_subsystem_from_nlmsghdr_type, pad_netlink_object,
+ pad_netlink_object_with_variable_size, AttributeDecoder, NetlinkType, NfNetlinkAttribute,
+ },
+ sys::{
+ nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN,
+ NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE,
+ NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
+ },
+};
+
+pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> {
+ let size_of_hdr = size_of::<nlmsghdr>();
+
+ if buf.len() < size_of_hdr {
+ return Err(DecodeError::BufTooSmall);
+ }
+
+ let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr;
+ let nlmsghdr = unsafe { *nlmsghdr_ptr };
+
+ if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+
+ if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 {
+ return Err(DecodeError::ConcurrentGenerationUpdate);
+ }
+
+ Ok(nlmsghdr)
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub enum NlMsg<'a> {
+ Done,
+ Noop,
+ Error(nlmsgerr),
+ NfGenMsg(nfgenmsg, &'a [u8]),
+}
+
+pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeError> {
+ // in theory the message is composed of the following parts:
+ // - nlmsghdr (contains the message size and type)
+ // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
+ // - the raw value that we want to validate (if the previous part is nfgenmsg)
+ let hdr = get_nlmsghdr(buf)?;
+
+ let size_of_hdr = pad_netlink_object::<nlmsghdr>();
+
+ if hdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
+ match hdr.nlmsg_type as u32 {
+ x if x == NLMSG_NOOP => return Ok((hdr, NlMsg::Noop)),
+ x if x == NLMSG_ERROR => {
+ if (hdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+ let mut err = unsafe {
+ *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr()
+ as *const nlmsgerr)
+ };
+ // some APIs return negative values, while other return positive values
+ err.error = err.error.abs();
+ return Ok((hdr, NlMsg::Error(err)));
+ }
+ x if x == NLMSG_DONE => return Ok((hdr, NlMsg::Done)),
+ x => return Err(DecodeError::UnsupportedType(x as u16)),
+ }
+ }
+
+ // batch messages are not specific to the nftables subsystem
+ if hdr.nlmsg_type != NFNL_MSG_BATCH_BEGIN as u16 && hdr.nlmsg_type != NFNL_MSG_BATCH_END as u16
+ {
+ // verify that we are decoding nftables messages
+ let subsys = get_subsystem_from_nlmsghdr_type(hdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(DecodeError::InvalidSubsystem(subsys));
+ }
+ }
+
+ let size_of_nfgenmsg = pad_netlink_object::<nfgenmsg>();
+ if hdr.nlmsg_len as usize > buf.len()
+ || (hdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
+ {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+
+ let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const nfgenmsg;
+ let nfgenmsg = unsafe { *nfgenmsg_ptr };
+
+ if nfgenmsg.version != NFNETLINK_V0 as u8 {
+ return Err(DecodeError::InvalidVersion(nfgenmsg.version));
+ }
+
+ let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..hdr.nlmsg_len as usize];
+
+ Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value)))
+}
+
+/// Write the attribute, preceded by a `libc::nlattr`
+// rewrite of `mnl_attr_put`
+pub unsafe fn write_attribute<'a>(
+ ty: NetlinkType,
+ obj: &impl NfNetlinkAttribute,
+ mut buf: *mut u8,
+) {
+ let header_len = pad_netlink_object::<libc::nlattr>();
+ // copy the header
+ *(buf as *mut nlattr) = nlattr {
+ // nla_len contains the header size + the unpadded attribute length
+ nla_len: (header_len + obj.get_size() as usize) as u16,
+ nla_type: if obj.is_nested() {
+ ty | NLA_F_NESTED as u16
+ } else {
+ ty
+ },
+ };
+ buf = buf.offset(pad_netlink_object::<nlattr>() as isize);
+ // copy the attribute data itself
+ obj.write_payload(buf);
+}
+
+pub(crate) fn read_attributes<T: AttributeDecoder + Default>(buf: &[u8]) -> Result<T, DecodeError> {
+ debug!(
+ "Calling <{} as NfNetlinkDeserialize>::deserialize()",
+ std::any::type_name::<T>()
+ );
+ let mut remaining_size = buf.len();
+ let mut pos = 0;
+ let mut res = T::default();
+ while remaining_size > pad_netlink_object::<nlattr>() {
+ let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) };
+ // ignore the byteorder and nested attributes
+ let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16;
+
+ pos += pad_netlink_object::<nlattr>();
+ let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>();
+ match T::decode_attribute(&mut res, nla_type, &buf[pos..pos + attr_remaining_size]) {
+ Ok(()) => {}
+ Err(DecodeError::UnsupportedAttributeType(t)) => info!(
+ "Ignoring unsupported attribute type {} for type {}",
+ t,
+ std::any::type_name::<T>()
+ ),
+ Err(e) => return Err(e),
+ }
+ pos += pad_netlink_object_with_variable_size(attr_remaining_size);
+
+ remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
+ }
+
+ if remaining_size != 0 {
+ Err(DecodeError::InvalidDataSize)
+ } else {
+ Ok(res)
+ }
+}
+
+pub trait InnerFormat {
+ fn inner_format_struct<'a, 'b: 'a>(
+ &'a self,
+ s: DebugStruct<'a, 'b>,
+ ) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>;
+}
+
+pub trait Parsable
+where
+ Self: Sized,
+{
+ fn parse_object(
+ buf: &[u8],
+ add_obj: u32,
+ del_obj: u32,
+ ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>;
+}
+
+impl<T> Parsable for T
+where
+ T: AttributeDecoder + Default + Sized,
+{
+ fn parse_object(
+ buf: &[u8],
+ add_obj: u32,
+ del_obj: u32,
+ ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> {
+ debug!("parse_object() started");
+ let (hdr, msg) = parse_nlmsg(buf)?;
+
+ let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
+
+ if op != add_obj && op != del_obj {
+ return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
+ }
+
+ let obj_size = hdr.nlmsg_len as usize
+ - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>());
+
+ let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
+ let remaining_data = &buf[remaining_data_offset..];
+
+ let (nfgenmsg, res) = match msg {
+ NlMsg::NfGenMsg(nfgenmsg, content) => {
+ (nfgenmsg, read_attributes(&content[..obj_size])?)
+ }
+ _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
+ };
+
+ Ok((res, nfgenmsg, remaining_data))
+ }
+}
diff --git a/src/parser_impls.rs b/src/parser_impls.rs
new file mode 100644
index 0000000..b2681bb
--- /dev/null
+++ b/src/parser_impls.rs
@@ -0,0 +1,243 @@
+use std::{fmt::Debug, mem::transmute};
+
+use rustables_macros::nfnetlink_struct;
+
+use crate::{
+ error::DecodeError,
+ expr::Verdict,
+ nlmsg::{
+ pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
+ NfNetlinkDeserializable, NfNetlinkObject,
+ },
+ parser::{write_attribute, Parsable},
+ sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK},
+ ProtocolFamily,
+};
+
+impl NfNetlinkAttribute for u8 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *addr = *self;
+ }
+}
+
+impl NfNetlinkDeserializable for u8 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf[0], &buf[1..]))
+ }
+}
+
+impl NfNetlinkAttribute for u16 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u16 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..]))
+ }
+}
+
+impl NfNetlinkAttribute for i32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for i32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for u32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for u64 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u64 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u64::from_be_bytes([
+ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
+ ]),
+ &buf[8..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for String {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
+ }
+}
+
+impl NfNetlinkDeserializable for String {
+ fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ // ignore the NULL byte terminator, if any
+ if buf.len() > 0 && buf[buf.len() - 1] == 0 {
+ buf = &buf[..buf.len() - 1];
+ }
+ Ok((String::from_utf8(buf.to_vec())?, &[]))
+ }
+}
+
+impl NfNetlinkAttribute for Vec<u8> {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
+ }
+}
+
+impl NfNetlinkDeserializable for Vec<u8> {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf.to_vec(), &[]))
+ }
+}
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true)]
+pub struct NfNetlinkData {
+ #[field(NFTA_DATA_VALUE)]
+ value: Vec<u8>,
+ #[field(NFTA_DATA_VERDICT)]
+ verdict: Verdict,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Default)]
+pub struct NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default,
+{
+ objs: Vec<T>,
+}
+
+impl<T> NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ pub fn add_value(&mut self, e: impl Into<T>) {
+ self.objs.push(e.into());
+ }
+
+ pub fn with_value(mut self, e: impl Into<T>) -> Self {
+ self.add_value(e);
+ self
+ }
+
+ pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> {
+ self.objs.iter()
+ }
+}
+
+impl<T> NfNetlinkAttribute for NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn is_nested(&self) -> bool {
+ true
+ }
+
+ fn get_size(&self) -> usize {
+ // one nlattr LIST_ELEM per object
+ self.objs.iter().fold(0, |acc, item| {
+ acc + item.get_size() + pad_netlink_object::<nlattr>()
+ })
+ }
+
+ unsafe fn write_payload(&self, mut addr: *mut u8) {
+ for item in &self.objs {
+ write_attribute(NFTA_LIST_ELEM, item, addr);
+ addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize);
+ }
+ }
+}
+
+impl<T> NfNetlinkDeserializable for NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let mut objs = Vec::new();
+
+ let mut pos = 0;
+ while buf.len() - pos > pad_netlink_object::<nlattr>() {
+ let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) };
+ // ignore the byteorder and nested attributes
+ let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16;
+
+ if nla_type != NFTA_LIST_ELEM {
+ return Err(DecodeError::UnsupportedAttributeType(nla_type));
+ }
+
+ let (obj, remaining) = T::deserialize(
+ &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize],
+ )?;
+ if remaining.len() != 0 {
+ return Err(DecodeError::InvalidDataSize);
+ }
+ objs.push(obj);
+
+ pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
+ }
+
+ if pos != buf.len() {
+ Err(DecodeError::InvalidDataSize)
+ } else {
+ Ok((Self { objs }, &[]))
+ }
+ }
+}
+
+impl<O, T> From<Vec<O>> for NfNetlinkList<T>
+where
+ T: From<O>,
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn from(v: Vec<O>) -> Self {
+ NfNetlinkList {
+ objs: v.into_iter().map(T::from).collect(),
+ }
+ }
+}
+
+impl<T> NfNetlinkDeserializable for T
+where
+ T: NfNetlinkObject + Parsable,
+{
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (mut obj, nfgenmsg, remaining_data) = Self::parse_object(
+ buf,
+ <T as NfNetlinkObject>::MSG_TYPE_ADD,
+ <T as NfNetlinkObject>::MSG_TYPE_DEL,
+ )?;
+ obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?);
+
+ Ok((obj, remaining_data))
+ }
+}
diff --git a/src/query.rs b/src/query.rs
index bc1d02e..7cf5050 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,129 +1,178 @@
-use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
-use sys::libc;
-
-/// Returns a buffer containing a netlink message which requests a list of all the netfilter
-/// matching objects (e.g. tables, chains, rules, ...).
-/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback
-/// to execute on the header, to set parameters for example.
-/// To pass arbitrary data inside that callback, please use a closure.
-pub fn get_list_of_objects<Error>(
- seq: u32,
- target: u16,
- setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
-) -> Result<Vec<u8>, Error> {
- let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
- let hdr = unsafe {
- &mut *sys::nftnl_nlmsg_build_hdr(
- buffer.as_mut_ptr() as *mut libc::c_char,
- target,
- ProtoFamily::Unspec as u16,
- (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
- seq,
- )
- };
- if let Some(cb) = setup_cb {
- cb(hdr)?;
- }
- Ok(buffer)
-}
-
-#[cfg(feature = "query")]
-mod inner {
- use crate::FinalizedBatch;
-
- use super::*;
-
- #[derive(thiserror::Error, Debug)]
- pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
-
- #[error("Unable to send netlink command to netfilter")]
- NetlinkSendError(#[source] std::io::Error),
-
- #[error("Error while reading from netlink socket")]
- NetlinkRecvError(#[source] std::io::Error),
+use std::os::unix::prelude::RawFd;
+
+use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType};
+
+use crate::{
+ error::QueryError,
+ nlmsg::{
+ nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
+ NfNetlinkObject, NfNetlinkWriter,
+ },
+ parser::{parse_nlmsg, NlMsg},
+ sys::{NLM_F_DUMP, NLM_F_MULTI},
+ ProtocolFamily,
+};
+
+pub(crate) fn recv_and_process<'a, T>(
+ sock: RawFd,
+ max_seq: Option<u32>,
+ cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>,
+ working_data: &'a mut T,
+) -> Result<(), QueryError> {
+ let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize];
+ let mut buf_start = 0;
+ let mut end_pos = 0;
+
+ loop {
+ let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty())
+ .map_err(QueryError::NetlinkRecvError)?;
+ if nb_recv <= 0 {
+ return Ok(());
+ }
+ end_pos += nb_recv;
+ loop {
+ let buf = &msg_buffer.as_slice()[buf_start..end_pos];
+ // exit the loop and try to receive further messages when we consumed all the buffer
+ if buf.len() == 0 {
+ break;
+ }
- #[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[source] std::io::Error),
+ debug!("Calling parse_nlmsg");
+ let (nlmsghdr, msg) = parse_nlmsg(&buf)?;
+ debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg);
+
+ match msg {
+ NlMsg::Done => {
+ return Ok(());
+ }
+ NlMsg::Error(e) => {
+ if e.error != 0 {
+ return Err(QueryError::NetlinkError(e));
+ }
+ }
+ NlMsg::Noop => {}
+ NlMsg::NfGenMsg(_genmsg, _data) => {
+ if let Some(cb) = cb {
+ cb(&buf[0..nlmsghdr.nlmsg_len as usize], working_data)?;
+ }
+ }
+ }
- #[error("Custom error when customizing the query")]
- InitError(#[from] Box<dyn std::error::Error + 'static>),
+ // we cannot know when a sequence of messages will end if the messages do not end
+ // with an NlMsg::Done marker while if a maximum sequence number wasn't specified
+ if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
+ return Err(QueryError::UndecidableMessageTermination);
+ }
- #[error("Couldn't allocate a netlink object, out of memory ?")]
- NetlinkAllocationFailed,
- }
+ // retrieve the next message
+ if let Some(max_seq) = max_seq {
+ if nlmsghdr.nlmsg_seq >= max_seq {
+ return Ok(());
+ }
+ }
- /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
- /// function called by mnl::cb_run2.
- /// The callback expects a tuple of additional data (supplied as an argument to this function)
- /// and of the output vector, to which it should append the parsed object it received.
- pub fn list_objects_with_data<'a, A, T>(
- data_type: u16,
- cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int,
- additional_data: &'a A,
- req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
- ) -> Result<Vec<T>, Error>
- where
- T: 'a,
- {
- debug!("listing objects of kind {}", data_type);
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
-
- let seq = 0;
- let portid = 0;
-
- let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
- socket.send(&chains_buf).map_err(Error::NetlinkSendError)?;
-
- let mut res = Vec::new();
-
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop = mnl::cb_run2(
- &msg_buffer,
- seq,
- portid,
- cb,
- &mut (additional_data, &mut res),
- )
- .map_err(Error::ProcessNetlinkError)?
- {
- break;
+ // netlink messages are 4bytes aligned
+ let aligned_length = pad_netlink_object_with_variable_size(nlmsghdr.nlmsg_len as usize);
+ buf_start += aligned_length;
+ }
+ // Ensure that we always have nft_nlmsg_maxsize() free space available in the buffer.
+ // We achieve this by relocating the buffer content at the beginning of the buffer
+ if end_pos >= nft_nlmsg_maxsize() as usize {
+ if buf_start < end_pos {
+ unsafe {
+ std::ptr::copy(
+ msg_buffer[buf_start..end_pos].as_ptr(),
+ msg_buffer.as_mut_ptr(),
+ end_pos - buf_start,
+ );
+ }
}
+ end_pos = end_pos - buf_start;
+ buf_start = 0;
}
-
- Ok(res)
}
+}
- pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
+pub(crate) fn socket_close_wrapper<E>(
+ sock: RawFd,
+ cb: impl FnOnce(RawFd) -> Result<(), E>,
+) -> Result<(), QueryError>
+where
+ QueryError: From<E>,
+{
+ let ret = cb(sock);
- let seq = 0;
- let portid = socket.portid();
+ // we don't need to shutdown the socket (in fact, Linux doesn't support that operation;
+ // and return EOPNOTSUPP if we try)
+ nix::unistd::close(sock).map_err(QueryError::CloseFailed)?;
- socket.send_all(batch).map_err(Error::NetlinkSendError)?;
- debug!("sent");
+ Ok(ret?)
+}
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop =
- mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)?
- {
- break;
- }
+/// Returns a buffer containing a netlink message which requests a list of all the netfilter
+/// matching objects (e.g. tables, chains, rules, ...).
+/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter.
+pub fn get_list_of_objects<T: NfNetlinkAttribute>(
+ msg_type: u16,
+ seq: u32,
+ filter: Option<&T>,
+) -> Result<Vec<u8>, QueryError> {
+ let mut buffer = Vec::new();
+ let mut writer = NfNetlinkWriter::new(&mut buffer);
+ writer.write_header(
+ msg_type,
+ ProtocolFamily::Unspec,
+ NLM_F_DUMP as u16,
+ seq,
+ None,
+ );
+ if let Some(filter) = filter {
+ let buf = writer.add_data_zeroed(filter.get_size());
+ unsafe {
+ filter.write_payload(buf.as_mut_ptr());
}
- Ok(())
}
+ writer.finalize_writing_object();
+ Ok(buffer)
}
-#[cfg(feature = "query")]
-pub use inner::*;
+/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
+/// function called by mnl::cb_run2.
+/// The callback expects a tuple of additional data (supplied as an argument to this function)
+/// and of the output vector, to which it should append the parsed object it received.
+pub fn list_objects_with_data<'a, Object, Accumulator>(
+ data_type: u16,
+ cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>,
+ filter: Option<&Object>,
+ working_data: &'a mut Accumulator,
+) -> Result<(), QueryError>
+where
+ Object: NfNetlinkObject + NfNetlinkAttribute,
+{
+ debug!("Listing objects of kind {}", data_type);
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(QueryError::NetlinkOpenError)?;
+
+ let seq = 0;
+
+ let chains_buf = get_list_of_objects(data_type, seq, filter)?;
+ socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?;
+
+ socket_close_wrapper(sock, move |sock| {
+ // the kernel should return NLM_F_MULTI objects
+ recv_and_process(
+ sock,
+ None,
+ Some(&|buf: &[u8], working_data: &mut Accumulator| {
+ debug!("Calling Object::deserialize()");
+ cb(Object::deserialize(buf)?.0, working_data)
+ }),
+ working_data,
+ )
+ })
+}
diff --git a/src/rule.rs b/src/rule.rs
index 2ee5308..858b9ce 100644
--- a/src/rule.rs
+++ b/src/rule.rs
@@ -1,341 +1,111 @@
-use crate::expr::ExpressionWrapper;
-use crate::{chain::Chain, expr::Expression, MsgType};
-use crate::sys::{self, libc};
-use std::ffi::{c_void, CStr, CString};
use std::fmt::Debug;
-use std::os::raw::c_char;
-use std::rc::Rc;
+
+use rustables_macros::nfnetlink_struct;
+
+use crate::chain::Chain;
+use crate::error::{BuilderError, QueryError};
+use crate::expr::{ExpressionList, RawExpression};
+use crate::nlmsg::NfNetlinkObject;
+use crate::query::list_objects_with_data;
+use crate::sys::{
+ NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION,
+ NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND,
+ NLM_F_CREATE,
+};
+use crate::{Batch, ProtocolFamily};
/// A nftables firewall rule.
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(derive_deserialize = false)]
pub struct Rule {
- pub(crate) rule: *mut sys::nftnl_rule,
- pub(crate) chain: Rc<Chain>,
+ family: ProtocolFamily,
+ #[field(NFTA_RULE_TABLE)]
+ table: String,
+ #[field(NFTA_RULE_CHAIN)]
+ chain: String,
+ #[field(NFTA_RULE_HANDLE)]
+ handle: u64,
+ #[field(NFTA_RULE_EXPRESSIONS)]
+ expressions: ExpressionList,
+ #[field(NFTA_RULE_POSITION)]
+ position: u64,
+ #[field(NFTA_RULE_USERDATA)]
+ userdata: Vec<u8>,
+ #[field(NFTA_RULE_ID)]
+ id: u32,
}
impl Rule {
/// Creates a new rule object in the given [`Chain`].
///
/// [`Chain`]: struct.Chain.html
- pub fn new(chain: Rc<Chain>) -> Rule {
- unsafe {
- let rule = try_alloc!(sys::nftnl_rule_alloc());
- sys::nftnl_rule_set_u32(
- rule,
- sys::NFTNL_RULE_FAMILY as u16,
- chain.get_table().get_family() as u32,
- );
- sys::nftnl_rule_set_str(
- rule,
- sys::NFTNL_RULE_TABLE as u16,
- chain.get_table().get_name().as_ptr(),
- );
- sys::nftnl_rule_set_str(
- rule,
- sys::NFTNL_RULE_CHAIN as u16,
- chain.get_name().as_ptr(),
- );
-
- Rule { rule, chain }
- }
- }
-
- pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self {
- Rule { rule, chain }
- }
-
- pub fn get_position(&self) -> u64 {
- unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) }
- }
-
- /// Sets the position of this rule within the chain it lives in. By default a new rule is added
- /// to the end of the chain.
- pub fn set_position(&mut self, position: u64) {
- unsafe {
- sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position);
- }
- }
-
- pub fn get_handle(&self) -> u64 {
- unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) }
- }
-
- pub fn set_handle(&mut self, handle: u64) {
- unsafe {
- sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle);
- }
- }
-
- /// Adds an expression to this rule. Expressions are evaluated from first to last added.
- /// As soon as an expression does not match the packet it's being evaluated for, evaluation
- /// stops and the packet is evaluated against the next rule in the chain.
- pub fn add_expr(&mut self, expr: &impl Expression) {
- unsafe { sys::nftnl_rule_add_expr(self.rule, expr.to_expr(self)) }
- }
-
- /// Returns a reference to the [`Chain`] this rule lives in.
- ///
- /// [`Chain`]: struct.Chain.html
- pub fn get_chain(&self) -> Rc<Chain> {
- self.chain.clone()
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
- }
-
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr());
- }
- }
-
- /// Returns a textual description of the rule.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_rule_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.rule,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- /// Retrieves an iterator to loop over the expressions of the rule.
- pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter {
- RuleExprsIter::new(self.clone())
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_rule {
- self.rule as *const sys::nftnl_rule
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule {
- self.rule
- }
-
- /// Performs a deep comparizon of rules, by checking they have the same expressions inside.
- /// This is not enabled by default in our PartialEq implementation because of the difficulty to
- /// compare an expression generated by the library with the expressions returned by the kernel
- /// when iterating over the currently in-use rules. The kernel-returned expressions may have
- /// additional attributes despite being generated from the same rule. This is particularly true
- /// for the 'nat' expression).
- pub fn deep_eq(&self, other: &Self) -> bool {
- if self != other {
- return false;
- }
-
- let self_exprs =
- try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) });
- let other_exprs = try_alloc!(unsafe {
- sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule)
- });
-
- loop {
- let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) };
- let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) };
- if self_next.is_null() && other_next.is_null() {
- return true;
- } else if self_next.is_null() || other_next.is_null() {
- return false;
- }
-
- // we are falling back on comparing the strings, because there is no easy mechanism to
- // perform a memcmp() between the two expressions :/
- let mut self_str = [0; 256];
- let mut other_str = [0; 256];
- unsafe {
- sys::nftnl_expr_snprintf(
- self_str.as_mut_ptr(),
- (self_str.len() - 1) as u64,
- self_next,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- sys::nftnl_expr_snprintf(
- other_str.as_mut_ptr(),
- (other_str.len() - 1) as u64,
- other_next,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
+ pub fn new(chain: &Chain) -> Result<Rule, BuilderError> {
+ Ok(Rule::default()
+ .with_family(chain.get_family())
+ .with_table(
+ chain
+ .get_table()
+ .ok_or(BuilderError::MissingChainInformationError)?,
+ )
+ .with_chain(
+ chain
+ .get_name()
+ .ok_or(BuilderError::MissingChainInformationError)?,
+ ))
+ }
+
+ pub fn add_expr(&mut self, e: impl Into<RawExpression>) {
+ let exprs = match self.get_mut_expressions() {
+ Some(x) => x,
+ None => {
+ self.set_expressions(ExpressionList::default());
+ self.get_mut_expressions().unwrap()
}
-
- if self_str != other_str {
- return false;
- }
- }
- }
-}
-
-impl Debug for Rule {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
- }
-}
-
-impl PartialEq for Rule {
- fn eq(&self, other: &Self) -> bool {
- if self.get_chain() != other.get_chain() {
- return false;
- }
-
- unsafe {
- if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16)
- && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16)
- {
- if self.get_handle() != other.get_handle() {
- return false;
- }
- }
- if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16)
- && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16)
- {
- if self.get_position() != other.get_position() {
- return false;
- }
- }
- }
-
- return false;
- }
-}
-
-unsafe impl crate::NlMsg for Rule {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- let type_ = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWRULE,
- MsgType::Del => libc::NFT_MSG_DELRULE,
};
- let flags: u16 = match msg_type {
- MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16,
- MsgType::Del => 0u16,
- } | libc::NLM_F_ACK as u16;
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- type_ as u16,
- self.chain.get_table().get_family() as u16,
- flags,
- seq,
- );
- sys::nftnl_rule_nlmsg_build_payload(header, self.rule);
+ exprs.add_value(e);
}
-}
-impl Drop for Rule {
- fn drop(&mut self) {
- unsafe { sys::nftnl_rule_free(self.rule) };
+ pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self {
+ self.add_expr(e);
+ self
}
-}
-
-pub struct RuleExprsIter {
- rule: Rc<Rule>,
- iter: *mut sys::nftnl_expr_iter,
-}
-impl RuleExprsIter {
- fn new(rule: Rc<Rule>) -> Self {
- let iter =
- try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) });
- RuleExprsIter { rule, iter }
+ /// Appends this rule to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
}
}
-impl Iterator for RuleExprsIter {
- type Item = ExpressionWrapper;
+impl NfNetlinkObject for Rule {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE;
- fn next(&mut self) -> Option<Self::Item> {
- let next = unsafe { sys::nftnl_expr_iter_next(self.iter) };
- if next.is_null() {
- trace!("RulesExprsIter iterator ending");
- None
- } else {
- trace!("RulesExprsIter returning new expression");
- Some(ExpressionWrapper {
- expr: next,
- rule: self.rule.clone(),
- })
- }
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
}
-}
-impl Drop for RuleExprsIter {
- fn drop(&mut self) {
- unsafe { sys::nftnl_expr_iter_destroy(self.iter) };
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
-}
-
-#[cfg(feature = "query")]
-pub fn get_rules_cb(
- header: &libc::nlmsghdr,
- (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>),
-) -> libc::c_int {
- unsafe {
- let rule = sys::nftnl_rule_alloc();
- if rule == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
- }
- let err = sys::nftnl_rule_nlmsg_parse(header, rule);
- if err < 0 {
- error!("Failed to parse nelink rule message - {}", err);
- sys::nftnl_rule_free(rule);
- return err;
- }
- rules.push(Rule::from_raw(rule, chain.clone()));
+ // append at the end of the chain, instead of the beginning
+ fn get_add_flags(&self) -> u32 {
+ NLM_F_CREATE | NLM_F_APPEND
}
- mnl::mnl_sys::MNL_CB_OK
}
-#[cfg(feature = "query")]
-pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> {
- crate::query::list_objects_with_data(
+pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> {
+ let mut result = Vec::new();
+ list_objects_with_data(
libc::NFT_MSG_GETRULE as u16,
- get_rules_cb,
- &chain,
- // only retrieve rules from the currently targetted chain
- Some(&|hdr| unsafe {
- let rule = sys::nftnl_rule_alloc();
- if rule as *const _ == std::ptr::null() {
- return Err(crate::query::Error::NetlinkAllocationFailed);
- }
-
- sys::nftnl_rule_set_str(
- rule,
- sys::NFTNL_RULE_TABLE as u16,
- chain.get_table().get_name().as_ptr(),
- );
- sys::nftnl_rule_set_u32(
- rule,
- sys::NFTNL_RULE_FAMILY as u16,
- chain.get_table().get_family() as u32,
- );
- sys::nftnl_rule_set_str(
- rule,
- sys::NFTNL_RULE_CHAIN as u16,
- chain.get_name().as_ptr(),
- );
-
- sys::nftnl_rule_nlmsg_build_payload(hdr, rule);
-
- sys::nftnl_rule_free(rule);
+ &|rule: Rule, rules: &mut Vec<Rule>| {
+ rules.push(rule);
Ok(())
- }),
- )
+ },
+ // only retrieve rules from the currently targetted chain
+ Some(&Rule::new(chain)?),
+ &mut result,
+ )?;
+ Ok(result)
}
diff --git a/src/rule_methods.rs b/src/rule_methods.rs
index d7145d7..dff9bf6 100644
--- a/src/rule_methods.rs
+++ b/src/rule_methods.rs
@@ -1,230 +1,211 @@
-use crate::{Batch, Rule, nft_expr, sys::libc};
-use crate::expr::{LogGroup, LogPrefix};
-use ipnetwork::IpNetwork;
-use std::ffi::{CString, NulError};
+use std::ffi::CString;
use std::net::IpAddr;
-use std::num::ParseIntError;
-
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
- #[error("Firewall is already started")]
- AlreadyDone,
- #[error("Error converting from a C string to a string")]
- CStringError(#[from] NulError),
- #[error("no interface found under that name")]
- NoSuchIface,
- #[error("Error converting from a string to an integer")]
- ParseError(#[from] ParseIntError),
- #[error("the interface name is too long")]
- NameTooLong,
-}
+use ipnetwork::IpNetwork;
+use crate::data_type::ip_to_vec;
+use crate::error::BuilderError;
+use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey};
+use crate::expr::{
+ Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta,
+ MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField,
+ VerdictKind,
+};
+use crate::Rule;
/// Simple protocol description. Note that it does not implement other layer 4 protocols as
/// IGMP et al. See [`Rule::igmp`] for a workaround.
-#[derive(Debug, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Protocol {
TCP,
- UDP
+ UDP,
}
-/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a
-/// verdict. Mostly adapted from [talpid-core's firewall].
-/// All methods return the rule itself, allowing them to be chained. Usage example :
-/// ```rust
-/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook};
-/// use std::ffi::CString;
-/// use std::rc::Rc;
-/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet));
-/// let mut batch = Batch::new();
-/// batch.add(&table, MsgType::Add);
-/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table))
-/// .add_to_batch(&mut batch));
-/// let rule = Rule::new(inbound)
-/// .dport("80", &Protocol::TCP).unwrap()
-/// .accept()
-/// .add_to_batch(&mut batch);
-/// ```
-/// [talpid-core's firewall]:
-/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs
-pub trait RuleMethods {
- /// Matches ICMP packets.
- fn icmp(self) -> Self;
- /// Matches IGMP packets.
- fn igmp(self) -> Self;
- /// Matches packets to destination `port` and `protocol`.
- fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets on `protocol`.
- fn protocol(self, protocol: Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets in an already established connection.
- fn established(self) -> Self where Self: std::marker::Sized;
- /// Matches packets going through `iface_index`. Interface indexes can be queried with
- /// `iface_index()`.
- fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo".
- fn iface(self, iface_name: &str) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix
- /// appended to each log line.
- fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self;
- /// Matches packets whose source IP address is `saddr`.
- fn saddr(self, ip: IpAddr) -> Self;
- /// Matches packets whose source network is `snet`.
- fn snetwork(self, ip: IpNetwork) -> Self;
- /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
- fn accept(self) -> Self;
- /// Adds the `Drop` verdict to the rule. The packet will be dropped.
- fn drop(self) -> Self;
- /// Appends this rule to `batch`.
- fn add_to_batch(self, batch: &mut Batch) -> Self;
-}
-
-/// A trait to add helper functions to match some criterium over `crate::Rule`.
-impl RuleMethods for Rule {
- fn icmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8));
- self
- }
- fn igmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8));
+impl Rule {
+ fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self {
+ self = self.protocol(protocol);
+ self.add_expr(
+ HighLevelPayload::Transport(match protocol {
+ Protocol::TCP => TransportHeaderField::Tcp(if source {
+ TCPHeaderField::Sport
+ } else {
+ TCPHeaderField::Dport
+ }),
+ Protocol::UDP => TransportHeaderField::Udp(if source {
+ UDPHeaderField::Sport
+ } else {
+ UDPHeaderField::Dport
+ }),
+ })
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes()));
self
}
- fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- &Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- self.add_expr(&nft_expr!(payload tcp dport));
- },
- &Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- self.add_expr(&nft_expr!(payload udp dport));
- }
- }
- // Convert the port to Big-Endian number spelling.
- // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969
- self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be()));
- Ok(self)
- }
- fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- },
- Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- }
- }
- Ok(self)
- }
- fn established(mut self) -> Self {
- let allowed_states = crate::expr::ct::States::ESTABLISHED.bits();
- self.add_expr(&nft_expr!(ct state));
- self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32));
- self.add_expr(&nft_expr!(cmp != 0u32));
- self
- }
- fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta iif));
- self.add_expr(&nft_expr!(cmp == iface_index));
- Ok(self)
- }
- fn iface(mut self, iface_name: &str) -> Result<Self, Error> {
- if iface_name.len() >= libc::IFNAMSIZ {
- return Err(Error::NameTooLong);
- }
- let mut name_arr = [0u8; libc::IFNAMSIZ];
- for (pos, i) in iface_name.bytes().enumerate() {
- name_arr[pos] = i;
- }
- self.add_expr(&nft_expr!(meta iifname));
- self.add_expr(&nft_expr!(cmp == name_arr.as_ref()));
- Ok(self)
- }
- fn saddr(mut self, ip: IpAddr) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+ pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self {
+ self.add_expr(Meta::new(MetaType::NfProto));
match ip {
IpAddr::V4(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
+ }
IpAddr::V6(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
}
}
self
}
- fn snetwork(mut self, net: IpNetwork) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+
+ pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> {
+ self.add_expr(Meta::new(MetaType::NfProto));
match net {
IpNetwork::V4(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32));
- self.add_expr(&nft_expr!(cmp == net.network()));
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?);
+ }
IpNetwork::V6(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..]));
- self.add_expr(&nft_expr!(cmp == net.network()));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?);
}
}
+ self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network())));
+ Ok(self)
+ }
+}
+
+impl Rule {
+ /// Matches ICMP packets.
+ pub fn icmp(mut self) -> Self {
+ // quid of icmpv6?
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8]));
self
}
- fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self {
- match (group.is_some(), prefix.is_some()) {
- (true, true) => {
- self.add_expr(&nft_expr!(log group group prefix prefix));
- },
- (false, true) => {
- self.add_expr(&nft_expr!(log prefix prefix));
- },
- (true, false) => {
- self.add_expr(&nft_expr!(log group group));
- },
- (false, false) => {
- self.add_expr(&nft_expr!(log));
- }
- }
+ /// Matches IGMP packets.
+ pub fn igmp(mut self) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8]));
self
}
- fn accept(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict accept));
+ /// Matches packets from source `port` and `protocol`.
+ pub fn sport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets to destination `port` and `protocol`.
+ pub fn dport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets on `protocol`.
+ pub fn protocol(mut self, protocol: Protocol) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(
+ CmpOp::Eq,
+ [match protocol {
+ Protocol::TCP => libc::IPPROTO_TCP,
+ Protocol::UDP => libc::IPPROTO_UDP,
+ } as u8],
+ ));
+ self
+ }
+ /// Matches packets in an already established connection.
+ pub fn established(mut self) -> Result<Self, BuilderError> {
+ let allowed_states = ConnTrackState::ESTABLISHED.bits();
+ self.add_expr(Conntrack::new(ConntrackKey::State));
+ self.add_expr(Bitwise::new(
+ allowed_states.to_le_bytes(),
+ 0u32.to_be_bytes(),
+ )?);
+ self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes()));
+ Ok(self)
+ }
+ /// Matches packets going through `iface_index`. Interface indexes can be queried with
+ /// `iface_index()`.
+ pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self {
+ self.add_expr(Meta::new(MetaType::Iif));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes()));
self
}
- fn drop(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict drop));
+ /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo"
+ pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> {
+ if iface_name.len() >= libc::IFNAMSIZ {
+ return Err(BuilderError::InterfaceNameTooLong);
+ }
+ let mut iface_vec = iface_name.as_bytes().to_vec();
+ // null terminator
+ iface_vec.push(0u8);
+
+ self.add_expr(Meta::new(MetaType::IifName));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_vec));
+ Ok(self)
+ }
+ /// Matches packets whose source IP address is `saddr`.
+ pub fn saddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, true)
+ }
+ /// Matches packets whose destination IP address is `saddr`.
+ pub fn daddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, false)
+ }
+ /// Matches packets whose source network is `net`.
+ pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, true)
+ }
+ /// Matches packets whose destination network is `net`.
+ pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, false)
+ }
+ /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
+ pub fn accept(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Accept));
self
}
- fn add_to_batch(self, batch: &mut Batch) -> Self {
- batch.add(&self, crate::MsgType::Add);
+ /// Adds the `Drop` verdict to the rule. The packet will be dropped.
+ pub fn drop(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Drop));
self
}
}
/// Looks up the interface index for a given interface name.
-pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> {
+pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> {
let c_name = CString::new(name)?;
let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
match index {
- 0 => Err(Error::NoSuchIface),
- _ => Ok(index)
+ 0 => Err(std::io::Error::last_os_error()),
+ _ => Ok(index),
}
}
-
-
diff --git a/src/set.rs b/src/set.rs
index 90712c5..ab29770 100644
--- a/src/set.rs
+++ b/src/set.rs
@@ -1,273 +1,116 @@
-use crate::sys::{self, libc};
-use crate::{table::Table, MsgType, ProtoFamily};
-use std::{
- cell::Cell,
- ffi::{c_void, CStr, CString},
- fmt::Debug,
- net::{Ipv4Addr, Ipv6Addr},
- os::raw::c_char,
- rc::Rc,
+use rustables_macros::nfnetlink_struct;
+
+use crate::data_type::DataType;
+use crate::error::BuilderError;
+use crate::nlmsg::NfNetlinkObject;
+use crate::parser_impls::{NfNetlinkData, NfNetlinkList};
+use crate::sys::{
+ NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, NFTA_SET_ELEM_LIST_SET,
+ NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_FLAGS, NFTA_SET_ID, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE,
+ NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_DELSETELEM,
+ NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM,
};
-
-#[macro_export]
-macro_rules! nft_set {
- ($name:expr, $id:expr, $table:expr, $family:expr) => {
- $crate::set::Set::new($name, $id, $table, $family)
- };
- ($name:expr, $id:expr, $table:expr, $family:expr; [ ]) => {
- nft_set!($name, $id, $table, $family)
- };
- ($name:expr, $id:expr, $table:expr, $family:expr; [ $($value:expr,)* ]) => {{
- let mut set = nft_set!($name, $id, $table, $family).expect("Set allocation failed");
- $(
- set.add($value).expect(stringify!(Unable to add $value to set $name));
- )*
- set
- }};
-}
-
-pub struct Set<K> {
- pub(crate) set: *mut sys::nftnl_set,
- pub(crate) table: Rc<Table>,
- pub(crate) family: ProtoFamily,
- _marker: ::std::marker::PhantomData<K>,
+use crate::table::Table;
+use crate::ProtocolFamily;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(derive_deserialize = false)]
+pub struct Set {
+ pub family: ProtocolFamily,
+ #[field(NFTA_SET_TABLE)]
+ pub table: String,
+ #[field(NFTA_SET_NAME)]
+ pub name: String,
+ #[field(NFTA_SET_FLAGS)]
+ pub flags: u32,
+ #[field(NFTA_SET_KEY_TYPE)]
+ pub key_type: u32,
+ #[field(NFTA_SET_KEY_LEN)]
+ pub key_len: u32,
+ #[field(NFTA_SET_ID)]
+ pub id: u32,
+ #[field(NFTA_SET_USERDATA)]
+ pub userdata: String,
}
-impl<K> Set<K> {
- pub fn new(name: &CStr, id: u32, table: Rc<Table>, family: ProtoFamily) -> Self
- where
- K: SetKey,
- {
- unsafe {
- let set = try_alloc!(sys::nftnl_set_alloc());
-
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, family as u32);
- sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr());
- sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr());
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id);
-
- sys::nftnl_set_set_u32(
- set,
- sys::NFTNL_SET_FLAGS as u16,
- (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32,
- );
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE);
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN);
-
- Set {
- set,
- table,
- family,
- _marker: ::std::marker::PhantomData,
- }
- }
- }
-
- pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: Rc<Table>, family: ProtoFamily) -> Self
- where
- K: SetKey,
- {
- Set {
- set,
- table,
- family,
- _marker: ::std::marker::PhantomData,
- }
- }
-
- pub fn add(&mut self, key: &K)
- where
- K: SetKey,
- {
- unsafe {
- let elem = try_alloc!(sys::nftnl_set_elem_alloc());
-
- let data = key.data();
- let data_len = data.len() as u32;
- trace!("Adding key {:?} with len {}", data, data_len);
- sys::nftnl_set_elem_set(
- elem,
- sys::NFTNL_SET_ELEM_KEY as u16,
- data.as_ref() as *const _ as *const c_void,
- data_len,
- );
- sys::nftnl_set_elem_add(self.set, elem);
- }
- }
-
- pub fn elems_iter(&self) -> SetElemsIter<K> {
- SetElemsIter::new(self)
- }
+impl NfNetlinkObject for Set {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSET;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELSET;
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_set {
- self.set as *const sys::nftnl_set
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set {
- self.set
- }
-
- pub fn get_family(&self) -> ProtoFamily {
+ fn get_family(&self) -> ProtocolFamily {
self.family
}
- /// Returns a textual description of the set.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_set_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.set,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- pub fn get_name(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
- }
-
- pub fn get_id(&self) -> u32 {
- unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) }
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-impl<K> Debug for Set<K> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
- }
-}
-
-unsafe impl<K> crate::NlMsg for Set<K> {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- let type_ = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWSET,
- MsgType::Del => libc::NFT_MSG_DELSET,
- };
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- type_ as u16,
- self.table.get_family() as u16,
- (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16,
- seq,
- );
- sys::nftnl_set_nlmsg_build_payload(header, self.set);
- }
+pub struct SetBuilder<K: DataType> {
+ inner: Set,
+ list: SetElementList,
+ _phantom: PhantomData<K>,
}
-impl<K> Drop for Set<K> {
- fn drop(&mut self) {
- unsafe { sys::nftnl_set_free(self.set) };
- }
-}
-
-pub struct SetElemsIter<'a, K> {
- set: &'a Set<K>,
- iter: *mut sys::nftnl_set_elems_iter,
- ret: Rc<Cell<i32>>,
-}
-
-impl<'a, K> SetElemsIter<'a, K> {
- fn new(set: &'a Set<K>) -> Self {
- let iter = try_alloc!(unsafe {
- sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set)
+impl<K: DataType> SetBuilder<K> {
+ pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> {
+ let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?;
+ let set_name = name.into();
+ let set = Set::default()
+ .with_key_type(K::TYPE)
+ .with_key_len(K::LEN)
+ .with_table(table_name)
+ .with_name(&set_name);
+
+ Ok(SetBuilder {
+ inner: set,
+ list: SetElementList {
+ table: Some(table_name.clone()),
+ set: Some(set_name),
+ elements: Some(SetElementListElements::default()),
+ },
+ _phantom: PhantomData,
+ })
+ }
+
+ pub fn add(&mut self, key: &K) {
+ self.list.elements.as_mut().unwrap().add_value(SetElement {
+ key: Some(NfNetlinkData::default().with_value(key.data())),
});
- SetElemsIter {
- set,
- iter,
- ret: Rc::new(Cell::new(1)),
- }
}
-}
-
-impl<'a, K> Iterator for SetElemsIter<'a, K> {
- type Item = SetElemsMsg<'a, K>;
- fn next(&mut self) -> Option<Self::Item> {
- if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } {
- trace!("SetElemsIter iterator ending");
- None
- } else {
- trace!("SetElemsIter returning new SetElemsMsg");
- Some(SetElemsMsg {
- set: self.set,
- iter: self.iter,
- ret: self.ret.clone(),
- })
- }
+ pub fn finish(self) -> (Set, SetElementList) {
+ (self.inner, self.list)
}
}
-impl<'a, K> Drop for SetElemsIter<'a, K> {
- fn drop(&mut self) {
- unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) };
- }
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true, derive_deserialize = false)]
+pub struct SetElementList {
+ #[field(NFTA_SET_ELEM_LIST_TABLE)]
+ pub table: String,
+ #[field(NFTA_SET_ELEM_LIST_SET)]
+ pub set: String,
+ #[field(NFTA_SET_ELEM_LIST_ELEMENTS)]
+ pub elements: SetElementListElements,
}
-pub struct SetElemsMsg<'a, K> {
- set: &'a Set<K>,
- iter: *mut sys::nftnl_set_elems_iter,
- ret: Rc<Cell<i32>>,
-}
+impl NfNetlinkObject for SetElementList {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSETELEM;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELSETELEM;
-unsafe impl<'a, K> crate::NlMsg for SetElemsMsg<'a, K> {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- trace!("Writing SetElemsMsg to NlMsg");
- let (type_, flags) = match msg_type {
- MsgType::Add => (
- libc::NFT_MSG_NEWSETELEM,
- libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK,
- ),
- MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK),
- };
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- type_ as u16,
- self.set.get_family() as u16,
- flags as u16,
- seq,
- );
- self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter(
- header, self.iter,
- ));
+ fn get_family(&self) -> ProtocolFamily {
+ ProtocolFamily::Unspec
}
}
-pub trait SetKey {
- const TYPE: u32;
- const LEN: u32;
-
- fn data(&self) -> Box<[u8]>;
-}
-
-impl SetKey for Ipv4Addr {
- const TYPE: u32 = 7;
- const LEN: u32 = 4;
-
- fn data(&self) -> Box<[u8]> {
- self.octets().to_vec().into_boxed_slice()
- }
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
+pub struct SetElement {
+ #[field(NFTA_SET_ELEM_KEY)]
+ pub key: NfNetlinkData,
}
-impl SetKey for Ipv6Addr {
- const TYPE: u32 = 8;
- const LEN: u32 = 16;
-
- fn data(&self) -> Box<[u8]> {
- self.octets().to_vec().into_boxed_slice()
- }
-}
+type SetElementListElements = NfNetlinkList<SetElement>;
diff --git a/src/sys.rs b/src/sys.rs
new file mode 100644
index 0000000..4384a1c
--- /dev/null
+++ b/src/sys.rs
@@ -0,0 +1,3 @@
+#![allow(non_camel_case_types, dead_code)]
+
+include!(concat!(env!("OUT_DIR"), "/sys.rs"));
diff --git a/src/table.rs b/src/table.rs
index 593fffb..81a26ef 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,171 +1,68 @@
-use crate::{MsgType, ProtoFamily};
-use crate::sys::{self, libc};
-#[cfg(feature = "query")]
-use std::convert::TryFrom;
-use std::{
- ffi::{c_void, CStr, CString},
- fmt::Debug,
- os::raw::c_char,
+use std::fmt::Debug;
+
+use rustables_macros::nfnetlink_struct;
+
+use crate::error::QueryError;
+use crate::nlmsg::NfNetlinkObject;
+use crate::sys::{
+ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
+ NFT_MSG_NEWTABLE,
};
+use crate::{Batch, ProtocolFamily};
-/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol
+/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
+#[derive(Default, PartialEq, Eq, Debug)]
+#[nfnetlink_struct(derive_deserialize = false)]
pub struct Table {
- table: *mut sys::nftnl_table,
- family: ProtoFamily,
+ family: ProtocolFamily,
+ #[field(NFTA_TABLE_NAME)]
+ name: String,
+ #[field(NFTA_TABLE_FLAGS)]
+ flags: u32,
+ #[field(NFTA_TABLE_USERDATA)]
+ userdata: Vec<u8>,
}
impl Table {
- /// Creates a new table instance with the given name and protocol family.
- pub fn new<T: AsRef<CStr>>(name: &T, family: ProtoFamily) -> Table {
- unsafe {
- let table = try_alloc!(sys::nftnl_table_alloc());
-
- sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FAMILY as u16, family as u32);
- sys::nftnl_table_set_str(table, sys::NFTNL_TABLE_NAME as u16, name.as_ref().as_ptr());
- sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FLAGS as u16, 0u32);
- Table { table, family }
- }
- }
-
- pub unsafe fn from_raw(table: *mut sys::nftnl_table, family: ProtoFamily) -> Self {
- Table { table, family }
- }
-
- /// Returns the name of this table.
- pub fn get_name(&self) -> &CStr {
- unsafe {
- let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16);
- if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a chain failed")
- } else {
- CStr::from_ptr(ptr)
- }
- }
- }
-
- /// Returns a textual description of the table.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_table_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.table,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- /// Returns the protocol family for this table.
- pub fn get_family(&self) -> ProtoFamily {
- self.family
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
+ pub fn new(family: ProtocolFamily) -> Table {
+ let mut res = Self::default();
+ res.family = family;
+ res
}
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_table_set_str(self.table, sys::NFTNL_TABLE_USERDATA as u16, data.as_ptr());
- }
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_table {
- self.table as *const sys::nftnl_table
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&self) -> *mut sys::nftnl_table {
- self.table
- }
-}
-
-impl PartialEq for Table {
- fn eq(&self, other: &Self) -> bool {
- self.get_name() == other.get_name() && self.get_family() == other.get_family()
+ /// Appends this rule to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
}
}
-impl Debug for Table {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
- }
-}
+impl NfNetlinkObject for Table {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWTABLE;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELTABLE;
-unsafe impl crate::NlMsg for Table {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- let raw_msg_type = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWTABLE,
- MsgType::Del => libc::NFT_MSG_DELTABLE,
- };
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- raw_msg_type as u16,
- self.family as u16,
- libc::NLM_F_ACK as u16,
- seq,
- );
- sys::nftnl_table_nlmsg_build_payload(header, self.table);
- }
-}
-
-impl Drop for Table {
- fn drop(&mut self) {
- unsafe { sys::nftnl_table_free(self.table) };
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
}
-}
-#[cfg(feature = "query")]
-/// A callback to parse the response for messages created with `get_tables_nlmsg`.
-pub fn get_tables_cb(
- header: &libc::nlmsghdr,
- (_, tables): &mut (&(), &mut Vec<Table>),
-) -> libc::c_int {
- unsafe {
- let table = sys::nftnl_table_alloc();
- if table == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
- }
- let err = sys::nftnl_table_nlmsg_parse(header, table);
- if err < 0 {
- error!("Failed to parse nelink table message - {}", err);
- sys::nftnl_table_free(table);
- return err;
- }
- let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16);
- match crate::ProtoFamily::try_from(family as i32) {
- Ok(family) => {
- tables.push(Table::from_raw(table, family));
- mnl::mnl_sys::MNL_CB_OK
- }
- Err(crate::InvalidProtocolFamily) => {
- error!("The netlink table didn't have a valid protocol family !?");
- sys::nftnl_table_free(table);
- mnl::mnl_sys::MNL_CB_ERROR
- }
- }
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-#[cfg(feature = "query")]
-pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> {
- crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None)
+pub fn list_tables() -> Result<Vec<Table>, QueryError> {
+ let mut result = Vec::new();
+ crate::query::list_objects_with_data(
+ NFT_MSG_GETTABLE as u16,
+ &|table: Table, tables: &mut Vec<Table>| {
+ tables.push(table);
+ Ok(())
+ },
+ None,
+ &mut result,
+ )?;
+ Ok(result)
}
diff --git a/src/tests/batch.rs b/src/tests/batch.rs
new file mode 100644
index 0000000..12f373f
--- /dev/null
+++ b/src/tests/batch.rs
@@ -0,0 +1,96 @@
+use std::mem::size_of;
+
+use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST};
+use nix::libc::NFNL_MSG_BATCH_END;
+
+use crate::nlmsg::{pad_netlink_object_with_variable_size, NfNetlinkDeserializable};
+use crate::parser::{parse_nlmsg, NlMsg};
+use crate::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES};
+use crate::{Batch, MsgType, Table};
+
+use super::get_test_table;
+
+const HEADER_SIZE: u32 =
+ pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()) as u32;
+
+const DEFAULT_BATCH_BEGIN_HDR: nlmsghdr = nlmsghdr {
+ nlmsg_len: HEADER_SIZE,
+ nlmsg_flags: NLM_F_REQUEST as u16,
+ nlmsg_type: NFNL_MSG_BATCH_BEGIN as u16,
+ nlmsg_seq: 0,
+ nlmsg_pid: 0,
+};
+const DEFAULT_BATCH_MSG: NlMsg = NlMsg::NfGenMsg(
+ nfgenmsg {
+ nfgen_family: AF_UNSPEC as u8,
+ version: NFNETLINK_V0 as u8,
+ res_id: NFNL_SUBSYS_NFTABLES as u16,
+ },
+ &[],
+);
+
+const DEFAULT_BATCH_END_HDR: nlmsghdr = nlmsghdr {
+ nlmsg_len: HEADER_SIZE,
+ nlmsg_flags: NLM_F_REQUEST as u16,
+ nlmsg_type: NFNL_MSG_BATCH_END as u16,
+ nlmsg_seq: 1,
+ nlmsg_pid: 0,
+};
+
+#[test]
+fn batch_empty() {
+ let batch = Batch::new();
+ let buf = batch.finalize();
+
+ let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+
+ let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
+
+ let (hdr, msg) = parse_nlmsg(&buf[remaining_data_offset..]).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_END_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+}
+
+#[test]
+fn batch_with_objects() {
+ let mut original_tables = vec![];
+ for i in 0..10 {
+ let mut table = get_test_table();
+ table.set_userdata(vec![i as u8]);
+ original_tables.push(table);
+ }
+
+ let mut batch = Batch::new();
+ for i in 0..10 {
+ batch.add(
+ &original_tables[i],
+ if i % 2 == 0 {
+ MsgType::Add
+ } else {
+ MsgType::Del
+ },
+ );
+ }
+ let buf = batch.finalize();
+
+ let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+ let mut remaining_data = &buf[pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize)..];
+
+ for i in 0..10 {
+ let (deserialized_table, rest) =
+ Table::deserialize(&remaining_data).expect("could not deserialize a table");
+ remaining_data = rest;
+
+ assert_eq!(deserialized_table, original_tables[i]);
+ }
+
+ let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message");
+ let mut end_hdr = DEFAULT_BATCH_END_HDR;
+ end_hdr.nlmsg_seq = 11;
+ assert_eq!(hdr, end_hdr);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+}
diff --git a/src/tests/chain.rs b/src/tests/chain.rs
new file mode 100644
index 0000000..7f696e6
--- /dev/null
+++ b/src/tests/chain.rs
@@ -0,0 +1,120 @@
+use crate::{
+ nlmsg::get_operation_from_nlmsghdr_type,
+ sys::{
+ NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA,
+ NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN,
+ },
+ ChainType, Hook, HookClass, MsgType,
+};
+
+use super::{
+ get_test_chain, get_test_nlmsg, get_test_nlmsg_with_msg_type, NetlinkExpr, CHAIN_NAME,
+ CHAIN_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_chain() {
+ let mut chain = get_test_chain();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_chain_with_hook_and_type() {
+ let mut chain = get_test_chain()
+ .with_hook(Hook::new(HookClass::In, 0))
+ .with_type(ChainType::Filter);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 84);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_TYPE, "filter".as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_CHAIN_HOOK,
+ vec![
+ NetlinkExpr::List(vec![NetlinkExpr::Final(
+ NFTA_HOOK_HOOKNUM,
+ vec![0, 0, 0, 1]
+ )]),
+ NetlinkExpr::List(vec![NetlinkExpr::Final(
+ NFTA_HOOK_PRIORITY,
+ vec![0, 0, 0, 0]
+ )])
+ ]
+ ),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_chain_with_userdata() {
+ let mut chain = get_test_chain();
+ chain.set_userdata(CHAIN_USERDATA);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 72);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_USERDATA, CHAIN_USERDATA.as_bytes().to_vec())
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_chain() {
+ let mut chain = get_test_chain();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut chain, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/expr.rs b/src/tests/expr.rs
new file mode 100644
index 0000000..35c4fea
--- /dev/null
+++ b/src/tests/expr.rs
@@ -0,0 +1,591 @@
+use std::net::Ipv4Addr;
+
+use libc::NF_DROP;
+
+use crate::{
+ expr::{
+ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField,
+ HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat,
+ NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind,
+ },
+ set::SetBuilder,
+ sys::{
+ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG,
+ NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES,
+ NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT,
+ NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM,
+ NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG,
+ NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE,
+ NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE,
+ NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE,
+ NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT,
+ NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH,
+ },
+ tests::{get_test_table, SET_NAME},
+ ProtocolFamily,
+};
+
+use super::{get_test_nlmsg, get_test_rule, NetlinkExpr, CHAIN_NAME, TABLE_NAME};
+
+#[test]
+fn bitwise_expr_is_valid() {
+ let netmask = Ipv4Addr::new(255, 255, 255, 0);
+ let bitwise = Bitwise::new(netmask.octets(), [0, 0, 0, 0]).unwrap();
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(bitwise));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 124);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_BITWISE_SREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_BITWISE_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(NFTA_BITWISE_LEN, 4u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_BITWISE_MASK,
+ vec![NetlinkExpr::Final(
+ NFTA_DATA_VALUE,
+ vec![255, 255, 255, 0]
+ )]
+ ),
+ NetlinkExpr::Nested(
+ NFTA_BITWISE_XOR,
+ vec![NetlinkExpr::Final(
+ NFTA_DATA_VALUE,
+ 0u32.to_be_bytes().to_vec()
+ )]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn cmp_expr_is_valid() {
+ let val = [1u8, 2, 3, 4];
+ let cmp = Cmp::new(CmpOp::Eq, val.clone());
+ let mut rule = get_test_rule().with_expressions(vec![cmp]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(NFTA_CMP_SREG, NFT_REG_1.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_CMP_DATA,
+ vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val.to_vec())]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn counter_expr_is_valid() {
+ let nb_bytes = 123456u64;
+ let nb_packets = 987u64;
+ let counter = Counter::default()
+ .with_nb_bytes(nb_bytes)
+ .with_nb_packets(nb_packets);
+
+ let mut rule = get_test_rule().with_expressions(vec![counter]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_COUNTER_BYTES,
+ nb_bytes.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_COUNTER_PACKETS,
+ nb_packets.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn ct_expr_is_valid() {
+ let ct = Conntrack::default().with_retrieve_value(ConntrackKey::State);
+ let mut rule = get_test_rule().with_expressions(vec![ct]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_CT_KEY,
+ NFT_CT_STATE.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(NFTA_CT_DREG, NFT_REG_1.to_be_bytes().to_vec())
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ )
+}
+
+#[test]
+fn immediate_expr_is_valid() {
+ let immediate = Immediate::new_data(vec![42u8], Register::Reg1);
+ let mut rule =
+ get_test_rule().with_expressions(ExpressionList::default().with_value(immediate));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_IMMEDIATE_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Nested(
+ NFTA_IMMEDIATE_DATA,
+ vec![NetlinkExpr::Final(1u16, 42u8.to_be_bytes().to_vec())]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn log_expr_is_valid() {
+ let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression");
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(log));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"log".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(NFTA_LOG_GROUP, 1337u16.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix".to_vec()),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn lookup_expr_is_valid() {
+ let table = get_test_table();
+ let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap();
+ let address: Ipv4Addr = [8, 8, 8, 8].into();
+ set_builder.add(&address);
+ let (set, _set_elements) = set_builder.finish();
+ let lookup = Lookup::new(&set).unwrap();
+
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()),
+ NetlinkExpr::Final(
+ NFTA_LOOKUP_SREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn masquerade_expr_is_valid() {
+ let masquerade = Masquerade::default();
+ let mut rule = get_test_rule().with_expressions(vec![masquerade]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 72);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq".to_vec()),
+ NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]),
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn meta_expr_is_valid() {
+ let meta = Meta::default()
+ .with_key(MetaType::Protocol)
+ .with_dreg(Register::Reg1);
+ let mut rule = get_test_rule().with_expressions(vec![meta]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_META_KEY,
+ NFT_META_PROTOCOL.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_META_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn nat_expr_is_valid() {
+ let nat = Nat::default()
+ .with_nat_type(NatType::SNat)
+ .with_family(ProtocolFamily::Ipv4)
+ .with_ip_register(Register::Reg1);
+ let mut rule = get_test_rule().with_expressions(vec![nat]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_NAT_TYPE,
+ NFT_NAT_SNAT.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_NAT_FAMILY,
+ (ProtocolFamily::Ipv4 as u32).to_be_bytes().to_vec(),
+ ),
+ NetlinkExpr::Final(
+ NFTA_NAT_REG_ADDR_MIN,
+ NFT_REG_1.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn payload_expr_is_valid() {
+ let tcp_header_field = TCPHeaderField::Sport;
+ let transport_header_field = TransportHeaderField::Tcp(tcp_header_field);
+ let payload = HighLevelPayload::Transport(transport_header_field);
+ let mut rule = get_test_rule().with_expressions(vec![payload.build()]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 108);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_BASE,
+ NFT_PAYLOAD_TRANSPORT_HEADER.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_OFFSET,
+ tcp_header_field.offset().to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_LEN,
+ tcp_header_field.len().to_be_bytes().to_vec()
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn reject_expr_is_valid() {
+ let code = IcmpCode::NoRoute;
+ let reject = Reject::default()
+ .with_type(RejectType::IcmpxUnreach)
+ .with_icmp_code(code);
+ let mut rule = get_test_rule().with_expressions(vec![reject]);
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 92);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_REJECT_TYPE,
+ NFT_REJECT_ICMPX_UNREACH.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_REJECT_ICMP_CODE,
+ (code as u8).to_be_bytes().to_vec()
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn verdict_expr_is_valid() {
+ let verdict = Immediate::new_verdict(VerdictKind::Drop);
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(verdict));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 104);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_IMMEDIATE_DREG,
+ NFT_REG_VERDICT.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Nested(
+ NFTA_IMMEDIATE_DATA,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VERDICT,
+ vec![NetlinkExpr::Final(
+ NFTA_VERDICT_CODE,
+ NF_DROP.to_be_bytes().to_vec()
+ ),]
+ )],
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/mod.rs b/src/tests/mod.rs
new file mode 100644
index 0000000..75fe8b0
--- /dev/null
+++ b/src/tests/mod.rs
@@ -0,0 +1,193 @@
+use crate::data_type::DataType;
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::parser::{parse_nlmsg, NlMsg};
+use crate::set::{Set, SetBuilder};
+use crate::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table};
+
+mod batch;
+mod chain;
+mod expr;
+mod rule;
+mod set;
+mod table;
+
+pub const TABLE_NAME: &'static str = "mocktable";
+pub const CHAIN_NAME: &'static str = "mockchain";
+pub const SET_NAME: &'static str = "mockset";
+
+pub const TABLE_USERDATA: &'static str = "mocktabledata";
+pub const CHAIN_USERDATA: &'static str = "mockchaindata";
+pub const RULE_USERDATA: &'static str = "mockruledata";
+pub const SET_USERDATA: &'static str = "mocksetdata";
+
+type NetLinkType = u16;
+
+#[derive(Debug, thiserror::Error)]
+#[error("empty data")]
+pub struct EmptyDataError;
+
+#[derive(Debug, Clone, Eq, Ord)]
+pub enum NetlinkExpr {
+ Nested(NetLinkType, Vec<NetlinkExpr>),
+ Final(NetLinkType, Vec<u8>),
+ List(Vec<NetlinkExpr>),
+}
+
+impl NetlinkExpr {
+ pub fn to_raw(self) -> Vec<u8> {
+ match self.sort() {
+ NetlinkExpr::Final(ty, val) => {
+ let len = val.len() + 4;
+ let mut res = Vec::with_capacity(len);
+
+ res.extend(&(len as u16).to_le_bytes());
+ res.extend(&ty.to_le_bytes());
+ res.extend(val);
+ // alignment
+ while res.len() % 4 != 0 {
+ res.push(0);
+ }
+
+ res
+ }
+ NetlinkExpr::Nested(ty, exprs) => {
+ // some heuristic to decrease allocations (even though this is
+ // only useful for testing so performance is not an objective)
+ let mut sub = Vec::with_capacity(exprs.len() * 50);
+
+ for expr in exprs {
+ sub.append(&mut expr.to_raw());
+ }
+
+ let len = sub.len() + 4;
+ let mut res = Vec::with_capacity(len);
+
+ // set the "NESTED" flag
+ res.extend(&(len as u16).to_le_bytes());
+ res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes());
+ res.extend(sub);
+
+ res
+ }
+ NetlinkExpr::List(exprs) => {
+ // some heuristic to decrease allocations (even though this is
+ // only useful for testing so performance is not an objective)
+ let mut list = Vec::with_capacity(exprs.len() * 50);
+
+ for expr in exprs {
+ list.append(&mut expr.to_raw());
+ }
+
+ list
+ }
+ }
+ }
+
+ pub fn sort(self) -> Self {
+ match self {
+ NetlinkExpr::Final(_, _) => self,
+ NetlinkExpr::Nested(ty, mut exprs) => {
+ exprs.sort();
+ NetlinkExpr::Nested(ty, exprs)
+ }
+ NetlinkExpr::List(mut exprs) => {
+ exprs.sort();
+ NetlinkExpr::List(exprs)
+ }
+ }
+ }
+}
+
+impl PartialEq for NetlinkExpr {
+ fn eq(&self, other: &Self) -> bool {
+ match (self.clone().sort(), other.clone().sort()) {
+ (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2,
+ _ => false,
+ }
+ }
+}
+
+impl PartialOrd for NetlinkExpr {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (
+ NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _),
+ NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _),
+ ) => k1.partial_cmp(k2),
+ (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2),
+ (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less),
+ (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater),
+ }
+ }
+}
+
+pub fn get_test_table() -> Table {
+ Table::new(ProtocolFamily::Inet)
+ .with_name(TABLE_NAME)
+ .with_flags(0u32)
+}
+
+pub fn get_test_table_raw_expr() -> NetlinkExpr {
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()),
+ ])
+ .sort()
+}
+
+pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr {
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.as_bytes().to_vec()),
+ ])
+ .sort()
+}
+
+pub fn get_test_chain() -> Chain {
+ Chain::new(&get_test_table()).with_name(CHAIN_NAME)
+}
+
+pub fn get_test_rule() -> Rule {
+ Rule::new(&get_test_chain()).unwrap()
+}
+
+pub fn get_test_set<K: DataType>() -> Set {
+ SetBuilder::<K>::new(SET_NAME, &get_test_table())
+ .expect("Couldn't create a set")
+ .finish()
+ .0
+ .with_userdata(SET_USERDATA)
+}
+
+pub fn get_test_nlmsg_with_msg_type<'a>(
+ buf: &'a mut Vec<u8>,
+ obj: &mut impl NfNetlinkObject,
+ msg_type: MsgType,
+) -> (nlmsghdr, nfgenmsg, &'a [u8]) {
+ let mut writer = NfNetlinkWriter::new(buf);
+ obj.add_or_remove(&mut writer, msg_type, 0);
+
+ let (hdr, msg) = parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message");
+
+ let (nfgenmsg, raw_value) = match msg {
+ NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value),
+ _ => panic!("Invalid return value type, expected a valid message"),
+ };
+
+ // sanity checks on the global message (this should be very similar/factorisable for the
+ // most part in other tests)
+ // TODO: check the messages flags
+ assert_eq!(nfgenmsg.res_id.to_be(), 0);
+
+ (hdr, nfgenmsg, raw_value)
+}
+
+pub fn get_test_nlmsg<'a>(
+ buf: &'a mut Vec<u8>,
+ obj: &mut impl NfNetlinkObject,
+) -> (nlmsghdr, nfgenmsg, &'a [u8]) {
+ get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add)
+}
diff --git a/src/tests/rule.rs b/src/tests/rule.rs
new file mode 100644
index 0000000..08b4139
--- /dev/null
+++ b/src/tests/rule.rs
@@ -0,0 +1,132 @@
+use crate::{
+ nlmsg::get_operation_from_nlmsghdr_type,
+ sys::{
+ NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA,
+ NFT_MSG_DELRULE, NFT_MSG_NEWRULE,
+ },
+ MsgType,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_rule, NetlinkExpr, CHAIN_NAME,
+ RULE_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_rule() {
+ let mut rule = get_test_rule();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_rule_with_userdata() {
+ let mut rule = get_test_rule().with_userdata(RULE_USERDATA);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 68);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.as_bytes().to_vec())
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_rule_with_position_and_handle() {
+ let handle: u64 = 1337;
+ let position: u64 = 42;
+ let mut rule = get_test_rule().with_handle(handle).with_position(position);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 76);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_rule() {
+ let mut rule = get_test_rule();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_rule_with_handle() {
+ let handle: u64 = 42;
+ let mut rule = get_test_rule().with_handle(handle);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 64);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/set.rs b/src/tests/set.rs
new file mode 100644
index 0000000..6c8247c
--- /dev/null
+++ b/src/tests/set.rs
@@ -0,0 +1,119 @@
+use std::net::{Ipv4Addr, Ipv6Addr};
+
+use crate::{
+ data_type::DataType,
+ nlmsg::get_operation_from_nlmsghdr_type,
+ set::SetBuilder,
+ sys::{
+ NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS,
+ NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE,
+ NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET,
+ NFT_MSG_NEWSETELEM,
+ },
+ MsgType,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr,
+ SET_NAME, SET_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_set() {
+ let mut set = get_test_set::<Ipv4Addr>();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut set);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWSET as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 80);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_set() {
+ let mut set = get_test_set::<Ipv6Addr>();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut set, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELSET as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 80);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_set_with_data() {
+ let ip1 = Ipv4Addr::new(127, 0, 0, 1);
+ let ip2 = Ipv4Addr::new(1, 1, 1, 1);
+ let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table())
+ .expect("Couldn't create a set");
+
+ set_builder.add(&ip1);
+ set_builder.add(&ip2);
+ let (_set, mut elem_list) = set_builder.finish();
+
+ let mut buf = Vec::new();
+
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut elem_list);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWSETELEM as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 84);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_ELEM_LIST_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_ELEM_LIST_SET, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_SET_ELEM_LIST_ELEMENTS,
+ vec![
+ NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VALUE,
+ vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip1.data().to_vec())]
+ )]
+ ),
+ NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VALUE,
+ vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip2.data().to_vec())]
+ )]
+ ),
+ ]
+ ),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/table.rs b/src/tests/table.rs
new file mode 100644
index 0000000..39bf399
--- /dev/null
+++ b/src/tests/table.rs
@@ -0,0 +1,67 @@
+use crate::{
+ nlmsg::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize, NfNetlinkDeserializable},
+ sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE},
+ MsgType, Table,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_table, get_test_table_raw_expr,
+ get_test_table_with_userdata_raw_expr, TABLE_USERDATA,
+};
+
+#[test]
+fn new_empty_table() {
+ let mut table = get_test_table();
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 44);
+
+ assert_eq!(raw_expr, get_test_table_raw_expr().to_raw());
+}
+
+#[test]
+fn new_empty_table_with_userdata() {
+ let mut table = get_test_table();
+ table.set_userdata(TABLE_USERDATA.as_bytes().to_vec());
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 64);
+
+ assert_eq!(raw_expr, get_test_table_with_userdata_raw_expr().to_raw());
+}
+
+#[test]
+fn delete_empty_table() {
+ let mut table = get_test_table();
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut table, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 44);
+
+ assert_eq!(raw_expr, get_test_table_raw_expr().to_raw());
+}
+
+#[test]
+fn parse_table() {
+ let mut table = get_test_table();
+ table.set_userdata(TABLE_USERDATA.as_bytes().to_vec());
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (_nlmsghdr, _nfgenmsg, _raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+
+ let (deserialized_table, remaining) =
+ Table::deserialize(&buf).expect("Couldn't deserialize the object");
+ assert_eq!(table, deserialized_table);
+ assert_eq!(remaining.len(), 0);
+}