aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/batch.rs8
-rw-r--r--src/chain.rs84
-rw-r--r--src/data_type.rs35
-rw-r--r--src/error.rs174
-rw-r--r--src/expr/bitwise.rs17
-rw-r--r--src/expr/cmp.rs18
-rw-r--r--src/expr/immediate.rs31
-rw-r--r--src/expr/lookup.rs94
-rw-r--r--src/expr/meta.rs6
-rw-r--r--src/expr/mod.rs165
-rw-r--r--src/expr/payload.rs2
-rw-r--r--src/expr/verdict.rs25
-rw-r--r--src/lib.rs94
-rw-r--r--src/nlmsg.rs135
-rw-r--r--src/parser.rs280
-rw-r--r--src/parser_impls.rs243
-rw-r--r--src/query.rs84
-rw-r--r--src/rule.rs73
-rw-r--r--src/set.rs365
-rw-r--r--src/table.rs49
-rw-r--r--src/tests/batch.rs96
-rw-r--r--src/tests/chain.rs120
-rw-r--r--src/tests/expr.rs589
-rw-r--r--src/tests/mod.rs195
-rw-r--r--src/tests/rule.rs132
-rw-r--r--src/tests/set.rs122
-rw-r--r--src/tests/table.rs67
27 files changed, 2172 insertions, 1131 deletions
diff --git a/src/batch.rs b/src/batch.rs
index d885813..b5c88b8 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -2,11 +2,11 @@ use libc;
use thiserror::Error;
+use crate::error::QueryError;
use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
use crate::sys::NFNL_SUBSYS_NFTABLES;
use crate::{MsgType, ProtocolFamily};
-use crate::query::Error;
use nix::sys::socket::{
self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
};
@@ -88,7 +88,7 @@ impl Batch {
*self.buf
}
- pub fn send(mut self) -> Result<(), Error> {
+ pub fn send(self) -> Result<(), QueryError> {
use crate::query::{recv_and_process, socket_close_wrapper};
let sock = socket::socket(
@@ -97,7 +97,7 @@ impl Batch {
SockFlag::empty(),
SockProtocol::NetlinkNetFilter,
)
- .map_err(Error::NetlinkOpenError)?;
+ .map_err(QueryError::NetlinkOpenError)?;
let max_seq = self.seq - 1;
@@ -110,7 +110,7 @@ impl Batch {
let mut sent = 0;
while sent != to_send.len() {
sent += socket::send(sock, &to_send[sent..], MsgFlags::empty())
- .map_err(Error::NetlinkSendError)?;
+ .map_err(QueryError::NetlinkSendError)?;
}
Ok(socket_close_wrapper(sock, move |sock| {
diff --git a/src/chain.rs b/src/chain.rs
index 7a62fb2..0ce0ad8 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,15 +1,14 @@
use libc::{NF_ACCEPT, NF_DROP};
use rustables_macros::nfnetlink_struct;
-use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter};
-use crate::parser::{DecodeError, Parsable};
+use crate::error::{DecodeError, QueryError};
+use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
use crate::sys::{
NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
- NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE,
+ NFT_MSG_NEWCHAIN,
};
-use crate::{MsgType, ProtocolFamily, Table};
-use std::convert::TryFrom;
+use crate::{ProtocolFamily, Table};
use std::fmt::Debug;
pub type ChainPriority = i32;
@@ -132,14 +131,10 @@ impl NfNetlinkDeserializable for ChainType {
}
}
-/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s.
-///
-/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more
-/// details.
+/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s.
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-/// [`set_hook`]: #method.set_hook
#[derive(PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
pub struct Chain {
@@ -166,7 +161,7 @@ impl Chain {
/// [`Table`]: struct.Table.html
pub fn new(table: &Table) -> Chain {
let mut chain = Chain::default();
- chain.family = table.family;
+ chain.family = table.get_family();
if let Some(table_name) = table.get_name() {
chain.set_table(table_name);
@@ -174,73 +169,22 @@ impl Chain {
chain
}
-
- pub fn get_family(&self) -> ProtocolFamily {
- self.family
- }
-
- /*
- /// Returns a textual description of the chain.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_chain_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.chain,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
- */
}
-/*
-impl PartialEq for Chain {
- fn eq(&self, other: &Self) -> bool {
- self.get_table() == other.get_table() && self.get_name() == other.get_name()
- }
-}
-*/
-
impl NfNetlinkObject for Chain {
- fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
- let raw_msg_type = match msg_type {
- MsgType::Add => NFT_MSG_NEWCHAIN,
- MsgType::Del => NFT_MSG_DELCHAIN,
- } as u16;
- writer.write_header(
- raw_msg_type,
- self.family,
- (if let MsgType::Add = msg_type {
- NLM_F_CREATE
- } else {
- 0
- } | NLM_F_ACK) as u16,
- seq,
- None,
- );
- let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
- writer.finalize_writing_object();
- }
-}
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN;
-impl NfNetlinkDeserializable for Chain {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (mut obj, nfgenmsg, remaining_data) =
- Self::parse_object(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?;
- obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?;
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
+ }
- Ok((obj, remaining_data))
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, crate::query::Error> {
+pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> {
let mut result = Vec::new();
crate::query::list_objects_with_data(
libc::NFT_MSG_GETCHAIN as u16,
diff --git a/src/data_type.rs b/src/data_type.rs
new file mode 100644
index 0000000..f9c97cb
--- /dev/null
+++ b/src/data_type.rs
@@ -0,0 +1,35 @@
+use std::net::{Ipv4Addr, Ipv6Addr};
+
+pub trait DataType {
+ const TYPE: u32;
+ const LEN: u32;
+
+ fn data(&self) -> Vec<u8>;
+}
+
+impl DataType for Ipv4Addr {
+ const TYPE: u32 = 7;
+ const LEN: u32 = 4;
+
+ fn data(&self) -> Vec<u8> {
+ self.octets().to_vec()
+ }
+}
+
+impl DataType for Ipv6Addr {
+ const TYPE: u32 = 8;
+ const LEN: u32 = 16;
+
+ fn data(&self) -> Vec<u8> {
+ self.octets().to_vec()
+ }
+}
+
+impl<const N: usize> DataType for [u8; N] {
+ const TYPE: u32 = 5;
+ const LEN: u32 = N as u32;
+
+ fn data(&self) -> Vec<u8> {
+ self.to_vec()
+ }
+}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..eae6898
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,174 @@
+use std::string::FromUtf8Error;
+
+use nix::errno::Errno;
+use thiserror::Error;
+
+use crate::sys::nlmsgerr;
+
+#[derive(Error, Debug)]
+pub enum DecodeError {
+ #[error("The buffer is too small to hold a valid message")]
+ BufTooSmall,
+
+ #[error("The message is too small")]
+ NlMsgTooSmall,
+
+ #[error("The message holds unexpected data")]
+ InvalidDataSize,
+
+ #[error("Invalid subsystem, expected NFTABLES")]
+ InvalidSubsystem(u8),
+
+ #[error("Invalid version, expected NFNETLINK_V0")]
+ InvalidVersion(u8),
+
+ #[error("Invalid port ID")]
+ InvalidPortId(u32),
+
+ #[error("Invalid sequence number")]
+ InvalidSeq(u32),
+
+ #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
+ ConcurrentGenerationUpdate,
+
+ #[error("Unsupported message type")]
+ UnsupportedType(u16),
+
+ #[error("Invalid attribute type")]
+ InvalidAttributeType,
+
+ #[error("Invalid type for a chain")]
+ UnknownChainType,
+
+ #[error("Invalid policy for a chain")]
+ UnknownChainPolicy,
+
+ #[error("Unknown type for a Meta expression")]
+ UnknownMetaType(u32),
+
+ #[error("Unsupported value for an icmp reject type")]
+ UnknownRejectType(u32),
+
+ #[error("Unsupported value for an icmp code in a reject expression")]
+ UnknownIcmpCode(u8),
+
+ #[error("Invalid value for a register")]
+ UnknownRegister(u32),
+
+ #[error("Invalid type for a verdict expression")]
+ UnknownVerdictType(i32),
+
+ #[error("Invalid type for a nat expression")]
+ UnknownNatType(i32),
+
+ #[error("Invalid type for a payload expression")]
+ UnknownPayloadType(u32),
+
+ #[error("Invalid type for a compare expression")]
+ UnknownCmpOp(u32),
+
+ #[error("Invalid type for a conntrack key")]
+ UnknownConntrackKey(u32),
+
+ #[error("Unsupported value for a link layer header field")]
+ UnknownLinkLayerHeaderField(u32, u32),
+
+ #[error("Unsupported value for an IPv4 header field")]
+ UnknownIPv4HeaderField(u32, u32),
+
+ #[error("Unsupported value for an IPv6 header field")]
+ UnknownIPv6HeaderField(u32, u32),
+
+ #[error("Unsupported value for a TCP header field")]
+ UnknownTCPHeaderField(u32, u32),
+
+ #[error("Unsupported value for an UDP header field")]
+ UnknownUDPHeaderField(u32, u32),
+
+ #[error("Unsupported value for an ICMPv6 header field")]
+ UnknownICMPv6HeaderField(u32, u32),
+
+ #[error("Missing the 'base' attribute to deserialize the payload object")]
+ PayloadMissingBase,
+
+ #[error("Missing the 'offset' attribute to deserialize the payload object")]
+ PayloadMissingOffset,
+
+ #[error("Missing the 'len' attribute to deserialize the payload object")]
+ PayloadMissingLen,
+
+ #[error("The object does not contain a name for the expression being parsed")]
+ MissingExpressionName,
+
+ #[error("Unsupported attribute type")]
+ UnsupportedAttributeType(u16),
+
+ #[error("Unexpected message type")]
+ UnexpectedType(u16),
+
+ #[error("The decoded String is not UTF8 compliant")]
+ StringDecodeFailure(#[from] FromUtf8Error),
+
+ #[error("Invalid value for a protocol family")]
+ UnknownProtocolFamily(i32),
+
+ #[error("A custom error occured")]
+ Custom(Box<dyn std::error::Error + 'static>),
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum BuilderError {
+ #[error("The length of the arguments are not compatible with each other")]
+ IncompatibleLength,
+
+ #[error("The table does not have a name")]
+ MissingTableName,
+
+ #[error("Missing information in the chain to create a rule")]
+ MissingChainInformationError,
+
+ #[error("Missing name for the set")]
+ MissingSetName,
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum QueryError {
+ #[error("Unable to open netlink socket to netfilter")]
+ NetlinkOpenError(#[source] nix::Error),
+
+ #[error("Unable to send netlink command to netfilter")]
+ NetlinkSendError(#[source] nix::Error),
+
+ #[error("Error while reading from netlink socket")]
+ NetlinkRecvError(#[source] nix::Error),
+
+ #[error("Error while processing an incoming netlink message")]
+ ProcessNetlinkError(#[from] DecodeError),
+
+ #[error("Error while building netlink objects in Rust")]
+ BuilderError(#[from] BuilderError),
+
+ #[error("Error received from the kernel")]
+ NetlinkError(nlmsgerr),
+
+ #[error("Custom error when customizing the query")]
+ InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
+
+ #[error("Couldn't allocate a netlink object, out of memory ?")]
+ NetlinkAllocationFailed,
+
+ #[error("This socket is not a netlink socket")]
+ NotNetlinkSocket,
+
+ #[error("Couldn't retrieve information on a socket")]
+ RetrievingSocketInfoFailed,
+
+ #[error("Only a part of the message was sent")]
+ TruncatedSend,
+
+ #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")]
+ UndecidableMessageTermination,
+
+ #[error("Couldn't close the socket")]
+ CloseFailed(#[source] Errno),
+}
diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs
index 29d2d63..fb40a04 100644
--- a/src/expr/bitwise.rs
+++ b/src/expr/bitwise.rs
@@ -1,7 +1,8 @@
use rustables_macros::nfnetlink_struct;
-use super::{Expression, ExpressionData, Register};
-use crate::parser::DecodeError;
+use super::{Expression, Register};
+use crate::error::BuilderError;
+use crate::parser_impls::NfNetlinkData;
use crate::sys::{
NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR,
};
@@ -16,9 +17,9 @@ pub struct Bitwise {
#[field(NFTA_BITWISE_LEN)]
len: u32,
#[field(NFTA_BITWISE_MASK)]
- mask: ExpressionData,
+ mask: NfNetlinkData,
#[field(NFTA_BITWISE_XOR)]
- xor: ExpressionData,
+ xor: NfNetlinkData,
}
impl Expression for Bitwise {
@@ -30,17 +31,17 @@ impl Expression for Bitwise {
impl Bitwise {
/// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and
/// then performs xor with the value in `xor`
- pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, DecodeError> {
+ pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> {
let mask = mask.into();
let xor = xor.into();
if mask.len() != xor.len() {
- return Err(DecodeError::IncompatibleLength);
+ return Err(BuilderError::IncompatibleLength);
}
Ok(Bitwise::default()
.with_sreg(Register::Reg1)
.with_dreg(Register::Reg1)
.with_len(mask.len() as u32)
- .with_xor(ExpressionData::default().with_value(xor))
- .with_mask(ExpressionData::default().with_value(mask)))
+ .with_xor(NfNetlinkData::default().with_value(xor))
+ .with_mask(NfNetlinkData::default().with_value(mask)))
}
}
diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs
index d69f73c..223902f 100644
--- a/src/expr/cmp.rs
+++ b/src/expr/cmp.rs
@@ -1,11 +1,15 @@
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
-use crate::sys::{
- NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT,
- NFT_CMP_LTE, NFT_CMP_NEQ,
+use crate::{
+ data_type::DataType,
+ parser_impls::NfNetlinkData,
+ sys::{
+ NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT,
+ NFT_CMP_LTE, NFT_CMP_NEQ,
+ },
};
-use super::{Expression, ExpressionData, Register};
+use super::{Expression, Register};
/// Comparison operator.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
@@ -34,17 +38,17 @@ pub struct Cmp {
#[field(NFTA_CMP_OP)]
op: CmpOp,
#[field(NFTA_CMP_DATA)]
- data: ExpressionData,
+ data: NfNetlinkData,
}
impl Cmp {
/// Returns a new comparison expression comparing the value loaded in the register with the
/// data in `data` using the comparison operator `op`.
- pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self {
+ pub fn new(op: CmpOp, data: impl DataType) -> Self {
Cmp {
sreg: Some(Register::Reg1),
op: Some(op),
- data: Some(ExpressionData::default().with_value(data)),
+ data: Some(NfNetlinkData::default().with_value(data.data())),
}
}
}
diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs
index 134f7e1..2fd9bd5 100644
--- a/src/expr/immediate.rs
+++ b/src/expr/immediate.rs
@@ -1,7 +1,10 @@
use rustables_macros::nfnetlink_struct;
-use super::{Expression, ExpressionData, Register};
-use crate::sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG};
+use super::{Expression, Register, Verdict, VerdictKind, VerdictType};
+use crate::{
+ parser_impls::NfNetlinkData,
+ sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG},
+};
#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct]
@@ -9,14 +12,34 @@ pub struct Immediate {
#[field(NFTA_IMMEDIATE_DREG)]
dreg: Register,
#[field(NFTA_IMMEDIATE_DATA)]
- data: ExpressionData,
+ data: NfNetlinkData,
}
impl Immediate {
pub fn new_data(data: Vec<u8>, register: Register) -> Self {
Immediate::default()
.with_dreg(register)
- .with_data(ExpressionData::default().with_value(data))
+ .with_data(NfNetlinkData::default().with_value(data))
+ }
+
+ pub fn new_verdict(kind: VerdictKind) -> Self {
+ let code = match kind {
+ VerdictKind::Drop => VerdictType::Drop,
+ VerdictKind::Accept => VerdictType::Accept,
+ VerdictKind::Queue => VerdictType::Queue,
+ VerdictKind::Continue => VerdictType::Continue,
+ VerdictKind::Break => VerdictType::Break,
+ VerdictKind::Jump { .. } => VerdictType::Jump,
+ VerdictKind::Goto { .. } => VerdictType::Goto,
+ VerdictKind::Return => VerdictType::Return,
+ };
+ let mut data = Verdict::default().with_code(code);
+ if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind {
+ data.set_chain(chain);
+ }
+ Immediate::default()
+ .with_dreg(Register::Verdict)
+ .with_data(NfNetlinkData::default().with_verdict(data))
}
}
diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs
index a0cc021..2ef830e 100644
--- a/src/expr/lookup.rs
+++ b/src/expr/lookup.rs
@@ -1,78 +1,40 @@
-use super::{DeserializationError, Expression, Rule};
-use crate::set::Set;
-use crate::sys::{self, libc};
-use std::ffi::{CStr, CString};
-use std::os::raw::c_char;
+use rustables_macros::nfnetlink_struct;
-#[derive(Debug, PartialEq)]
+use super::{Expression, Register};
+use crate::error::BuilderError;
+use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG};
+use crate::Set;
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct]
pub struct Lookup {
- set_name: CString,
+ #[field(NFTA_LOOKUP_SET)]
+ set: String,
+ #[field(NFTA_LOOKUP_SREG)]
+ sreg: Register,
+ #[field(NFTA_LOOKUP_DREG)]
+ dreg: Register,
+ #[field(NFTA_LOOKUP_SET_ID)]
set_id: u32,
}
impl Lookup {
- /// Creates a new lookup entry. May return None if the set has no name.
- pub fn new<K>(set: &Set<K>) -> Option<Self> {
- set.get_name().map(|set_name| Lookup {
- set_name: set_name.to_owned(),
- set_id: set.get_id(),
- })
- }
-}
-
-impl Expression for Lookup {
- fn get_raw_name() -> *const libc::c_char {
- b"lookup\0" as *const _ as *const c_char
- }
-
- fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
- where
- Self: Sized,
- {
- unsafe {
- let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16);
- let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16);
-
- if set_name.is_null() {
- return Err(DeserializationError::NullPointer);
- }
-
- let set_name = CStr::from_ptr(set_name).to_owned();
-
- Ok(Lookup { set_id, set_name })
+ /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name.
+ pub fn new(set: &Set) -> Result<Self, BuilderError> {
+ let mut res = Lookup::default()
+ .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?)
+ .with_sreg(Register::Reg1);
+
+ if let Some(id) = set.get_id() {
+ res.set_set_id(*id);
}
- }
-
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_LOOKUP_SREG as u16,
- libc::NFT_REG_1 as u32,
- );
- sys::nftnl_expr_set_str(
- expr,
- sys::NFTNL_EXPR_LOOKUP_SET as u16,
- self.set_name.as_ptr() as *const _ as *const c_char,
- );
- sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id);
- // This code is left here since it's quite likely we need it again when we get further
- // if self.reverse {
- // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16,
- // libc::NFT_LOOKUP_F_INV as u32);
- // }
-
- expr
- }
+ Ok(res)
}
}
-#[macro_export]
-macro_rules! nft_expr_lookup {
- ($set:expr) => {
- $crate::expr::Lookup::new($set)
- };
+impl Expression for Lookup {
+ fn get_name() -> &'static str {
+ "lookup"
+ }
}
diff --git a/src/expr/meta.rs b/src/expr/meta.rs
index 79016bd..d0fecee 100644
--- a/src/expr/meta.rs
+++ b/src/expr/meta.rs
@@ -49,6 +49,12 @@ pub struct Meta {
sreg: Register,
}
+impl Meta {
+ pub fn new(ty: MetaType) -> Self {
+ Meta::default().with_dreg(Register::Reg1).with_key(ty)
+ }
+}
+
impl Expression for Meta {
fn get_name() -> &'static str {
"meta"
diff --git a/src/expr/mod.rs b/src/expr/mod.rs
index cfc01c8..979ebb2 100644
--- a/src/expr/mod.rs
+++ b/src/expr/mod.rs
@@ -4,21 +4,15 @@
//! [`Rule`]: struct.Rule.html
use std::fmt::Debug;
-use std::mem::transmute;
-
-use crate::nlmsg::NfNetlinkAttribute;
-use crate::nlmsg::NfNetlinkDeserializable;
-use crate::parser::pad_netlink_object;
-use crate::parser::pad_netlink_object_with_variable_size;
-use crate::parser::write_attribute;
-use crate::parser::DecodeError;
-use crate::sys::{self, nlattr};
-use crate::sys::{
- NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NLA_TYPE_MASK,
-};
+
use rustables_macros::nfnetlink_struct;
use thiserror::Error;
+use crate::error::DecodeError;
+use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable};
+use crate::parser_impls::NfNetlinkList;
+use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME};
+
mod bitwise;
pub use self::bitwise::*;
@@ -36,11 +30,9 @@ pub use self::immediate::*;
mod log;
pub use self::log::*;
-/*
mod lookup;
pub use self::lookup::*;
-*/
mod masquerade;
pub use self::masquerade::*;
@@ -105,19 +97,18 @@ pub struct RawExpression {
data: ExpressionVariant,
}
-impl RawExpression {
- pub fn new<T>(expr: T) -> Self
- where
- T: Expression,
- ExpressionVariant: From<T>,
- {
+impl<T> From<T> for RawExpression
+where
+ T: Expression,
+ ExpressionVariant: From<T>,
+{
+ fn from(val: T) -> Self {
RawExpression::default()
.with_name(T::get_name())
- .with_data(ExpressionVariant::from(expr))
+ .with_data(ExpressionVariant::from(val))
}
}
-#[macro_export]
macro_rules! create_expr_variant {
($enum:ident $(, [$name:ident, $type:ty])+) => {
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -162,14 +153,14 @@ macro_rules! create_expr_variant {
&mut self,
attr_type: u16,
buf: &[u8],
- ) -> Result<(), $crate::parser::DecodeError> {
+ ) -> Result<(), $crate::error::DecodeError> {
debug!("Decoding attribute {} in an expression", attr_type);
match attr_type {
x if x == sys::NFTA_EXPR_NAME => {
debug!("Calling {}::deserialize()", std::any::type_name::<String>());
let (val, remaining) = String::deserialize(buf)?;
if remaining.len() != 0 {
- return Err($crate::parser::DecodeError::InvalidDataSize);
+ return Err($crate::error::DecodeError::InvalidDataSize);
}
self.name = Some(val);
Ok(())
@@ -178,14 +169,14 @@ macro_rules! create_expr_variant {
// we can assume we have already the name parsed, as that's how we identify the
// type of expression
let name = self.name.as_ref()
- .ok_or($crate::parser::DecodeError::MissingExpressionName)?;
+ .ok_or($crate::error::DecodeError::MissingExpressionName)?;
match name {
$(
x if x == <$type>::get_name() => {
debug!("Calling {}::deserialize()", std::any::type_name::<$type>());
let (res, remaining) = <$type>::deserialize(buf)?;
if remaining.len() != 0 {
- return Err($crate::parser::DecodeError::InvalidDataSize);
+ return Err($crate::error::DecodeError::InvalidDataSize);
}
self.data = Some(ExpressionVariant::from(res));
Ok(())
@@ -207,126 +198,22 @@ macro_rules! create_expr_variant {
create_expr_variant!(
ExpressionVariant,
- [Log, Log],
- [Immediate, Immediate],
[Bitwise, Bitwise],
+ [Cmp, Cmp],
+ [Conntrack, Conntrack],
+ [Counter, Counter],
[ExpressionRaw, ExpressionRaw],
+ [Immediate, Immediate],
+ [Log, Log],
+ [Lookup, Lookup],
+ [Masquerade, Masquerade],
[Meta, Meta],
- [Reject, Reject],
- [Counter, Counter],
[Nat, Nat],
[Payload, Payload],
- [Cmp, Cmp],
- [Conntrack, Conntrack],
- [Masquerade, Masquerade]
+ [Reject, Reject]
);
-#[derive(Debug, Clone, PartialEq, Eq, Default)]
-pub struct ExpressionList {
- exprs: Vec<RawExpression>,
-}
-
-impl ExpressionList {
- /// Useful to add raw expressions because RawExpression cannot infer alone its type
- pub fn add_raw_expression(&mut self, e: RawExpression) {
- self.exprs.push(e);
- }
-
- pub fn add_expression<T>(&mut self, e: T)
- where
- T: Expression,
- ExpressionVariant: From<T>,
- {
- self.exprs.push(RawExpression::new(e));
- }
-
- pub fn with_expression<T>(mut self, e: T) -> Self
- where
- T: Expression,
- ExpressionVariant: From<T>,
- {
- self.add_expression(e);
- self
- }
-
- pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a ExpressionVariant> {
- self.exprs.iter().map(|e| e.get_data().unwrap())
- }
-}
-
-impl NfNetlinkAttribute for ExpressionList {
- fn is_nested(&self) -> bool {
- true
- }
-
- fn get_size(&self) -> usize {
- // one nlattr LIST_ELEM per object
- self.exprs.iter().fold(0, |acc, item| {
- acc + item.get_size() + pad_netlink_object::<nlattr>()
- })
- }
-
- unsafe fn write_payload(&self, mut addr: *mut u8) {
- for item in &self.exprs {
- write_attribute(sys::NFTA_LIST_ELEM, item, addr);
- addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize);
- }
- }
-}
-
-impl NfNetlinkDeserializable for ExpressionList {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let mut exprs = Vec::new();
-
- let mut pos = 0;
- while buf.len() - pos > pad_netlink_object::<nlattr>() {
- let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) };
- // ignore the byteorder and nested attributes
- let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16;
-
- if nla_type != sys::NFTA_LIST_ELEM {
- return Err(DecodeError::UnsupportedAttributeType(nla_type));
- }
-
- let (expr, remaining) = RawExpression::deserialize(
- &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize],
- )?;
- if remaining.len() != 0 {
- return Err(DecodeError::InvalidDataSize);
- }
- exprs.push(expr);
-
- pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
- }
-
- if pos != buf.len() {
- Err(DecodeError::InvalidDataSize)
- } else {
- Ok((Self { exprs }, &[]))
- }
- }
-}
-
-impl<T> From<Vec<T>> for ExpressionList
-where
- ExpressionVariant: From<T>,
- T: Expression,
-{
- fn from(v: Vec<T>) -> Self {
- ExpressionList {
- exprs: v.into_iter().map(RawExpression::new).collect(),
- }
- }
-}
-
-#[derive(Clone, PartialEq, Eq, Default, Debug)]
-#[nfnetlink_struct(nested = true)]
-pub struct ExpressionData {
- #[field(NFTA_DATA_VALUE)]
- value: Vec<u8>,
- #[field(NFTA_DATA_VERDICT)]
- verdict: VerdictAttribute,
-}
+pub type ExpressionList = NfNetlinkList<RawExpression>;
// default type for expressions that we do not handle yet
#[derive(Debug, Clone, PartialEq, Eq)]
diff --git a/src/expr/payload.rs b/src/expr/payload.rs
index 490a4ec..d0b2cea 100644
--- a/src/expr/payload.rs
+++ b/src/expr/payload.rs
@@ -2,7 +2,7 @@ use rustables_macros::nfnetlink_struct;
use super::{Expression, Register};
use crate::{
- parser::DecodeError,
+ error::DecodeError,
sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER},
};
diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs
index c4facfb..7edf7cd 100644
--- a/src/expr/verdict.rs
+++ b/src/expr/verdict.rs
@@ -3,7 +3,6 @@ use std::fmt::Debug;
use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE};
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
-use super::{ExpressionData, Immediate, Register};
use crate::sys::{
NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE,
NFT_GOTO, NFT_JUMP, NFT_RETURN,
@@ -24,7 +23,7 @@ pub enum VerdictType {
#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(nested = true)]
-pub struct VerdictAttribute {
+pub struct Verdict {
#[field(NFTA_VERDICT_CODE)]
code: VerdictType,
#[field(NFTA_VERDICT_CHAIN)]
@@ -50,25 +49,3 @@ pub enum VerdictKind {
},
Return,
}
-
-impl Immediate {
- pub fn new_verdict(kind: VerdictKind) -> Self {
- let code = match kind {
- VerdictKind::Drop => VerdictType::Drop,
- VerdictKind::Accept => VerdictType::Accept,
- VerdictKind::Queue => VerdictType::Queue,
- VerdictKind::Continue => VerdictType::Continue,
- VerdictKind::Break => VerdictType::Break,
- VerdictKind::Jump { .. } => VerdictType::Jump,
- VerdictKind::Goto { .. } => VerdictType::Goto,
- VerdictKind::Return => VerdictType::Return,
- };
- let mut data = VerdictAttribute::default().with_code(code);
- if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind {
- data.set_chain(chain);
- }
- Immediate::default()
- .with_dreg(Register::Verdict)
- .with_data(ExpressionData::default().with_verdict(data))
- }
-}
diff --git a/src/lib.rs b/src/lib.rs
index fecbc83..1ad1eed 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby
+// Copyryght (c) 2021-2022 GPL lafleur@boum.org and Simon Thoby
//
// This file is free software: you may copy, redistribute and/or modify it
// under the terms of the GNU General Public License as published by the
@@ -24,64 +24,37 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
-//! Safe abstraction for [`libnftnl`]. Provides userspace access to the in-kernel nf_tables
-//! subsystem. Can be used to create and remove tables, chains, sets and rules from the nftables
+//! Safe abstraction for userspace access to the in-kernel nf_tables subsystem.
+//! Can be used to create and remove tables, chains, sets and rules from the nftables
//! firewall, the successor to iptables.
//!
//! This library currently has quite rough edges and does not make adding and removing netfilter
//! entries super easy and elegant. That is partly because the library needs more work, but also
//! partly because nftables is super low level and extremely customizable, making it hard, and
//! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration.
-//! One can also look at how the original project this crate was developed to support uses it:
-//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app)
//!
-//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by
-//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft`
-//! binary. Since the implementation is mostly based on trial and error, there might of course be
-//! a number of places where the underlying library is used in an invalid or not intended way.
-//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome!
+//! Understanding how to use the netlink subsystem and implementing this crate has mostly been done by
+//! reading the source code for the [`nftables`] userspace program and its corresponding kernel code,
+//! as well as attaching debuggers to the `nft` binary.
+//! Since the implementation is mostly based on trial and error, there might of course be
+//! a number of places where the forged netlink messages are used in an invalid or not intended way.
+//! Contributions are welcome!
//!
-//! # Supported versions of `libnftnl`
-//!
-//! This crate will automatically link to the currently installed version of libnftnl upon build.
-//! It requires libnftnl version 1.0.6 or higher. See how the low level FFI bindings to the C
-//! library are generated in [`build.rs`].
-//!
-//! # Access to raw handles
-//!
-//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absolutely
-//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`.
-//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For
-//! example, a program using a const handle to a object in a thread and writing through a mutable
-//! handle in another could reach all kind of undefined (and dangerous!) behaviors. By enabling
-//! that feature flag, you acknowledge that guaranteeing the respect of safety invariants is now
-//! your responsibility! Despite these shortcomings, that feature is still available because it
-//! may allow you to perform manipulations that this library doesn't currently expose. If that is
-//! your case, we would be very happy to hear from you and maybe help you get the necessary
-//! functionality upstream.
-//!
-//! Our current lack of confidence in our availability to provide a safe abstraction over the use
-//! of raw handles in the face of concurrency is the reason we decided to settly on `Rc` pointers
-//! instead of `Arc` (besides, this should gives us some nice performance boost, not that it
-//! matters much of course) and why we do not declare the types exposed by the library as `Send`
-//! nor `Sync`.
-//!
-//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/
//! [`nftables`]: https://netfilter.org/projects/nftables/
-//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs
-
-use parser::DecodeError;
#[macro_use]
extern crate log;
-pub mod sys;
use libc;
+
+use rustables_macros::nfnetlink_enum;
use std::convert::TryFrom;
mod batch;
pub use batch::{default_batch_page_size, Batch};
+mod data_type;
+
mod table;
pub use table::list_tables;
pub use table::Table;
@@ -90,13 +63,16 @@ mod chain;
pub use chain::list_chains_for_table;
pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass};
+pub mod error;
+
//mod chain_methods;
//pub use chain_methods::ChainMethods;
pub mod query;
-pub mod nlmsg;
-pub mod parser;
+pub(crate) mod nlmsg;
+pub(crate) mod parser;
+pub(crate) mod parser_impls;
mod rule;
pub use rule::list_rules_for_chain;
@@ -107,8 +83,13 @@ pub mod expr;
//mod rule_methods;
//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods};
-//pub mod set;
-//pub use set::Set;
+pub mod set;
+pub use set::Set;
+
+pub mod sys;
+
+#[cfg(test)]
+mod tests;
/// The type of the message as it's sent to netfilter. A message consists of an object, such as a
/// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with
@@ -119,7 +100,7 @@ pub mod expr;
/// [`Chain`]: struct.Chain.html
/// [`Rule`]: struct.Rule.html
/// [`MsgType`]: enum.MsgType.html
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum MsgType {
/// Add the object to netfilter.
Add,
@@ -128,8 +109,8 @@ pub enum MsgType {
}
/// Denotes a protocol. Used to specify which protocol a table or set belongs to.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
-#[repr(i32)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+#[nfnetlink_enum(i32)]
pub enum ProtocolFamily {
Unspec = libc::NFPROTO_UNSPEC,
/// Inet - Means both IPv4 and IPv6
@@ -144,23 +125,6 @@ pub enum ProtocolFamily {
impl Default for ProtocolFamily {
fn default() -> Self {
- Self::Unspec
- }
-}
-
-impl TryFrom<i32> for ProtocolFamily {
- type Error = DecodeError;
- fn try_from(value: i32) -> Result<Self, Self::Error> {
- match value {
- libc::NFPROTO_UNSPEC => Ok(ProtocolFamily::Unspec),
- libc::NFPROTO_INET => Ok(ProtocolFamily::Inet),
- libc::NFPROTO_IPV4 => Ok(ProtocolFamily::Ipv4),
- libc::NFPROTO_ARP => Ok(ProtocolFamily::Arp),
- libc::NFPROTO_NETDEV => Ok(ProtocolFamily::NetDev),
- libc::NFPROTO_BRIDGE => Ok(ProtocolFamily::Bridge),
- libc::NFPROTO_IPV6 => Ok(ProtocolFamily::Ipv6),
- libc::NFPROTO_DECNET => Ok(ProtocolFamily::DecNet),
- _ => Err(DecodeError::InvalidProtocolFamily(value)),
- }
+ ProtocolFamily::Unspec
}
}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index 8563a37..b3710bf 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -1,13 +1,41 @@
use std::{fmt::Debug, mem::size_of};
use crate::{
- parser::{pad_netlink_object, pad_netlink_object_with_variable_size, DecodeError},
+ error::DecodeError,
sys::{
nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
- NFNL_SUBSYS_NFTABLES,
+ NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE,
},
MsgType, ProtocolFamily,
};
+///
+/// The largest nf_tables netlink message is the set element message, which contains the
+/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
+/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
+/// a bit larger than 64 KBytes.
+pub fn nft_nlmsg_maxsize() -> u32 {
+ u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
+}
+
+#[inline]
+pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize {
+ // align on a 4 bytes boundary
+ (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1)
+}
+
+#[inline]
+pub const fn pad_netlink_object<T>() -> usize {
+ let size = size_of::<T>();
+ pad_netlink_object_with_variable_size(size)
+}
+
+pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
+ ((x & 0xff00) >> 8) as u8
+}
+
+pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
+ (x & 0x00ff) as u8
+}
pub struct NfNetlinkWriter<'a> {
buf: &'a mut Vec<u8>,
@@ -92,76 +120,67 @@ pub trait NfNetlinkDeserializable: Sized {
fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>;
}
-pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable {
- fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32);
-}
-
-pub type NetlinkType = u16;
-
-pub trait NfNetlinkAttribute: Debug + Sized {
- // is it a nested argument that must be marked with a NLA_F_NESTED flag?
- fn is_nested(&self) -> bool {
- false
+pub trait NfNetlinkObject:
+ Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute
+{
+ const MSG_TYPE_ADD: u32;
+ const MSG_TYPE_DEL: u32;
+
+ fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
+ let raw_msg_type = match msg_type {
+ MsgType::Add => Self::MSG_TYPE_ADD,
+ MsgType::Del => Self::MSG_TYPE_DEL,
+ } as u16;
+ writer.write_header(
+ raw_msg_type,
+ self.get_family(),
+ (if let MsgType::Add = msg_type {
+ self.get_add_flags()
+ } else {
+ self.get_del_flags()
+ } | NLM_F_ACK) as u16,
+ seq,
+ None,
+ );
+ let buf = writer.add_data_zeroed(self.get_size());
+ unsafe {
+ self.write_payload(buf.as_mut_ptr());
+ }
+ writer.finalize_writing_object();
}
- fn get_size(&self) -> usize {
- size_of::<Self>()
- }
+ fn get_family(&self) -> ProtocolFamily;
- // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
- unsafe fn write_payload(&self, addr: *mut u8);
-}
-
-/*
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct NfNetlinkAttributes {
- pub attributes: BTreeMap<NetlinkType, AttributeType>,
-}
-
-impl NfNetlinkAttributes {
- pub fn new() -> Self {
- NfNetlinkAttributes {
- attributes: BTreeMap::new(),
- }
+ fn set_family(&mut self, _family: ProtocolFamily) {
+ // the default impl do nothing, because some types are family-agnostic
}
- pub fn set_attr(&mut self, ty: NetlinkType, obj: AttributeType) {
- self.attributes.insert(ty, obj);
+ fn with_family(mut self, family: ProtocolFamily) -> Self {
+ self.set_family(family);
+ self
}
- pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> {
- self.attributes.get(&ty)
+ fn get_add_flags(&self) -> u32 {
+ NLM_F_CREATE
}
- pub fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) {
- let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
+ fn get_del_flags(&self) -> u32 {
+ 0
}
}
-impl NfNetlinkAttribute for NfNetlinkAttributes {
- fn get_size(&self) -> usize {
- let mut size = 0;
-
- for (_type, attr) in self.attributes.iter() {
- // Attribute header + attribute value
- size += pad_netlink_object::<nlattr>()
- + pad_netlink_object_with_variable_size(attr.get_size());
- }
+pub type NetlinkType = u16;
- size
+pub trait NfNetlinkAttribute: Debug + Sized {
+ // is it a nested argument that must be marked with a NLA_F_NESTED flag?
+ fn is_nested(&self) -> bool {
+ false
}
- unsafe fn write_payload(&self, mut addr: *mut u8) {
- for (ty, attr) in self.attributes.iter() {
- debug!("writing attribute {} - {:?}", ty, attr);
- write_attribute(*ty, attr, addr);
- let size = pad_netlink_object::<nlattr>()
- + pad_netlink_object_with_variable_size(attr.get_size());
- addr = addr.offset(size as isize);
- }
+ fn get_size(&self) -> usize {
+ size_of::<Self>()
}
+
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
+ unsafe fn write_payload(&self, addr: *mut u8);
}
-*/
diff --git a/src/parser.rs b/src/parser.rs
index c402dae..6ea34c1 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -1,167 +1,21 @@
use std::{
- convert::TryFrom,
fmt::{Debug, DebugStruct},
mem::{size_of, transmute},
- string::FromUtf8Error,
};
-use thiserror::Error;
-
use crate::{
- nlmsg::{AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkDeserializable},
+ error::DecodeError,
+ nlmsg::{
+ get_operation_from_nlmsghdr_type, get_subsystem_from_nlmsghdr_type, pad_netlink_object,
+ pad_netlink_object_with_variable_size, AttributeDecoder, NetlinkType, NfNetlinkAttribute,
+ },
sys::{
nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN,
- NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_ALIGNTO,
- NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
+ NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE,
+ NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
},
- ProtocolFamily,
};
-#[derive(Error, Debug)]
-pub enum DecodeError {
- #[error("The buffer is too small to hold a valid message")]
- BufTooSmall,
-
- #[error("The message is too small")]
- NlMsgTooSmall,
-
- #[error("The message holds unexpected data")]
- InvalidDataSize,
-
- #[error("Missing information in the chain to create a rule")]
- MissingChainInformationError,
-
- #[error("The length of the arguments are not compatible with each other")]
- IncompatibleLength,
-
- #[error("Invalid subsystem, expected NFTABLES")]
- InvalidSubsystem(u8),
-
- #[error("Invalid version, expected NFNETLINK_V0")]
- InvalidVersion(u8),
-
- #[error("Invalid port ID")]
- InvalidPortId(u32),
-
- #[error("Invalid sequence number")]
- InvalidSeq(u32),
-
- #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
- ConcurrentGenerationUpdate,
-
- #[error("Unsupported message type")]
- UnsupportedType(u16),
-
- #[error("Invalid attribute type")]
- InvalidAttributeType,
-
- #[error("Invalid type for a chain")]
- UnknownChainType,
-
- #[error("Invalid policy for a chain")]
- UnknownChainPolicy,
-
- #[error("Unknown type for a Meta expression")]
- UnknownMetaType(u32),
-
- #[error("Unsupported value for an icmp reject type")]
- UnknownRejectType(u32),
-
- #[error("Unsupported value for an icmp code in a reject expression")]
- UnknownIcmpCode(u8),
-
- #[error("Invalid value for a register")]
- UnknownRegister(u32),
-
- #[error("Invalid type for a verdict expression")]
- UnknownVerdictType(i32),
-
- #[error("Invalid type for a nat expression")]
- UnknownNatType(i32),
-
- #[error("Invalid type for a payload expression")]
- UnknownPayloadType(u32),
-
- #[error("Invalid type for a compare expression")]
- UnknownCmpOp(u32),
-
- #[error("Invalid type for a conntrack key")]
- UnknownConntrackKey(u32),
-
- #[error("Unsupported value for a link layer header field")]
- UnknownLinkLayerHeaderField(u32, u32),
-
- #[error("Unsupported value for an IPv4 header field")]
- UnknownIPv4HeaderField(u32, u32),
-
- #[error("Unsupported value for an IPv6 header field")]
- UnknownIPv6HeaderField(u32, u32),
-
- #[error("Unsupported value for a TCP header field")]
- UnknownTCPHeaderField(u32, u32),
-
- #[error("Unsupported value for an UDP header field")]
- UnknownUDPHeaderField(u32, u32),
-
- #[error("Unsupported value for an ICMPv6 header field")]
- UnknownICMPv6HeaderField(u32, u32),
-
- #[error("Missing the 'base' attribute to deserialize the payload object")]
- PayloadMissingBase,
-
- #[error("Missing the 'offset' attribute to deserialize the payload object")]
- PayloadMissingOffset,
-
- #[error("Missing the 'len' attribute to deserialize the payload object")]
- PayloadMissingLen,
-
- #[error("The object does not contain a name for the expression being parsed")]
- MissingExpressionName,
-
- #[error("Unsupported attribute type")]
- UnsupportedAttributeType(u16),
-
- #[error("Unexpected message type")]
- UnexpectedType(u16),
-
- #[error("The decoded String is not UTF8 compliant")]
- StringDecodeFailure(#[from] FromUtf8Error),
-
- #[error("Invalid value for a protocol family")]
- InvalidProtocolFamily(i32),
-
- #[error("A custom error occured")]
- Custom(Box<dyn std::error::Error + 'static>),
-}
-
-/// The largest nf_tables netlink message is the set element message, which contains the
-/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
-/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
-/// a bit larger than 64 KBytes.
-pub fn nft_nlmsg_maxsize() -> u32 {
- u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
-}
-
-#[inline]
-pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize {
- // align on a 4 bytes boundary
- (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1)
-}
-
-#[inline]
-pub const fn pad_netlink_object<T>() -> usize {
- let size = size_of::<T>();
- pad_netlink_object_with_variable_size(size)
-}
-
-pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
- ((x & 0xff00) >> 8) as u8
-}
-
-pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
- (x & 0x00ff) as u8
-}
-
pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> {
let size_of_hdr = size_of::<nlmsghdr>();
@@ -272,126 +126,6 @@ pub unsafe fn write_attribute<'a>(
obj.write_payload(buf);
}
-impl NfNetlinkAttribute for u8 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *addr = *self;
- }
-}
-
-impl NfNetlinkDeserializable for u8 {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((buf[0], &buf[1..]))
- }
-}
-
-impl NfNetlinkAttribute for u16 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
- }
-}
-
-impl NfNetlinkDeserializable for u16 {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..]))
- }
-}
-
-impl NfNetlinkAttribute for i32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
- }
-}
-
-impl NfNetlinkDeserializable for i32 {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((
- i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
- &buf[4..],
- ))
- }
-}
-
-impl NfNetlinkAttribute for u32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
- }
-}
-
-impl NfNetlinkDeserializable for u32 {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((
- u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
- &buf[4..],
- ))
- }
-}
-
-impl NfNetlinkAttribute for u64 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
- }
-}
-
-impl NfNetlinkDeserializable for u64 {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((
- u64::from_be_bytes([
- buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
- ]),
- &buf[8..],
- ))
- }
-}
-
-impl NfNetlinkAttribute for String {
- fn get_size(&self) -> usize {
- self.len()
- }
-
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
- }
-}
-
-impl NfNetlinkDeserializable for String {
- fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- // ignore the NULL byte terminator, if any
- if buf.len() > 0 && buf[buf.len() - 1] == 0 {
- buf = &buf[..buf.len() - 1];
- }
- Ok((String::from_utf8(buf.to_vec())?, &[]))
- }
-}
-
-impl NfNetlinkAttribute for Vec<u8> {
- fn get_size(&self) -> usize {
- self.len()
- }
-
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
- }
-}
-
-impl NfNetlinkDeserializable for Vec<u8> {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- Ok((buf.to_vec(), &[]))
- }
-}
-
-impl NfNetlinkAttribute for ProtocolFamily {
- unsafe fn write_payload(&self, addr: *mut u8) {
- (*self as i32).write_payload(addr);
- }
-}
-
-impl NfNetlinkDeserializable for ProtocolFamily {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (v, remaining_data) = i32::deserialize(buf)?;
- Ok((Self::try_from(v)?, remaining_data))
- }
-}
-
pub(crate) fn read_attributes<T: AttributeDecoder + Default>(buf: &[u8]) -> Result<T, DecodeError> {
debug!(
"Calling <{} as NfNetlinkDeserialize>::deserialize()",
diff --git a/src/parser_impls.rs b/src/parser_impls.rs
new file mode 100644
index 0000000..b2681bb
--- /dev/null
+++ b/src/parser_impls.rs
@@ -0,0 +1,243 @@
+use std::{fmt::Debug, mem::transmute};
+
+use rustables_macros::nfnetlink_struct;
+
+use crate::{
+ error::DecodeError,
+ expr::Verdict,
+ nlmsg::{
+ pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
+ NfNetlinkDeserializable, NfNetlinkObject,
+ },
+ parser::{write_attribute, Parsable},
+ sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK},
+ ProtocolFamily,
+};
+
+impl NfNetlinkAttribute for u8 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *addr = *self;
+ }
+}
+
+impl NfNetlinkDeserializable for u8 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf[0], &buf[1..]))
+ }
+}
+
+impl NfNetlinkAttribute for u16 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u16 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..]))
+ }
+}
+
+impl NfNetlinkAttribute for i32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for i32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for u32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for u64 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = self.to_be();
+ }
+}
+
+impl NfNetlinkDeserializable for u64 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u64::from_be_bytes([
+ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
+ ]),
+ &buf[8..],
+ ))
+ }
+}
+
+impl NfNetlinkAttribute for String {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
+ }
+}
+
+impl NfNetlinkDeserializable for String {
+ fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ // ignore the NULL byte terminator, if any
+ if buf.len() > 0 && buf[buf.len() - 1] == 0 {
+ buf = &buf[..buf.len() - 1];
+ }
+ Ok((String::from_utf8(buf.to_vec())?, &[]))
+ }
+}
+
+impl NfNetlinkAttribute for Vec<u8> {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
+ }
+}
+
+impl NfNetlinkDeserializable for Vec<u8> {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf.to_vec(), &[]))
+ }
+}
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true)]
+pub struct NfNetlinkData {
+ #[field(NFTA_DATA_VALUE)]
+ value: Vec<u8>,
+ #[field(NFTA_DATA_VERDICT)]
+ verdict: Verdict,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Default)]
+pub struct NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default,
+{
+ objs: Vec<T>,
+}
+
+impl<T> NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ pub fn add_value(&mut self, e: impl Into<T>) {
+ self.objs.push(e.into());
+ }
+
+ pub fn with_value(mut self, e: impl Into<T>) -> Self {
+ self.add_value(e);
+ self
+ }
+
+ pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> {
+ self.objs.iter()
+ }
+}
+
+impl<T> NfNetlinkAttribute for NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn is_nested(&self) -> bool {
+ true
+ }
+
+ fn get_size(&self) -> usize {
+ // one nlattr LIST_ELEM per object
+ self.objs.iter().fold(0, |acc, item| {
+ acc + item.get_size() + pad_netlink_object::<nlattr>()
+ })
+ }
+
+ unsafe fn write_payload(&self, mut addr: *mut u8) {
+ for item in &self.objs {
+ write_attribute(NFTA_LIST_ELEM, item, addr);
+ addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize);
+ }
+ }
+}
+
+impl<T> NfNetlinkDeserializable for NfNetlinkList<T>
+where
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let mut objs = Vec::new();
+
+ let mut pos = 0;
+ while buf.len() - pos > pad_netlink_object::<nlattr>() {
+ let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) };
+ // ignore the byteorder and nested attributes
+ let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16;
+
+ if nla_type != NFTA_LIST_ELEM {
+ return Err(DecodeError::UnsupportedAttributeType(nla_type));
+ }
+
+ let (obj, remaining) = T::deserialize(
+ &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize],
+ )?;
+ if remaining.len() != 0 {
+ return Err(DecodeError::InvalidDataSize);
+ }
+ objs.push(obj);
+
+ pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
+ }
+
+ if pos != buf.len() {
+ Err(DecodeError::InvalidDataSize)
+ } else {
+ Ok((Self { objs }, &[]))
+ }
+ }
+}
+
+impl<O, T> From<Vec<O>> for NfNetlinkList<T>
+where
+ T: From<O>,
+ T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default,
+{
+ fn from(v: Vec<O>) -> Self {
+ NfNetlinkList {
+ objs: v.into_iter().map(T::from).collect(),
+ }
+ }
+}
+
+impl<T> NfNetlinkDeserializable for T
+where
+ T: NfNetlinkObject + Parsable,
+{
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (mut obj, nfgenmsg, remaining_data) = Self::parse_object(
+ buf,
+ <T as NfNetlinkObject>::MSG_TYPE_ADD,
+ <T as NfNetlinkObject>::MSG_TYPE_DEL,
+ )?;
+ obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?);
+
+ Ok((obj, remaining_data))
+ }
+}
diff --git a/src/query.rs b/src/query.rs
index 294cbfe..7cf5050 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,71 +1,31 @@
use std::os::unix::prelude::RawFd;
+use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType};
+
use crate::{
- nlmsg::{NfNetlinkAttribute, NfNetlinkObject, NfNetlinkWriter},
- parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size},
- sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI},
+ error::QueryError,
+ nlmsg::{
+ nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
+ NfNetlinkObject, NfNetlinkWriter,
+ },
+ parser::{parse_nlmsg, NlMsg},
+ sys::{NLM_F_DUMP, NLM_F_MULTI},
ProtocolFamily,
};
-use nix::{
- errno::Errno,
- sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType},
-};
-
-use crate::parser::{parse_nlmsg, DecodeError, NlMsg};
-
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] nix::Error),
-
- #[error("Unable to send netlink command to netfilter")]
- NetlinkSendError(#[source] nix::Error),
-
- #[error("Error while reading from netlink socket")]
- NetlinkRecvError(#[source] nix::Error),
-
- #[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[from] DecodeError),
-
- #[error("Error received from the kernel")]
- NetlinkError(nlmsgerr),
-
- #[error("Custom error when customizing the query")]
- InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
-
- #[error("Couldn't allocate a netlink object, out of memory ?")]
- NetlinkAllocationFailed,
-
- #[error("This socket is not a netlink socket")]
- NotNetlinkSocket,
-
- #[error("Couldn't retrieve information on a socket")]
- RetrievingSocketInfoFailed,
-
- #[error("Only a part of the message was sent")]
- TruncatedSend,
-
- #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")]
- UndecidableMessageTermination,
-
- #[error("Couldn't close the socket")]
- CloseFailed(#[source] Errno),
-}
-
pub(crate) fn recv_and_process<'a, T>(
sock: RawFd,
max_seq: Option<u32>,
- cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), Error>>,
+ cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>,
working_data: &'a mut T,
-) -> Result<(), Error> {
+) -> Result<(), QueryError> {
let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize];
let mut buf_start = 0;
let mut end_pos = 0;
loop {
let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty())
- .map_err(Error::NetlinkRecvError)?;
+ .map_err(QueryError::NetlinkRecvError)?;
if nb_recv <= 0 {
return Ok(());
}
@@ -87,7 +47,7 @@ pub(crate) fn recv_and_process<'a, T>(
}
NlMsg::Error(e) => {
if e.error != 0 {
- return Err(Error::NetlinkError(e));
+ return Err(QueryError::NetlinkError(e));
}
}
NlMsg::Noop => {}
@@ -101,7 +61,7 @@ pub(crate) fn recv_and_process<'a, T>(
// we cannot know when a sequence of messages will end if the messages do not end
// with an NlMsg::Done marker while if a maximum sequence number wasn't specified
if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
- return Err(Error::UndecidableMessageTermination);
+ return Err(QueryError::UndecidableMessageTermination);
}
// retrieve the next message
@@ -136,15 +96,15 @@ pub(crate) fn recv_and_process<'a, T>(
pub(crate) fn socket_close_wrapper<E>(
sock: RawFd,
cb: impl FnOnce(RawFd) -> Result<(), E>,
-) -> Result<(), Error>
+) -> Result<(), QueryError>
where
- Error: From<E>,
+ QueryError: From<E>,
{
let ret = cb(sock);
// we don't need to shutdown the socket (in fact, Linux doesn't support that operation;
// and return EOPNOTSUPP if we try)
- nix::unistd::close(sock).map_err(Error::CloseFailed)?;
+ nix::unistd::close(sock).map_err(QueryError::CloseFailed)?;
Ok(ret?)
}
@@ -156,7 +116,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>(
msg_type: u16,
seq: u32,
filter: Option<&T>,
-) -> Result<Vec<u8>, Error> {
+) -> Result<Vec<u8>, QueryError> {
let mut buffer = Vec::new();
let mut writer = NfNetlinkWriter::new(&mut buffer);
writer.write_header(
@@ -182,10 +142,10 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>(
/// and of the output vector, to which it should append the parsed object it received.
pub fn list_objects_with_data<'a, Object, Accumulator>(
data_type: u16,
- cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), Error>,
+ cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>,
filter: Option<&Object>,
working_data: &'a mut Accumulator,
-) -> Result<(), Error>
+) -> Result<(), QueryError>
where
Object: NfNetlinkObject + NfNetlinkAttribute,
{
@@ -196,12 +156,12 @@ where
SockFlag::empty(),
SockProtocol::NetlinkNetFilter,
)
- .map_err(Error::NetlinkOpenError)?;
+ .map_err(QueryError::NetlinkOpenError)?;
let seq = 0;
let chains_buf = get_list_of_objects(data_type, seq, filter)?;
- socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?;
+ socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?;
socket_close_wrapper(sock, move |sock| {
// the kernel should return NLM_F_MULTI objects
diff --git a/src/rule.rs b/src/rule.rs
index 5d13ac4..7f732d3 100644
--- a/src/rule.rs
+++ b/src/rule.rs
@@ -1,22 +1,24 @@
+use std::fmt::Debug;
+
use rustables_macros::nfnetlink_struct;
+use crate::chain::Chain;
+use crate::error::{BuilderError, QueryError};
use crate::expr::ExpressionList;
-use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter};
-use crate::parser::{DecodeError, Parsable};
+use crate::nlmsg::NfNetlinkObject;
use crate::query::list_objects_with_data;
use crate::sys::{
NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION,
- NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_ACK, NLM_F_CREATE,
+ NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND,
+ NLM_F_CREATE,
};
use crate::ProtocolFamily;
-use crate::{chain::Chain, MsgType};
-use std::convert::TryFrom;
-use std::fmt::Debug;
/// A nftables firewall rule.
#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
pub struct Rule {
+ family: ProtocolFamily,
#[field(NFTA_RULE_TABLE)]
table: String,
#[field(NFTA_RULE_CHAIN)]
@@ -31,78 +33,47 @@ pub struct Rule {
userdata: Vec<u8>,
#[field(NFTA_RULE_ID)]
id: u32,
- family: ProtocolFamily,
}
impl Rule {
/// Creates a new rule object in the given [`Chain`].
///
/// [`Chain`]: struct.Chain.html
- pub fn new(chain: &Chain) -> Result<Rule, DecodeError> {
+ pub fn new(chain: &Chain) -> Result<Rule, BuilderError> {
Ok(Rule::default()
.with_family(chain.get_family())
.with_table(
chain
.get_table()
- .ok_or(DecodeError::MissingChainInformationError)?,
+ .ok_or(BuilderError::MissingChainInformationError)?,
)
.with_chain(
chain
.get_name()
- .ok_or(DecodeError::MissingChainInformationError)?,
+ .ok_or(BuilderError::MissingChainInformationError)?,
))
}
+}
+
+impl NfNetlinkObject for Rule {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE;
- pub fn get_family(&self) -> ProtocolFamily {
+ fn get_family(&self) -> ProtocolFamily {
self.family
}
- pub fn set_family(&mut self, family: ProtocolFamily) {
+ fn set_family(&mut self, family: ProtocolFamily) {
self.family = family;
}
- pub fn with_family(mut self, family: ProtocolFamily) -> Self {
- self.set_family(family);
- self
- }
-}
-
-impl NfNetlinkObject for Rule {
- fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
- let raw_msg_type = match msg_type {
- MsgType::Add => NFT_MSG_NEWRULE,
- MsgType::Del => NFT_MSG_DELRULE,
- } as u16;
- writer.write_header(
- raw_msg_type,
- self.family,
- (if let MsgType::Add = msg_type {
- NLM_F_CREATE
- } else {
- 0
- } | NLM_F_ACK) as u16,
- seq,
- None,
- );
- let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
- writer.finalize_writing_object();
- }
-}
-
-impl NfNetlinkDeserializable for Rule {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (mut obj, nfgenmsg, remaining_data) =
- Self::parse_object(buf, NFT_MSG_NEWRULE, NFT_MSG_DELRULE)?;
- obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?;
-
- Ok((obj, remaining_data))
+ // append at the end of the chain, instead of the beginning
+ fn get_add_flags(&self) -> u32 {
+ NLM_F_CREATE | NLM_F_APPEND
}
}
-pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, crate::query::Error> {
+pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> {
let mut result = Vec::new();
list_objects_with_data(
libc::NFT_MSG_GETRULE as u16,
diff --git a/src/set.rs b/src/set.rs
index b153450..32d1666 100644
--- a/src/set.rs
+++ b/src/set.rs
@@ -1,278 +1,117 @@
-use crate::nlmsg::NlMsg;
-use crate::sys::{self, libc};
-use crate::{table::Table, MsgType};
-use std::{
- cell::Cell,
- ffi::{c_void, CStr, CString},
- fmt::Debug,
- net::{Ipv4Addr, Ipv6Addr},
- os::raw::c_char,
- rc::Rc,
+use rustables_macros::nfnetlink_struct;
+
+use crate::data_type::DataType;
+use crate::error::BuilderError;
+use crate::nlmsg::NfNetlinkObject;
+use crate::parser_impls::{NfNetlinkData, NfNetlinkList};
+use crate::sys::{
+ NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, NFTA_SET_ELEM_LIST_SET,
+ NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_FLAGS, NFTA_SET_ID, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE,
+ NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_DELSETELEM,
+ NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM,
};
-
-#[macro_export]
-macro_rules! nft_set {
- ($name:expr, $id:expr, $table:expr) => {
- $crate::set::Set::new(Some($name), $id, $table, $family)
- };
- ($name:expr, $id:expr, $table:expr; [ ]) => {
- nft_set!(Some($name), $id, $table)
- };
- ($name:expr, $id:expr, $table:expr; [ $($value:expr,)* ]) => {{
- let mut set = nft_set!(Some($name), $id, $table).expect("Set allocation failed");
- $(
- set.add($value).expect(stringify!(Unable to add $value to set $name));
- )*
- set
- }};
-}
-
-pub struct Set<K> {
- pub(crate) set: *mut sys::nftnl_set,
- pub(crate) table: Rc<Table>,
- _marker: ::std::marker::PhantomData<K>,
-}
-
-impl<K> Set<K> {
- pub fn new(name: &CStr, id: u32, table: Rc<Table>) -> Self
- where
- K: SetKey,
- {
- unsafe {
- let set = try_alloc!(sys::nftnl_set_alloc());
-
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, table.get_family() as u32);
- sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr());
- sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr());
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id);
-
- sys::nftnl_set_set_u32(
- set,
- sys::NFTNL_SET_FLAGS as u16,
- (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32,
- );
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE);
- sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN);
-
- Set {
- set,
- table,
- _marker: ::std::marker::PhantomData,
- }
- }
- }
-
- pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: Rc<Table>) -> Self
- where
- K: SetKey,
- {
- Set {
- set,
- table,
- _marker: ::std::marker::PhantomData,
- }
- }
-
- pub fn add(&mut self, key: &K)
- where
- K: SetKey,
- {
- unsafe {
- let elem = try_alloc!(sys::nftnl_set_elem_alloc());
-
- let data = key.data();
- let data_len = data.len() as u32;
- trace!("Adding key {:?} with len {}", data, data_len);
- sys::nftnl_set_elem_set(
- elem,
- sys::NFTNL_SET_ELEM_KEY as u16,
- data.as_ref() as *const _ as *const c_void,
- data_len,
- );
- sys::nftnl_set_elem_add(self.set, elem);
- }
- }
-
- pub fn elems_iter(&self) -> SetElemsIter<K> {
- SetElemsIter::new(self)
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_set {
- self.set as *const sys::nftnl_set
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set {
- self.set
- }
-
- /// Returns a textual description of the set.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_set_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.set,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- pub fn get_name(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
- }
-
- pub fn get_id(&self) -> u32 {
- unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) }
- }
-}
-
-impl<K> Debug for Set<K> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
- }
-}
-
-unsafe impl<K> NlMsg for Set<K> {
- unsafe fn write(&self, buf: &mut Vec<u8>, seq: u32, msg_type: MsgType) {
- let type_ = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWSET,
- MsgType::Del => libc::NFT_MSG_DELSET,
- };
- /*
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- type_ as u16,
- self.table.get_family() as u16,
- (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16,
- seq,
- );
- sys::nftnl_set_nlmsg_build_payload(header, self.set);
- */
- }
-}
-
-impl<K> Drop for Set<K> {
- fn drop(&mut self) {
- unsafe { sys::nftnl_set_free(self.set) };
- }
-}
-
-pub struct SetElemsIter<'a, K> {
- set: &'a Set<K>,
- iter: *mut sys::nftnl_set_elems_iter,
- ret: Rc<Cell<i32>>,
-}
-
-impl<'a, K> SetElemsIter<'a, K> {
- fn new(set: &'a Set<K>) -> Self {
- let iter = try_alloc!(unsafe {
- sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set)
+use crate::table::Table;
+use crate::ProtocolFamily;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(derive_deserialize = false)]
+pub struct Set {
+ pub family: ProtocolFamily,
+ #[field(NFTA_SET_TABLE)]
+ pub table: String,
+ #[field(NFTA_SET_NAME)]
+ pub name: String,
+ #[field(NFTA_SET_FLAGS)]
+ pub flags: u32,
+ #[field(NFTA_SET_KEY_TYPE)]
+ pub key_type: u32,
+ #[field(NFTA_SET_KEY_LEN)]
+ pub key_len: u32,
+ #[field(NFTA_SET_ID)]
+ pub id: u32,
+ #[field(NFTA_SET_USERDATA)]
+ pub userdata: String,
+}
+
+impl NfNetlinkObject for Set {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSET;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELSET;
+
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
+ }
+
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
+ }
+}
+
+pub struct SetBuilder<K: DataType> {
+ inner: Set,
+ list: SetElementList,
+ _phantom: PhantomData<K>,
+}
+
+impl<K: DataType> SetBuilder<K> {
+ pub fn new(name: impl Into<String>, id: u32, table: &Table) -> Result<Self, BuilderError> {
+ let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?;
+ let set_name = name.into();
+ let set = Set::default()
+ .with_id(id)
+ .with_key_type(K::TYPE)
+ .with_key_len(K::LEN)
+ .with_table(table_name)
+ .with_name(&set_name);
+
+ Ok(SetBuilder {
+ inner: set,
+ list: SetElementList {
+ table: Some(table_name.clone()),
+ set: Some(set_name),
+ elements: Some(SetElementListElements::default()),
+ },
+ _phantom: PhantomData,
+ })
+ }
+
+ pub fn add(&mut self, key: &K) {
+ self.list.elements.as_mut().unwrap().add_value(SetElement {
+ key: Some(NfNetlinkData::default().with_value(key.data())),
});
- SetElemsIter {
- set,
- iter,
- ret: Rc::new(Cell::new(1)),
- }
}
-}
-
-impl<'a, K> Iterator for SetElemsIter<'a, K> {
- type Item = SetElemsMsg<'a, K>;
- fn next(&mut self) -> Option<Self::Item> {
- if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } {
- trace!("SetElemsIter iterator ending");
- None
- } else {
- trace!("SetElemsIter returning new SetElemsMsg");
- Some(SetElemsMsg {
- set: self.set,
- iter: self.iter,
- ret: self.ret.clone(),
- })
- }
+ pub fn finish(self) -> (Set, SetElementList) {
+ (self.inner, self.list)
}
}
-impl<'a, K> Drop for SetElemsIter<'a, K> {
- fn drop(&mut self) {
- unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) };
- }
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true, derive_deserialize = false)]
+pub struct SetElementList {
+ #[field(NFTA_SET_ELEM_LIST_TABLE)]
+ pub table: String,
+ #[field(NFTA_SET_ELEM_LIST_SET)]
+ pub set: String,
+ #[field(NFTA_SET_ELEM_LIST_ELEMENTS)]
+ pub elements: SetElementListElements,
}
-pub struct SetElemsMsg<'a, K> {
- set: &'a Set<K>,
- iter: *mut sys::nftnl_set_elems_iter,
- ret: Rc<Cell<i32>>,
-}
+impl NfNetlinkObject for SetElementList {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSETELEM;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELSETELEM;
-unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- trace!("Writing SetElemsMsg to NlMsg");
- let (type_, flags) = match msg_type {
- MsgType::Add => (
- libc::NFT_MSG_NEWSETELEM,
- libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK,
- ),
- MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK),
- };
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- type_ as u16,
- self.set.table.get_family() as u16,
- flags as u16,
- seq,
- );
- self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter(
- header, self.iter,
- ));
+ fn get_family(&self) -> ProtocolFamily {
+ ProtocolFamily::Unspec
}
}
-pub trait SetKey {
- const TYPE: u32;
- const LEN: u32;
-
- fn data(&self) -> Box<[u8]>;
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[nfnetlink_struct(nested = true)]
+pub struct SetElement {
+ #[field(NFTA_SET_ELEM_KEY)]
+ pub key: NfNetlinkData,
}
-impl SetKey for Ipv4Addr {
- const TYPE: u32 = 7;
- const LEN: u32 = 4;
-
- fn data(&self) -> Box<[u8]> {
- self.octets().to_vec().into_boxed_slice()
- }
-}
-
-impl SetKey for Ipv6Addr {
- const TYPE: u32 = 8;
- const LEN: u32 = 16;
-
- fn data(&self) -> Box<[u8]> {
- self.octets().to_vec().into_boxed_slice()
- }
-}
-
-impl<const N: usize> SetKey for [u8; N] {
- const TYPE: u32 = 5;
- const LEN: u32 = N as u32;
-
- fn data(&self) -> Box<[u8]> {
- Box::new(*self)
- }
-}
+type SetElementListElements = NfNetlinkList<SetElement>;
diff --git a/src/table.rs b/src/table.rs
index e6a6a1a..63bf669 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,15 +1,14 @@
-use std::convert::TryFrom;
use std::fmt::Debug;
use rustables_macros::nfnetlink_struct;
-use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter};
-use crate::parser::{DecodeError, Parsable};
+use crate::error::QueryError;
+use crate::nlmsg::NfNetlinkObject;
use crate::sys::{
NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
- NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE,
+ NFT_MSG_NEWTABLE,
};
-use crate::{MsgType, ProtocolFamily};
+use crate::ProtocolFamily;
/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
@@ -18,13 +17,13 @@ use crate::{MsgType, ProtocolFamily};
#[derive(Default, PartialEq, Eq, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
pub struct Table {
+ family: ProtocolFamily,
#[field(NFTA_TABLE_NAME)]
name: String,
#[field(NFTA_TABLE_FLAGS)]
flags: u32,
#[field(NFTA_TABLE_USERDATA)]
userdata: Vec<u8>,
- pub family: ProtocolFamily,
}
impl Table {
@@ -36,41 +35,19 @@ impl Table {
}
impl NfNetlinkObject for Table {
- fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
- let raw_msg_type = match msg_type {
- MsgType::Add => NFT_MSG_NEWTABLE,
- MsgType::Del => NFT_MSG_DELTABLE,
- } as u16;
- writer.write_header(
- raw_msg_type,
- self.family,
- (if let MsgType::Add = msg_type {
- NLM_F_CREATE
- } else {
- 0
- } | NLM_F_ACK) as u16,
- seq,
- None,
- );
- let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
- writer.finalize_writing_object();
- }
-}
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWTABLE;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELTABLE;
-impl NfNetlinkDeserializable for Table {
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (mut obj, nfgenmsg, remaining_data) =
- Self::parse_object(buf, NFT_MSG_NEWTABLE, NFT_MSG_DELTABLE)?;
- obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?;
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
+ }
- Ok((obj, remaining_data))
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> {
+pub fn list_tables() -> Result<Vec<Table>, QueryError> {
let mut result = Vec::new();
crate::query::list_objects_with_data(
NFT_MSG_GETTABLE as u16,
diff --git a/src/tests/batch.rs b/src/tests/batch.rs
new file mode 100644
index 0000000..12f373f
--- /dev/null
+++ b/src/tests/batch.rs
@@ -0,0 +1,96 @@
+use std::mem::size_of;
+
+use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST};
+use nix::libc::NFNL_MSG_BATCH_END;
+
+use crate::nlmsg::{pad_netlink_object_with_variable_size, NfNetlinkDeserializable};
+use crate::parser::{parse_nlmsg, NlMsg};
+use crate::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES};
+use crate::{Batch, MsgType, Table};
+
+use super::get_test_table;
+
+const HEADER_SIZE: u32 =
+ pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()) as u32;
+
+const DEFAULT_BATCH_BEGIN_HDR: nlmsghdr = nlmsghdr {
+ nlmsg_len: HEADER_SIZE,
+ nlmsg_flags: NLM_F_REQUEST as u16,
+ nlmsg_type: NFNL_MSG_BATCH_BEGIN as u16,
+ nlmsg_seq: 0,
+ nlmsg_pid: 0,
+};
+const DEFAULT_BATCH_MSG: NlMsg = NlMsg::NfGenMsg(
+ nfgenmsg {
+ nfgen_family: AF_UNSPEC as u8,
+ version: NFNETLINK_V0 as u8,
+ res_id: NFNL_SUBSYS_NFTABLES as u16,
+ },
+ &[],
+);
+
+const DEFAULT_BATCH_END_HDR: nlmsghdr = nlmsghdr {
+ nlmsg_len: HEADER_SIZE,
+ nlmsg_flags: NLM_F_REQUEST as u16,
+ nlmsg_type: NFNL_MSG_BATCH_END as u16,
+ nlmsg_seq: 1,
+ nlmsg_pid: 0,
+};
+
+#[test]
+fn batch_empty() {
+ let batch = Batch::new();
+ let buf = batch.finalize();
+
+ let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+
+ let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
+
+ let (hdr, msg) = parse_nlmsg(&buf[remaining_data_offset..]).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_END_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+}
+
+#[test]
+fn batch_with_objects() {
+ let mut original_tables = vec![];
+ for i in 0..10 {
+ let mut table = get_test_table();
+ table.set_userdata(vec![i as u8]);
+ original_tables.push(table);
+ }
+
+ let mut batch = Batch::new();
+ for i in 0..10 {
+ batch.add(
+ &original_tables[i],
+ if i % 2 == 0 {
+ MsgType::Add
+ } else {
+ MsgType::Del
+ },
+ );
+ }
+ let buf = batch.finalize();
+
+ let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
+ assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+ let mut remaining_data = &buf[pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize)..];
+
+ for i in 0..10 {
+ let (deserialized_table, rest) =
+ Table::deserialize(&remaining_data).expect("could not deserialize a table");
+ remaining_data = rest;
+
+ assert_eq!(deserialized_table, original_tables[i]);
+ }
+
+ let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message");
+ let mut end_hdr = DEFAULT_BATCH_END_HDR;
+ end_hdr.nlmsg_seq = 11;
+ assert_eq!(hdr, end_hdr);
+ assert_eq!(msg, DEFAULT_BATCH_MSG);
+}
diff --git a/src/tests/chain.rs b/src/tests/chain.rs
new file mode 100644
index 0000000..7f696e6
--- /dev/null
+++ b/src/tests/chain.rs
@@ -0,0 +1,120 @@
+use crate::{
+ nlmsg::get_operation_from_nlmsghdr_type,
+ sys::{
+ NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA,
+ NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN,
+ },
+ ChainType, Hook, HookClass, MsgType,
+};
+
+use super::{
+ get_test_chain, get_test_nlmsg, get_test_nlmsg_with_msg_type, NetlinkExpr, CHAIN_NAME,
+ CHAIN_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_chain() {
+ let mut chain = get_test_chain();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_chain_with_hook_and_type() {
+ let mut chain = get_test_chain()
+ .with_hook(Hook::new(HookClass::In, 0))
+ .with_type(ChainType::Filter);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 84);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_TYPE, "filter".as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_CHAIN_HOOK,
+ vec![
+ NetlinkExpr::List(vec![NetlinkExpr::Final(
+ NFTA_HOOK_HOOKNUM,
+ vec![0, 0, 0, 1]
+ )]),
+ NetlinkExpr::List(vec![NetlinkExpr::Final(
+ NFTA_HOOK_PRIORITY,
+ vec![0, 0, 0, 0]
+ )])
+ ]
+ ),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_chain_with_userdata() {
+ let mut chain = get_test_chain();
+ chain.set_userdata(CHAIN_USERDATA);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 72);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_USERDATA, CHAIN_USERDATA.as_bytes().to_vec())
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_chain() {
+ let mut chain = get_test_chain();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut chain, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELCHAIN as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/expr.rs b/src/tests/expr.rs
new file mode 100644
index 0000000..141f6ac
--- /dev/null
+++ b/src/tests/expr.rs
@@ -0,0 +1,589 @@
+use std::net::Ipv4Addr;
+
+use libc::NF_DROP;
+
+use crate::{
+ expr::{
+ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField,
+ HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType,
+ Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind,
+ },
+ sys::{
+ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG,
+ NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES,
+ NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT,
+ NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM,
+ NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_META_DREG, NFTA_META_KEY, NFTA_NAT_FAMILY,
+ NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, NFTA_PAYLOAD_DREG,
+ NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE,
+ NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_CMP_EQ,
+ NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1,
+ NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH,
+ },
+ ProtocolFamily,
+};
+
+use super::{get_test_nlmsg, get_test_rule, NetlinkExpr, CHAIN_NAME, TABLE_NAME};
+
+#[test]
+fn bitwise_expr_is_valid() {
+ let netmask = Ipv4Addr::new(255, 255, 255, 0);
+ let bitwise = Bitwise::new(netmask.octets(), [0, 0, 0, 0]).unwrap();
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(bitwise));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 124);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_BITWISE_SREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_BITWISE_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(NFTA_BITWISE_LEN, 4u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_BITWISE_MASK,
+ vec![NetlinkExpr::Final(
+ NFTA_DATA_VALUE,
+ vec![255, 255, 255, 0]
+ )]
+ ),
+ NetlinkExpr::Nested(
+ NFTA_BITWISE_XOR,
+ vec![NetlinkExpr::Final(
+ NFTA_DATA_VALUE,
+ 0u32.to_be_bytes().to_vec()
+ )]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn cmp_expr_is_valid() {
+ let val = [1u8, 2, 3, 4];
+ let cmp = Cmp::new(CmpOp::Eq, val.clone());
+ let mut rule = get_test_rule().with_expressions(vec![cmp]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(NFTA_CMP_SREG, NFT_REG_1.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_CMP_DATA,
+ vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val.to_vec())]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn counter_expr_is_valid() {
+ let nb_bytes = 123456u64;
+ let nb_packets = 987u64;
+ let counter = Counter::default()
+ .with_nb_bytes(nb_bytes)
+ .with_nb_packets(nb_packets);
+
+ let mut rule = get_test_rule().with_expressions(vec![counter]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_COUNTER_BYTES,
+ nb_bytes.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_COUNTER_PACKETS,
+ nb_packets.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn ct_expr_is_valid() {
+ let ct = Conntrack::default().with_retrieve_value(ConntrackKey::State);
+ let mut rule = get_test_rule().with_expressions(vec![ct]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_CT_KEY,
+ NFT_CT_STATE.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(NFTA_CT_DREG, NFT_REG_1.to_be_bytes().to_vec())
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ )
+}
+
+#[test]
+fn immediate_expr_is_valid() {
+ let immediate = Immediate::new_data(vec![42u8], Register::Reg1);
+ let mut rule =
+ get_test_rule().with_expressions(ExpressionList::default().with_value(immediate));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 100);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_IMMEDIATE_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Nested(
+ NFTA_IMMEDIATE_DATA,
+ vec![NetlinkExpr::Final(1u16, 42u8.to_be_bytes().to_vec())]
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn log_expr_is_valid() {
+ let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression");
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(log));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"log".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(NFTA_LOG_GROUP, 1337u16.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix".to_vec()),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+/*
+#[test]
+fn lookup_expr_is_valid() {
+ let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap();
+ let mut rule = get_test_rule();
+ let table = rule.get_chain().get_table();
+ let mut set = Set::new(set_name, 0, table);
+ let address: Ipv4Addr = [8, 8, 8, 8].into();
+ set.add(&address);
+ let lookup = Lookup::new(&set).unwrap();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup);
+ assert_eq!(nlmsghdr.nlmsg_len, 104);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_LOOKUP_SREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()),
+ NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+*/
+
+#[test]
+fn masquerade_expr_is_valid() {
+ let masquerade = Masquerade::default();
+ let mut rule = get_test_rule().with_expressions(vec![masquerade]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 72);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq".to_vec()),
+ NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]),
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn meta_expr_is_valid() {
+ let meta = Meta::default()
+ .with_key(MetaType::Protocol)
+ .with_dreg(Register::Reg1);
+ let mut rule = get_test_rule().with_expressions(vec![meta]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_META_KEY,
+ NFT_META_PROTOCOL.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_META_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn nat_expr_is_valid() {
+ let nat = Nat::default()
+ .with_nat_type(NatType::SNat)
+ .with_family(ProtocolFamily::Ipv4)
+ .with_ip_register(Register::Reg1);
+ let mut rule = get_test_rule().with_expressions(vec![nat]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_NAT_TYPE,
+ NFT_NAT_SNAT.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_NAT_FAMILY,
+ (ProtocolFamily::Ipv4 as u32).to_be_bytes().to_vec(),
+ ),
+ NetlinkExpr::Final(
+ NFTA_NAT_REG_ADDR_MIN,
+ NFT_REG_1.to_be_bytes().to_vec()
+ )
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn payload_expr_is_valid() {
+ let tcp_header_field = TCPHeaderField::Sport;
+ let transport_header_field = TransportHeaderField::Tcp(tcp_header_field);
+ let payload = HighLevelPayload::Transport(transport_header_field);
+ let mut rule = get_test_rule().with_expressions(vec![payload.build()]);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 108);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_DREG,
+ NFT_REG_1.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_BASE,
+ NFT_PAYLOAD_TRANSPORT_HEADER.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_OFFSET,
+ tcp_header_field.offset().to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_PAYLOAD_LEN,
+ tcp_header_field.len().to_be_bytes().to_vec()
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn reject_expr_is_valid() {
+ let code = IcmpCode::NoRoute;
+ let reject = Reject::default()
+ .with_type(RejectType::IcmpxUnreach)
+ .with_icmp_code(code);
+ let mut rule = get_test_rule().with_expressions(vec![reject]);
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 92);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_REJECT_TYPE,
+ NFT_REJECT_ICMPX_UNREACH.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Final(
+ NFTA_REJECT_ICMP_CODE,
+ (code as u8).to_be_bytes().to_vec()
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn verdict_expr_is_valid() {
+ let verdict = Immediate::new_verdict(VerdictKind::Drop);
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(verdict));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 104);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_RULE_EXPRESSIONS,
+ vec![NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_EXPR_DATA,
+ vec![
+ NetlinkExpr::Final(
+ NFTA_IMMEDIATE_DREG,
+ NFT_REG_VERDICT.to_be_bytes().to_vec()
+ ),
+ NetlinkExpr::Nested(
+ NFTA_IMMEDIATE_DATA,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VERDICT,
+ vec![NetlinkExpr::Final(
+ NFTA_VERDICT_CODE,
+ NF_DROP.to_be_bytes().to_vec()
+ ),]
+ )],
+ ),
+ ]
+ )
+ ]
+ )]
+ )
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/mod.rs b/src/tests/mod.rs
new file mode 100644
index 0000000..3693d35
--- /dev/null
+++ b/src/tests/mod.rs
@@ -0,0 +1,195 @@
+use crate::data_type::DataType;
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::parser::{parse_nlmsg, NlMsg};
+use crate::set::{Set, SetBuilder};
+use crate::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table};
+
+mod batch;
+mod chain;
+mod expr;
+mod rule;
+mod set;
+mod table;
+
+pub const TABLE_NAME: &'static str = "mocktable";
+pub const CHAIN_NAME: &'static str = "mockchain";
+pub const SET_NAME: &'static str = "mockset";
+
+pub const TABLE_USERDATA: &'static str = "mocktabledata";
+pub const CHAIN_USERDATA: &'static str = "mockchaindata";
+pub const RULE_USERDATA: &'static str = "mockruledata";
+pub const SET_USERDATA: &'static str = "mocksetdata";
+
+pub const SET_ID: u32 = 123456;
+
+type NetLinkType = u16;
+
+#[derive(Debug, thiserror::Error)]
+#[error("empty data")]
+pub struct EmptyDataError;
+
+#[derive(Debug, Clone, Eq, Ord)]
+pub enum NetlinkExpr {
+ Nested(NetLinkType, Vec<NetlinkExpr>),
+ Final(NetLinkType, Vec<u8>),
+ List(Vec<NetlinkExpr>),
+}
+
+impl NetlinkExpr {
+ pub fn to_raw(self) -> Vec<u8> {
+ match self.sort() {
+ NetlinkExpr::Final(ty, val) => {
+ let len = val.len() + 4;
+ let mut res = Vec::with_capacity(len);
+
+ res.extend(&(len as u16).to_le_bytes());
+ res.extend(&ty.to_le_bytes());
+ res.extend(val);
+ // alignment
+ while res.len() % 4 != 0 {
+ res.push(0);
+ }
+
+ res
+ }
+ NetlinkExpr::Nested(ty, exprs) => {
+ // some heuristic to decrease allocations (even though this is
+ // only useful for testing so performance is not an objective)
+ let mut sub = Vec::with_capacity(exprs.len() * 50);
+
+ for expr in exprs {
+ sub.append(&mut expr.to_raw());
+ }
+
+ let len = sub.len() + 4;
+ let mut res = Vec::with_capacity(len);
+
+ // set the "NESTED" flag
+ res.extend(&(len as u16).to_le_bytes());
+ res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes());
+ res.extend(sub);
+
+ res
+ }
+ NetlinkExpr::List(exprs) => {
+ // some heuristic to decrease allocations (even though this is
+ // only useful for testing so performance is not an objective)
+ let mut list = Vec::with_capacity(exprs.len() * 50);
+
+ for expr in exprs {
+ list.append(&mut expr.to_raw());
+ }
+
+ list
+ }
+ }
+ }
+
+ pub fn sort(self) -> Self {
+ match self {
+ NetlinkExpr::Final(_, _) => self,
+ NetlinkExpr::Nested(ty, mut exprs) => {
+ exprs.sort();
+ NetlinkExpr::Nested(ty, exprs)
+ }
+ NetlinkExpr::List(mut exprs) => {
+ exprs.sort();
+ NetlinkExpr::List(exprs)
+ }
+ }
+ }
+}
+
+impl PartialEq for NetlinkExpr {
+ fn eq(&self, other: &Self) -> bool {
+ match (self.clone().sort(), other.clone().sort()) {
+ (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2,
+ _ => false,
+ }
+ }
+}
+
+impl PartialOrd for NetlinkExpr {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (
+ NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _),
+ NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _),
+ ) => k1.partial_cmp(k2),
+ (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2),
+ (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less),
+ (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater),
+ }
+ }
+}
+
+pub fn get_test_table() -> Table {
+ Table::new(ProtocolFamily::Inet)
+ .with_name(TABLE_NAME)
+ .with_flags(0u32)
+}
+
+pub fn get_test_table_raw_expr() -> NetlinkExpr {
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()),
+ ])
+ .sort()
+}
+
+pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr {
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.as_bytes().to_vec()),
+ ])
+ .sort()
+}
+
+pub fn get_test_chain() -> Chain {
+ Chain::new(&get_test_table()).with_name(CHAIN_NAME)
+}
+
+pub fn get_test_rule() -> Rule {
+ Rule::new(&get_test_chain()).unwrap()
+}
+
+pub fn get_test_set<K: DataType>() -> Set {
+ SetBuilder::<K>::new(SET_NAME, SET_ID, &get_test_table())
+ .expect("Couldn't create a set")
+ .finish()
+ .0
+ .with_userdata(SET_USERDATA)
+}
+
+pub fn get_test_nlmsg_with_msg_type<'a>(
+ buf: &'a mut Vec<u8>,
+ obj: &mut impl NfNetlinkObject,
+ msg_type: MsgType,
+) -> (nlmsghdr, nfgenmsg, &'a [u8]) {
+ let mut writer = NfNetlinkWriter::new(buf);
+ obj.add_or_remove(&mut writer, msg_type, 0);
+
+ let (hdr, msg) = parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message");
+
+ let (nfgenmsg, raw_value) = match msg {
+ NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value),
+ _ => panic!("Invalid return value type, expected a valid message"),
+ };
+
+ // sanity checks on the global message (this should be very similar/factorisable for the
+ // most part in other tests)
+ // TODO: check the messages flags
+ assert_eq!(nfgenmsg.res_id.to_be(), 0);
+
+ (hdr, nfgenmsg, raw_value)
+}
+
+pub fn get_test_nlmsg<'a>(
+ buf: &'a mut Vec<u8>,
+ obj: &mut impl NfNetlinkObject,
+) -> (nlmsghdr, nfgenmsg, &'a [u8]) {
+ get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add)
+}
diff --git a/src/tests/rule.rs b/src/tests/rule.rs
new file mode 100644
index 0000000..08b4139
--- /dev/null
+++ b/src/tests/rule.rs
@@ -0,0 +1,132 @@
+use crate::{
+ nlmsg::get_operation_from_nlmsghdr_type,
+ sys::{
+ NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA,
+ NFT_MSG_DELRULE, NFT_MSG_NEWRULE,
+ },
+ MsgType,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_rule, NetlinkExpr, CHAIN_NAME,
+ RULE_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_rule() {
+ let mut rule = get_test_rule();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_rule_with_userdata() {
+ let mut rule = get_test_rule().with_userdata(RULE_USERDATA);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 68);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.as_bytes().to_vec())
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_empty_rule_with_position_and_handle() {
+ let handle: u64 = 1337;
+ let position: u64 = 42;
+ let mut rule = get_test_rule().with_handle(handle).with_position(position);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 76);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_rule() {
+ let mut rule = get_test_rule();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 52);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_rule_with_handle() {
+ let handle: u64 = 42;
+ let mut rule = get_test_rule().with_handle(handle);
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELRULE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 64);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/set.rs b/src/tests/set.rs
new file mode 100644
index 0000000..db27ced
--- /dev/null
+++ b/src/tests/set.rs
@@ -0,0 +1,122 @@
+use std::net::{Ipv4Addr, Ipv6Addr};
+
+use crate::{
+ data_type::DataType,
+ nlmsg::get_operation_from_nlmsghdr_type,
+ set::SetBuilder,
+ sys::{
+ NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS,
+ NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN,
+ NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET,
+ NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM,
+ },
+ MsgType,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr,
+ SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME,
+};
+
+#[test]
+fn new_empty_set() {
+ let mut set = get_test_set::<Ipv4Addr>();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut set);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWSET as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn delete_empty_set() {
+ let mut set = get_test_set::<Ipv6Addr>();
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut set, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELSET as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 88);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
+ ])
+ .to_raw()
+ );
+}
+
+#[test]
+fn new_set_with_data() {
+ let ip1 = Ipv4Addr::new(127, 0, 0, 1);
+ let ip2 = Ipv4Addr::new(1, 1, 1, 1);
+ let mut set_builder =
+ SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), SET_ID, &get_test_table())
+ .expect("Couldn't create a set");
+
+ set_builder.add(&ip1);
+ set_builder.add(&ip2);
+ let (_set, mut elem_list) = set_builder.finish();
+
+ let mut buf = Vec::new();
+
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut elem_list);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWSETELEM as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 84);
+
+ assert_eq!(
+ raw_expr,
+ NetlinkExpr::List(vec![
+ NetlinkExpr::Final(NFTA_SET_ELEM_LIST_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_SET_ELEM_LIST_SET, SET_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Nested(
+ NFTA_SET_ELEM_LIST_ELEMENTS,
+ vec![
+ NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VALUE,
+ vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip1.data().to_vec())]
+ )]
+ ),
+ NetlinkExpr::Nested(
+ NFTA_LIST_ELEM,
+ vec![NetlinkExpr::Nested(
+ NFTA_DATA_VALUE,
+ vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip2.data().to_vec())]
+ )]
+ ),
+ ]
+ ),
+ ])
+ .to_raw()
+ );
+}
diff --git a/src/tests/table.rs b/src/tests/table.rs
new file mode 100644
index 0000000..39bf399
--- /dev/null
+++ b/src/tests/table.rs
@@ -0,0 +1,67 @@
+use crate::{
+ nlmsg::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize, NfNetlinkDeserializable},
+ sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE},
+ MsgType, Table,
+};
+
+use super::{
+ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_table, get_test_table_raw_expr,
+ get_test_table_with_userdata_raw_expr, TABLE_USERDATA,
+};
+
+#[test]
+fn new_empty_table() {
+ let mut table = get_test_table();
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 44);
+
+ assert_eq!(raw_expr, get_test_table_raw_expr().to_raw());
+}
+
+#[test]
+fn new_empty_table_with_userdata() {
+ let mut table = get_test_table();
+ table.set_userdata(TABLE_USERDATA.as_bytes().to_vec());
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_NEWTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 64);
+
+ assert_eq!(raw_expr, get_test_table_with_userdata_raw_expr().to_raw());
+}
+
+#[test]
+fn delete_empty_table() {
+ let mut table = get_test_table();
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (nlmsghdr, _nfgenmsg, raw_expr) =
+ get_test_nlmsg_with_msg_type(&mut buf, &mut table, MsgType::Del);
+ assert_eq!(
+ get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
+ NFT_MSG_DELTABLE as u8
+ );
+ assert_eq!(nlmsghdr.nlmsg_len, 44);
+
+ assert_eq!(raw_expr, get_test_table_raw_expr().to_raw());
+}
+
+#[test]
+fn parse_table() {
+ let mut table = get_test_table();
+ table.set_userdata(TABLE_USERDATA.as_bytes().to_vec());
+ let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize);
+ let (_nlmsghdr, _nfgenmsg, _raw_expr) = get_test_nlmsg(&mut buf, &mut table);
+
+ let (deserialized_table, remaining) =
+ Table::deserialize(&buf).expect("Couldn't deserialize the object");
+ assert_eq!(table, deserialized_table);
+ assert_eq!(remaining.len(), 0);
+}