aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon THOBY <git@nightmared.fr>2021-11-05 06:23:45 +0000
committerSimon THOBY <git@nightmared.fr>2021-11-05 06:23:45 +0000
commit46b22d88c36863851e4b27efa767d28c8aeecfe0 (patch)
treeab1a638de7587e7b73fe64093428218e1c545004
parent3f61ea42bd291c208d07006d8019c25d588f9183 (diff)
parent1bec5a5c30541e47e9c7cff839ac0e7dd3fb6215 (diff)
Merge branch 'manipulate-exprs' into 'master'
Add functions to iterate over the expressions of existing rules See merge request rustwall/rustables!3
-rw-r--r--rustables/src/batch.rs7
-rw-r--r--rustables/src/chain.rs13
-rw-r--r--rustables/src/expr/bitwise.rs11
-rw-r--r--rustables/src/expr/cmp.rs179
-rw-r--r--rustables/src/expr/counter.rs39
-rw-r--r--rustables/src/expr/ct.rs26
-rw-r--r--rustables/src/expr/immediate.rs90
-rw-r--r--rustables/src/expr/log.rs78
-rw-r--r--rustables/src/expr/lookup.rs41
-rw-r--r--rustables/src/expr/masquerade.rs16
-rw-r--r--rustables/src/expr/meta.rs49
-rw-r--r--rustables/src/expr/mod.rs170
-rw-r--r--rustables/src/expr/nat.rs59
-rw-r--r--rustables/src/expr/payload.rs221
-rw-r--r--rustables/src/expr/register.rs34
-rw-r--r--rustables/src/expr/reject.rs99
-rw-r--r--rustables/src/expr/verdict.rs160
-rw-r--r--rustables/src/expr/wrapper.rs62
-rw-r--r--rustables/src/rule.rs132
-rw-r--r--rustables/src/set.rs14
-rw-r--r--rustables/src/table.rs13
21 files changed, 1188 insertions, 325 deletions
diff --git a/rustables/src/batch.rs b/rustables/src/batch.rs
index 2af1a7c..3cdd52b 100644
--- a/rustables/src/batch.rs
+++ b/rustables/src/batch.rs
@@ -10,6 +10,7 @@ use thiserror::Error;
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());
+#[cfg(feature = "query")]
/// Check if the kernel supports batched netlink messages to netfilter.
pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
match unsafe { sys::nftnl_batch_is_supported() } {
@@ -22,9 +23,9 @@ pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to
/// `nftnl_batch` in libnftnl.
pub struct Batch {
- batch: *mut sys::nftnl_batch,
- seq: u32,
- is_empty: bool,
+ pub(crate) batch: *mut sys::nftnl_batch,
+ pub(crate) seq: u32,
+ pub(crate) is_empty: bool,
}
impl Batch {
diff --git a/rustables/src/chain.rs b/rustables/src/chain.rs
index a5e732e..ac9c57d 100644
--- a/rustables/src/chain.rs
+++ b/rustables/src/chain.rs
@@ -1,7 +1,8 @@
use crate::{MsgType, Table};
use rustables_sys::{self as sys, libc};
+#[cfg(feature = "query")]
+use std::convert::TryFrom;
use std::{
- convert::TryFrom,
ffi::{c_void, CStr, CString},
fmt,
os::raw::c_char,
@@ -70,8 +71,8 @@ impl ChainType {
/// [`Rule`]: struct.Rule.html
/// [`set_hook`]: #method.set_hook
pub struct Chain {
- chain: *mut sys::nftnl_chain,
- table: Rc<Table>,
+ pub(crate) chain: *mut sys::nftnl_chain,
+ pub(crate) table: Rc<Table>,
}
impl Chain {
@@ -156,7 +157,11 @@ impl Chain {
pub fn get_name(&self) -> &CStr {
unsafe {
let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16);
- CStr::from_ptr(ptr)
+ if ptr.is_null() {
+ panic!("Impossible situation: retrieving the name of a chain failed")
+ } else {
+ CStr::from_ptr(ptr)
+ }
}
}
diff --git a/rustables/src/expr/bitwise.rs b/rustables/src/expr/bitwise.rs
index 1eb81ab..a5d9343 100644
--- a/rustables/src/expr/bitwise.rs
+++ b/rustables/src/expr/bitwise.rs
@@ -1,5 +1,4 @@
-use super::{Expression, Rule};
-use crate::expr::cmp::ToSlice;
+use super::{Expression, Rule, ToSlice};
use rustables_sys::{self as sys, libc};
use std::ffi::c_void;
use std::os::raw::c_char;
@@ -19,11 +18,13 @@ impl<M: ToSlice, X: ToSlice> Bitwise<M, X> {
}
impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> {
+ fn get_raw_name() -> *const c_char {
+ b"bitwise\0" as *const _ as *const c_char
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"bitwise\0" as *const _ as *const c_char
- ));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
let mask = self.mask.to_slice();
let xor = self.xor.to_slice();
diff --git a/rustables/src/expr/cmp.rs b/rustables/src/expr/cmp.rs
index 5c56492..747974d 100644
--- a/rustables/src/expr/cmp.rs
+++ b/rustables/src/expr/cmp.rs
@@ -1,15 +1,13 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule, ToSlice};
use rustables_sys::{self as sys, libc};
use std::{
borrow::Cow,
ffi::{c_void, CString},
- net::{IpAddr, Ipv4Addr, Ipv6Addr},
os::raw::c_char,
- slice,
};
/// Comparison operator.
-#[derive(Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CmpOp {
/// Equals.
Eq,
@@ -38,10 +36,24 @@ impl CmpOp {
Gte => libc::NFT_CMP_GTE as u32,
}
}
+
+ pub fn from_raw(val: u32) -> Result<Self, DeserializationError> {
+ use self::CmpOp::*;
+ match val as i32 {
+ libc::NFT_CMP_EQ => Ok(Eq),
+ libc::NFT_CMP_NEQ => Ok(Neq),
+ libc::NFT_CMP_LT => Ok(Lt),
+ libc::NFT_CMP_LTE => Ok(Lte),
+ libc::NFT_CMP_GT => Ok(Gt),
+ libc::NFT_CMP_GTE => Ok(Gte),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
}
/// Comparator expression. Allows comparing the content of the netfilter register with any value.
-pub struct Cmp<T: ToSlice> {
+#[derive(Debug, PartialEq)]
+pub struct Cmp<T> {
op: CmpOp,
data: T,
}
@@ -55,9 +67,13 @@ impl<T: ToSlice> Cmp<T> {
}
impl<T: ToSlice> Expression for Cmp<T> {
+ fn get_raw_name() -> *const c_char {
+ b"cmp\0" as *const _ as *const c_char
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(b"cmp\0" as *const _ as *const c_char));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
let data = self.data.to_slice();
trace!("Creating a cmp expr comparing with data {:?}", data);
@@ -71,7 +87,7 @@ impl<T: ToSlice> Expression for Cmp<T> {
sys::nftnl_expr_set(
expr,
sys::NFTNL_EXPR_CMP_DATA as u16,
- data.as_ref() as *const _ as *const c_void,
+ data.as_ptr() as *const c_void,
data.len() as u32,
);
@@ -80,6 +96,68 @@ impl<T: ToSlice> Expression for Cmp<T> {
}
}
+impl<const N: usize> Expression for Cmp<[u8; N]> {
+ fn get_raw_name() -> *const c_char {
+ Cmp::<u8>::get_raw_name()
+ }
+
+ /// The raw data contained inside `Cmp` expressions can only be deserialized to
+ /// arrays of bytes, to ensure that the memory layout of retrieved data cannot be
+ /// violated. It is your responsibility to provide the correct length of the byte
+ /// data. If the data size is invalid, you will get the error
+ /// `DeserializationError::InvalidDataSize`.
+ ///
+ /// Example (warning, no error checking!):
+ /// ```rust
+ /// use std::ffi::CString;
+ /// use std::net::Ipv4Addr;
+ /// use std::rc::Rc;
+ ///
+ /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table};
+ ///
+ /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet));
+ /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table));
+ /// let mut rule = Rule::new(chain);
+ /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16));
+ /// for expr in Rc::new(rule).get_exprs() {
+ /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap());
+ /// }
+ /// ```
+ /// These limitations occur because casting bytes to any type of the same size
+ /// as the raw input would be *extremely* dangerous in terms of memory safety.
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
+ unsafe {
+ let ref_len = std::mem::size_of::<[u8; N]>() as u32;
+ let mut data_len = 0;
+ let data = sys::nftnl_expr_get(
+ expr,
+ sys::NFTNL_EXPR_CMP_DATA as u16,
+ &mut data_len as *mut u32,
+ );
+
+ if data.is_null() {
+ return Err(DeserializationError::NullPointer);
+ } else if data_len != ref_len {
+ return Err(DeserializationError::InvalidDataSize);
+ }
+
+ let data = *(data as *const [u8; N]);
+
+ let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?;
+ Ok(Cmp { op, data })
+ }
+ }
+
+ // call to the other implementation to generate the expression
+ fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ Cmp {
+ data: &self.data as &[u8],
+ op: self.op,
+ }
+ .to_expr(rule)
+ }
+}
+
#[macro_export(local_inner_macros)]
macro_rules! nft_expr_cmp {
(@cmp_op ==) => {
@@ -105,93 +183,6 @@ macro_rules! nft_expr_cmp {
};
}
-/// A type that can be converted into a byte buffer.
-pub trait ToSlice {
- /// Returns the data this type represents.
- fn to_slice(&self) -> Cow<'_, [u8]>;
-}
-
-impl<'a> ToSlice for [u8; 0] {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Borrowed(&[])
- }
-}
-
-impl<'a> ToSlice for &'a [u8] {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Borrowed(self)
- }
-}
-
-impl<'a> ToSlice for &'a [u16] {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let ptr = self.as_ptr() as *const u8;
- let len = self.len() * 2;
- Cow::Borrowed(unsafe { slice::from_raw_parts(ptr, len) })
- }
-}
-
-impl ToSlice for IpAddr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- match *self {
- IpAddr::V4(ref addr) => addr.to_slice(),
- IpAddr::V6(ref addr) => addr.to_slice(),
- }
- }
-}
-
-impl ToSlice for Ipv4Addr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(self.octets().to_vec())
- }
-}
-
-impl ToSlice for Ipv6Addr {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(self.octets().to_vec())
- }
-}
-
-impl ToSlice for u8 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::Owned(vec![*self])
- }
-}
-
-impl ToSlice for u16 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = (*self & 0x00ff) as u8;
- let b1 = (*self >> 8) as u8;
- Cow::Owned(vec![b0, b1])
- }
-}
-
-impl ToSlice for u32 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = *self as u8;
- let b1 = (*self >> 8) as u8;
- let b2 = (*self >> 16) as u8;
- let b3 = (*self >> 24) as u8;
- Cow::Owned(vec![b0, b1, b2, b3])
- }
-}
-
-impl ToSlice for i32 {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- let b0 = *self as u8;
- let b1 = (*self >> 8) as u8;
- let b2 = (*self >> 16) as u8;
- let b3 = (*self >> 24) as u8;
- Cow::Owned(vec![b0, b1, b2, b3])
- }
-}
-
-impl<'a> ToSlice for &'a str {
- fn to_slice(&self) -> Cow<'_, [u8]> {
- Cow::from(self.as_bytes())
- }
-}
-
/// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please
/// note that it is faster to check interface index than name.
///
diff --git a/rustables/src/expr/counter.rs b/rustables/src/expr/counter.rs
index c2a0b5d..099e7fa 100644
--- a/rustables/src/expr/counter.rs
+++ b/rustables/src/expr/counter.rs
@@ -1,13 +1,46 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys as sys;
use std::os::raw::c_char;
/// A counter expression adds a counter to the rule that is incremented to count number of packets
/// and number of bytes for all packets that has matched the rule.
-pub struct Counter;
+#[derive(Debug, PartialEq)]
+pub struct Counter {
+ pub nb_bytes: u64,
+ pub nb_packets: u64,
+}
+
+impl Counter {
+ pub fn new() -> Self {
+ Self {
+ nb_bytes: 0,
+ nb_packets: 0,
+ }
+ }
+}
impl Expression for Counter {
+ fn get_raw_name() -> *const c_char {
+ b"counter\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
+ unsafe {
+ let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16);
+ let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16);
+ Ok(Counter {
+ nb_bytes,
+ nb_packets,
+ })
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- try_alloc!(unsafe { sys::nftnl_expr_alloc(b"counter\0" as *const _ as *const c_char) })
+ unsafe {
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
+ sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes);
+ sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets);
+ expr
+ }
}
}
diff --git a/rustables/src/expr/ct.rs b/rustables/src/expr/ct.rs
index c0349ab..001aef8 100644
--- a/rustables/src/expr/ct.rs
+++ b/rustables/src/expr/ct.rs
@@ -1,4 +1,4 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys::{self as sys, libc};
use std::os::raw::c_char;
@@ -27,9 +27,31 @@ impl Conntrack {
}
impl Expression for Conntrack {
+ fn get_raw_name() -> *const c_char {
+ b"ct\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ unsafe {
+ let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16);
+ let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16);
+
+ match ct_key as i32 {
+ libc::NFT_CT_STATE => Ok(Conntrack::State),
+ libc::NFT_CT_MARK => Ok(Conntrack::Mark {
+ set: ct_sreg_is_set,
+ }),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(b"ct\0" as *const _ as *const c_char));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
if let Conntrack::Mark { set: true } = self {
sys::nftnl_expr_set_u32(
diff --git a/rustables/src/expr/immediate.rs b/rustables/src/expr/immediate.rs
index e5ccc2a..ff4ad04 100644
--- a/rustables/src/expr/immediate.rs
+++ b/rustables/src/expr/immediate.rs
@@ -1,11 +1,10 @@
-use super::{Expression, Register, Rule};
+use super::{DeserializationError, Expression, Register, Rule, ToSlice};
use rustables_sys as sys;
use std::ffi::c_void;
-use std::mem::size_of_val;
use std::os::raw::c_char;
/// An immediate expression. Used to set immediate data.
-/// Verdicts are handled separately by [Verdict].
+/// Verdicts are handled separately by [crate::expr::Verdict].
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Immediate<T> {
pub data: T,
@@ -18,12 +17,14 @@ impl<T> Immediate<T> {
}
}
-impl<T> Expression for Immediate<T> {
+impl<T: ToSlice> Expression for Immediate<T> {
+ fn get_raw_name() -> *const c_char {
+ b"immediate\0" as *const _ as *const c_char
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"immediate\0" as *const _ as *const c_char
- ));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
sys::nftnl_expr_set_u32(
expr,
@@ -31,11 +32,12 @@ impl<T> Expression for Immediate<T> {
self.register.to_raw(),
);
+ let data = self.data.to_slice();
sys::nftnl_expr_set(
expr,
sys::NFTNL_EXPR_IMM_DATA as u16,
- &self.data as *const _ as *const c_void,
- size_of_val(&self.data) as u32,
+ data.as_ptr() as *const c_void,
+ data.len() as u32,
);
expr
@@ -43,6 +45,76 @@ impl<T> Expression for Immediate<T> {
}
}
+impl<const N: usize> Expression for Immediate<[u8; N]> {
+ fn get_raw_name() -> *const c_char {
+ Immediate::<u8>::get_raw_name()
+ }
+
+ /// The raw data contained inside `Immediate` expressions can only be deserialized to
+ /// arrays of bytes, to ensure that the memory layout of retrieved data cannot be
+ /// violated. It is your responsibility to provide the correct length of the byte
+ /// data. If the data size is invalid, you will get the error
+ /// `DeserializationError::InvalidDataSize`.
+ ///
+ /// Example (warning, no error checking!):
+ /// ```rust
+ /// use std::ffi::CString;
+ /// use std::net::Ipv4Addr;
+ /// use std::rc::Rc;
+ ///
+ /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table};
+ ///
+ /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet));
+ /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table));
+ /// let mut rule = Rule::new(chain);
+ /// rule.add_expr(&Immediate::new(42u8, Register::Reg1));
+ /// for expr in Rc::new(rule).get_exprs() {
+ /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap());
+ /// }
+ /// ```
+ /// These limitations occur because casting bytes to any type of the same size
+ /// as the raw input would be *extremely* dangerous in terms of memory safety.
+ // As casting bytes to any type of the same size as the input would
+ // be *extremely* dangerous in terms of memory safety,
+ // rustables only accept to deserialize expressions with variable-size data
+ // to arrays of bytes, so that the memory layout cannot be invalid.
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
+ unsafe {
+ let ref_len = std::mem::size_of::<[u8; N]>() as u32;
+ let mut data_len = 0;
+ let data = sys::nftnl_expr_get(
+ expr,
+ sys::NFTNL_EXPR_IMM_DATA as u16,
+ &mut data_len as *mut u32,
+ );
+
+ if data.is_null() {
+ return Err(DeserializationError::NullPointer);
+ } else if data_len != ref_len {
+ return Err(DeserializationError::InvalidDataSize);
+ }
+
+ let data = *(data as *const [u8; N]);
+
+ let register = Register::from_raw(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_IMM_DREG as u16,
+ ))?;
+
+ Ok(Immediate { data, register })
+ }
+ }
+
+ // call to the other implementation to generate the expression
+ fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ Immediate {
+ register: self.register,
+ data: &self.data as &[u8],
+ }
+ .to_expr(rule)
+ }
+}
+
#[macro_export]
macro_rules! nft_expr_immediate {
(data $value:expr) => {
diff --git a/rustables/src/expr/log.rs b/rustables/src/expr/log.rs
index aa7a8b7..db96ba9 100644
--- a/rustables/src/expr/log.rs
+++ b/rustables/src/expr/log.rs
@@ -1,34 +1,54 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys as sys;
+use std::ffi::{CStr, CString};
use std::os::raw::c_char;
-use std::ffi::CString;
use thiserror::Error;
/// A Log expression will log all packets that match the rule.
+#[derive(Debug, PartialEq)]
pub struct Log {
pub group: Option<LogGroup>,
- pub prefix: Option<LogPrefix>
+ pub prefix: Option<LogPrefix>,
}
impl Expression for Log {
- fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
+ fn get_raw_name() -> *const sys::libc::c_char {
+ b"log\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"log\0" as *const _ as *const c_char
- ));
- if let Some(log_group) = self.group {
- sys::nftnl_expr_set_u32(
+ let mut group = None;
+ if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) {
+ group = Some(LogGroup(sys::nftnl_expr_get_u32(
expr,
sys::NFTNL_EXPR_LOG_GROUP as u16,
- log_group.0 as u32,
- );
+ ) as u16));
+ }
+ let mut prefix = None;
+ if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) {
+ let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16);
+ if raw_prefix.is_null() {
+ return Err(DeserializationError::NullPointer);
+ } else {
+ prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned()));
+ }
+ }
+ Ok(Log { group, prefix })
+ }
+ }
+
+ fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
+ unsafe {
+ let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char));
+ if let Some(log_group) = self.group {
+ sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32);
};
if let Some(LogPrefix(prefix)) = &self.prefix {
- sys::nftnl_expr_set_str(
- expr,
- sys::NFTNL_EXPR_LOG_PREFIX as u16,
- prefix.as_ptr()
- );
+ sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr());
};
expr
@@ -41,8 +61,7 @@ pub enum LogPrefixError {
#[error("The log prefix string is more than 128 characters long")]
TooLongPrefix,
#[error("The log prefix string contains an invalid Nul character.")]
- PrefixContainsANul(#[from] std::ffi::NulError)
-
+ PrefixContainsANul(#[from] std::ffi::NulError),
}
/// The NFLOG group that will be assigned to each log line.
@@ -58,25 +77,36 @@ impl LogPrefix {
/// that LogPrefix should not be more than 127 characters long.
pub fn new(prefix: &str) -> Result<Self, LogPrefixError> {
if prefix.chars().count() > 127 {
- return Err(LogPrefixError::TooLongPrefix)
+ return Err(LogPrefixError::TooLongPrefix);
}
Ok(LogPrefix(CString::new(prefix)?))
}
}
-
#[macro_export]
macro_rules! nft_expr_log {
(group $group:ident prefix $prefix:expr) => {
- $crate::expr::Log { group: $group, prefix: $prefix }
+ $crate::expr::Log {
+ group: $group,
+ prefix: $prefix,
+ }
};
(prefix $prefix:expr) => {
- $crate::expr::Log { group: None, prefix: $prefix }
+ $crate::expr::Log {
+ group: None,
+ prefix: $prefix,
+ }
};
(group $group:ident) => {
- $crate::expr::Log { group: $group, prefix: None }
+ $crate::expr::Log {
+ group: $group,
+ prefix: None,
+ }
};
() => {
- $crate::expr::Log { group: None, prefix: None }
+ $crate::expr::Log {
+ group: None,
+ prefix: None,
+ }
};
}
diff --git a/rustables/src/expr/lookup.rs b/rustables/src/expr/lookup.rs
index ac22440..7796b29 100644
--- a/rustables/src/expr/lookup.rs
+++ b/rustables/src/expr/lookup.rs
@@ -1,29 +1,52 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use crate::set::Set;
use rustables_sys::{self as sys, libc};
-use std::ffi::CString;
+use std::ffi::{CStr, CString};
use std::os::raw::c_char;
+#[derive(Debug, PartialEq)]
pub struct Lookup {
set_name: CString,
set_id: u32,
}
impl Lookup {
- pub fn new<K>(set: &Set<'_, K>) -> Self {
- Lookup {
- set_name: set.get_name().to_owned(),
+ /// Creates a new lookup entry.
+ /// May return None if the set have no name.
+ pub fn new<K>(set: &Set<'_, K>) -> Option<Self> {
+ set.get_name().map(|set_name| Lookup {
+ set_name: set_name.to_owned(),
set_id: set.get_id(),
- }
+ })
}
}
impl Expression for Lookup {
+ fn get_raw_name() -> *const libc::c_char {
+ b"lookup\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ unsafe {
+ let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16);
+ let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16);
+
+ if set_name.is_null() {
+ return Err(DeserializationError::NullPointer);
+ }
+
+ let set_name = CStr::from_ptr(set_name).to_owned();
+
+ Ok(Lookup { set_id, set_name })
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"lookup\0" as *const _ as *const c_char
- ));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
sys::nftnl_expr_set_u32(
expr,
diff --git a/rustables/src/expr/masquerade.rs b/rustables/src/expr/masquerade.rs
index 66e9e0e..40565d5 100644
--- a/rustables/src/expr/masquerade.rs
+++ b/rustables/src/expr/masquerade.rs
@@ -1,12 +1,24 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys as sys;
use std::os::raw::c_char;
/// Sets the source IP to that of the output interface.
+#[derive(Debug, PartialEq)]
pub struct Masquerade;
impl Expression for Masquerade {
+ fn get_raw_name() -> *const sys::libc::c_char {
+ b"masq\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ Ok(Masquerade)
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- try_alloc!(unsafe { sys::nftnl_expr_alloc(b"masq\0" as *const _ as *const c_char) })
+ try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) })
}
}
diff --git a/rustables/src/expr/meta.rs b/rustables/src/expr/meta.rs
index a91cb27..199f3d3 100644
--- a/rustables/src/expr/meta.rs
+++ b/rustables/src/expr/meta.rs
@@ -1,8 +1,9 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys::{self as sys, libc};
use std::os::raw::c_char;
/// A meta expression refers to meta data associated with a packet.
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Meta {
/// Packet ethertype protocol (skb->protocol), invalid in OUTPUT.
@@ -56,14 +57,54 @@ impl Meta {
PRandom => libc::NFT_META_PRANDOM as u32,
}
}
+
+ fn from_raw(val: u32) -> Result<Self, DeserializationError> {
+ match val as i32 {
+ libc::NFT_META_PROTOCOL => Ok(Self::Protocol),
+ libc::NFT_META_MARK => Ok(Self::Mark { set: false }),
+ libc::NFT_META_IIF => Ok(Self::Iif),
+ libc::NFT_META_OIF => Ok(Self::Oif),
+ libc::NFT_META_IIFNAME => Ok(Self::IifName),
+ libc::NFT_META_OIFNAME => Ok(Self::OifName),
+ libc::NFT_META_IIFTYPE => Ok(Self::IifType),
+ libc::NFT_META_OIFTYPE => Ok(Self::OifType),
+ libc::NFT_META_SKUID => Ok(Self::SkUid),
+ libc::NFT_META_SKGID => Ok(Self::SkGid),
+ libc::NFT_META_NFPROTO => Ok(Self::NfProto),
+ libc::NFT_META_L4PROTO => Ok(Self::L4Proto),
+ libc::NFT_META_CGROUP => Ok(Self::Cgroup),
+ libc::NFT_META_PRANDOM => Ok(Self::PRandom),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
}
impl Expression for Meta {
+ fn get_raw_name() -> *const libc::c_char {
+ b"meta\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ unsafe {
+ let mut ret = Self::from_raw(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_META_KEY as u16,
+ ))?;
+
+ if let Self::Mark { ref mut set } = ret {
+ *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16);
+ }
+
+ Ok(ret)
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"meta\0" as *const _ as *const c_char
- ));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
if let Meta::Mark { set: true } = self {
sys::nftnl_expr_set_u32(
diff --git a/rustables/src/expr/mod.rs b/rustables/src/expr/mod.rs
index 99ea44b..b20a752 100644
--- a/rustables/src/expr/mod.rs
+++ b/rustables/src/expr/mod.rs
@@ -3,32 +3,14 @@
//!
//! [`Rule`]: struct.Rule.html
+use std::borrow::Cow;
+use std::net::IpAddr;
+use std::net::Ipv4Addr;
+use std::net::Ipv6Addr;
+
use super::rule::Rule;
use rustables_sys::{self as sys, libc};
-
-/// Trait for every safe wrapper of an nftables expression.
-pub trait Expression {
- /// Allocates and returns the low level `nftnl_expr` representation of this expression.
- /// The caller to this method is responsible for freeing the expression.
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr;
-}
-
-/// A netfilter data register. The expressions store and read data to and from these
-/// when evaluating rule statements.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(i32)]
-pub enum Register {
- Reg1 = libc::NFT_REG_1,
- Reg2 = libc::NFT_REG_2,
- Reg3 = libc::NFT_REG_3,
- Reg4 = libc::NFT_REG_4,
-}
-
-impl Register {
- pub fn to_raw(self) -> u32 {
- self as u32
- }
-}
+use thiserror::Error;
mod bitwise;
pub use self::bitwise::*;
@@ -63,9 +45,147 @@ pub use self::nat::*;
mod payload;
pub use self::payload::*;
+mod reject;
+pub use self::reject::{IcmpCode, Reject};
+
+mod register;
+pub use self::register::Register;
+
mod verdict;
pub use self::verdict::*;
+mod wrapper;
+pub use self::wrapper::ExpressionWrapper;
+
+#[derive(Debug, Error)]
+pub enum DeserializationError {
+ #[error("The expected expression type doesn't match the name of the raw expression")]
+ /// The expected expression type doesn't match the name of the raw expression
+ InvalidExpressionKind,
+
+ #[error("Deserializing the requested type isn't implemented yet")]
+ /// Deserializing the requested type isn't implemented yet
+ NotImplemented,
+
+ #[error("The expression value cannot be deserialized to the requested type")]
+ /// The expression value cannot be deserialized to the requested type
+ InvalidValue,
+
+ #[error("A pointer was null while a non-null pointer was expected")]
+ /// A pointer was null while a non-null pointer was expected
+ NullPointer,
+
+ #[error(
+ "The size of a raw value was incoherent with the expected type of the deserialized value"
+ )]
+ /// The size of a raw value was incoherent with the expected type of the deserialized value
+ InvalidDataSize,
+
+ #[error(transparent)]
+ /// Couldn't find a matching protocol
+ InvalidProtolFamily(#[from] super::InvalidProtocolFamily),
+}
+
+/// Trait for every safe wrapper of an nftables expression.
+pub trait Expression {
+ /// Returns the raw name used by nftables to identify the rule.
+ fn get_raw_name() -> *const libc::c_char;
+
+ /// Try to parse the expression from a raw nftables expression,
+ /// returning a [DeserializationError] if the attempted parsing failed.
+ fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ Err(DeserializationError::NotImplemented)
+ }
+
+ /// Allocates and returns the low level `nftnl_expr` representation of this expression.
+ /// The caller to this method is responsible for freeing the expression.
+ fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr;
+}
+
+/// A type that can be converted into a byte buffer.
+pub trait ToSlice {
+ /// Returns the data this type represents.
+ fn to_slice(&self) -> Cow<'_, [u8]>;
+}
+
+impl<'a> ToSlice for &'a [u8] {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ Cow::Borrowed(self)
+ }
+}
+
+impl<'a> ToSlice for &'a [u16] {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ let ptr = self.as_ptr() as *const u8;
+ let len = self.len() * 2;
+ Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) })
+ }
+}
+
+impl ToSlice for IpAddr {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ match *self {
+ IpAddr::V4(ref addr) => addr.to_slice(),
+ IpAddr::V6(ref addr) => addr.to_slice(),
+ }
+ }
+}
+
+impl ToSlice for Ipv4Addr {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ Cow::Owned(self.octets().to_vec())
+ }
+}
+
+impl ToSlice for Ipv6Addr {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ Cow::Owned(self.octets().to_vec())
+ }
+}
+
+impl ToSlice for u8 {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ Cow::Owned(vec![*self])
+ }
+}
+
+impl ToSlice for u16 {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ let b0 = (*self & 0x00ff) as u8;
+ let b1 = (*self >> 8) as u8;
+ Cow::Owned(vec![b0, b1])
+ }
+}
+
+impl ToSlice for u32 {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ let b0 = *self as u8;
+ let b1 = (*self >> 8) as u8;
+ let b2 = (*self >> 16) as u8;
+ let b3 = (*self >> 24) as u8;
+ Cow::Owned(vec![b0, b1, b2, b3])
+ }
+}
+
+impl ToSlice for i32 {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ let b0 = *self as u8;
+ let b1 = (*self >> 8) as u8;
+ let b2 = (*self >> 16) as u8;
+ let b3 = (*self >> 24) as u8;
+ Cow::Owned(vec![b0, b1, b2, b3])
+ }
+}
+
+impl<'a> ToSlice for &'a str {
+ fn to_slice(&self) -> Cow<'_, [u8]> {
+ Cow::from(self.as_bytes())
+ }
+}
+
#[macro_export(local_inner_macros)]
macro_rules! nft_expr {
(bitwise mask $mask:expr,xor $xor:expr) => {
@@ -75,7 +195,7 @@ macro_rules! nft_expr {
nft_expr_cmp!($op $data)
};
(counter) => {
- $crate::expr::Counter
+ $crate::expr::Counter { nb_bytes: 0, nb_packets: 0}
};
(ct $key:ident set) => {
nft_expr_ct!($key set)
diff --git a/rustables/src/expr/nat.rs b/rustables/src/expr/nat.rs
index d60e5ea..51f439f 100644
--- a/rustables/src/expr/nat.rs
+++ b/rustables/src/expr/nat.rs
@@ -1,7 +1,7 @@
-use super::{Expression, Register, Rule};
+use super::{DeserializationError, Expression, Register, Rule};
use crate::ProtoFamily;
use rustables_sys::{self as sys, libc};
-use std::os::raw::c_char;
+use std::{convert::TryFrom, os::raw::c_char};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(i32)]
@@ -12,8 +12,19 @@ pub enum NatType {
DNat = libc::NFT_NAT_DNAT,
}
+impl NatType {
+ fn from_raw(val: u32) -> Result<Self, DeserializationError> {
+ match val as i32 {
+ libc::NFT_NAT_SNAT => Ok(NatType::SNat),
+ libc::NFT_NAT_DNAT => Ok(NatType::DNat),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
+}
+
/// A source or destination NAT statement. Modifies the source or destination address
/// (and possibly port) of packets.
+#[derive(Debug, PartialEq)]
pub struct Nat {
pub nat_type: NatType,
pub family: ProtoFamily,
@@ -22,9 +33,49 @@ pub struct Nat {
}
impl Expression for Nat {
+ fn get_raw_name() -> *const libc::c_char {
+ b"nat\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ unsafe {
+ let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_NAT_TYPE as u16,
+ ))?;
+
+ let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_NAT_FAMILY as u16,
+ ) as i32)?;
+
+ let ip_register = Register::from_raw(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16,
+ ))?;
+
+ let mut port_register = None;
+ if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) {
+ port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32(
+ expr,
+ sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16,
+ ))?);
+ }
+
+ Ok(Nat {
+ ip_register,
+ nat_type,
+ family,
+ port_register,
+ })
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
- let expr =
- try_alloc!(unsafe { sys::nftnl_expr_alloc(b"nat\0" as *const _ as *const c_char) });
+ let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) });
unsafe {
sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32);
diff --git a/rustables/src/expr/payload.rs b/rustables/src/expr/payload.rs
index 2da4e1f..25a71ad 100644
--- a/rustables/src/expr/payload.rs
+++ b/rustables/src/expr/payload.rs
@@ -1,4 +1,4 @@
-use super::{Expression, Rule};
+use super::{DeserializationError, Expression, Rule};
use rustables_sys::{self as sys, libc};
use std::os::raw::c_char;
@@ -8,7 +8,7 @@ trait HeaderField {
}
/// Payload expressions refer to data from the packet's payload.
-#[derive(Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Payload {
LinkLayer(LLHeaderField),
Network(NetworkHeaderField),
@@ -16,41 +16,100 @@ pub enum Payload {
}
impl Payload {
- fn base(self) -> u32 {
+ pub fn build(&self) -> RawPayload {
+ match *self {
+ Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData {
+ offset: f.offset(),
+ len: f.len(),
+ }),
+ Payload::Network(ref f) => RawPayload::Network(RawPayloadData {
+ offset: f.offset(),
+ len: f.len(),
+ }),
+ Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData {
+ offset: f.offset(),
+ len: f.offset(),
+ }),
+ }
+ }
+}
+
+impl Expression for Payload {
+ fn get_raw_name() -> *const libc::c_char {
+ RawPayload::get_raw_name()
+ }
+
+ fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ self.build().to_expr(rule)
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+pub struct RawPayloadData {
+ offset: u32,
+ len: u32,
+}
+
+/// Because deserializing a `Payload` expression is not possible (there is not enough information
+/// in the expression itself, this enum should be used to deserialize payloads.
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+pub enum RawPayload {
+ LinkLayer(RawPayloadData),
+ Network(RawPayloadData),
+ Transport(RawPayloadData),
+}
+
+impl RawPayload {
+ fn base(&self) -> u32 {
match self {
- Payload::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32,
- Payload::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32,
- Payload::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32,
+ Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32,
+ Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32,
+ Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32,
}
}
}
-impl HeaderField for Payload {
+impl HeaderField for RawPayload {
fn offset(&self) -> u32 {
- use self::Payload::*;
- match *self {
- LinkLayer(ref f) => f.offset(),
- Network(ref f) => f.offset(),
- Transport(ref f) => f.offset(),
+ match self {
+ Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset,
}
}
fn len(&self) -> u32 {
- use self::Payload::*;
- match *self {
- LinkLayer(ref f) => f.len(),
- Network(ref f) => f.len(),
- Transport(ref f) => f.len(),
+ match self {
+ Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len,
}
}
}
-impl Expression for Payload {
+impl Expression for RawPayload {
+ fn get_raw_name() -> *const libc::c_char {
+ b"payload\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
+ unsafe {
+ let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16);
+ let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16);
+ let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16);
+ match base as i32 {
+ libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })),
+ libc::NFT_PAYLOAD_NETWORK_HEADER => {
+ Ok(Self::Network(RawPayloadData { offset, len }))
+ }
+ libc::NFT_PAYLOAD_TRANSPORT_HEADER => {
+ Ok(Self::Transport(RawPayloadData { offset, len }))
+ }
+
+ _ => return Err(DeserializationError::InvalidValue),
+ }
+ }
+ }
+
fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
unsafe {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"payload\0" as *const _ as *const c_char
- ));
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base());
sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset());
@@ -66,7 +125,7 @@ impl Expression for Payload {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum LLHeaderField {
Daddr,
@@ -94,7 +153,24 @@ impl HeaderField for LLHeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+impl LLHeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 0 && len == 6 {
+ Ok(Self::Daddr)
+ } else if off == 6 && len == 6 {
+ Ok(Self::Saddr)
+ } else if off == 12 && len == 2 {
+ Ok(Self::EtherType)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum NetworkHeaderField {
Ipv4(Ipv4HeaderField),
Ipv6(Ipv6HeaderField),
@@ -118,7 +194,7 @@ impl HeaderField for NetworkHeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Ipv4HeaderField {
Ttl,
@@ -149,7 +225,26 @@ impl HeaderField for Ipv4HeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+impl Ipv4HeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 8 && len == 1 {
+ Ok(Self::Ttl)
+ } else if off == 9 && len == 1 {
+ Ok(Self::Protocol)
+ } else if off == 12 && len == 4 {
+ Ok(Self::Saddr)
+ } else if off == 16 && len == 4 {
+ Ok(Self::Daddr)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Ipv6HeaderField {
NextHeader,
@@ -180,7 +275,26 @@ impl HeaderField for Ipv6HeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+impl Ipv6HeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 6 && len == 1 {
+ Ok(Self::NextHeader)
+ } else if off == 7 && len == 1 {
+ Ok(Self::HopLimit)
+ } else if off == 8 && len == 16 {
+ Ok(Self::Saddr)
+ } else if off == 24 && len == 16 {
+ Ok(Self::Daddr)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum TransportHeaderField {
Tcp(TcpHeaderField),
@@ -208,7 +322,7 @@ impl HeaderField for TransportHeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum TcpHeaderField {
Sport,
@@ -233,7 +347,22 @@ impl HeaderField for TcpHeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+impl TcpHeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 0 && len == 2 {
+ Ok(Self::Sport)
+ } else if off == 2 && len == 2 {
+ Ok(Self::Dport)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum UdpHeaderField {
Sport,
@@ -261,7 +390,24 @@ impl HeaderField for UdpHeaderField {
}
}
-#[derive(Copy, Clone, Eq, PartialEq)]
+impl UdpHeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 0 && len == 2 {
+ Ok(Self::Sport)
+ } else if off == 2 && len == 2 {
+ Ok(Self::Dport)
+ } else if off == 4 && len == 2 {
+ Ok(Self::Len)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Icmpv6HeaderField {
Type,
@@ -289,6 +435,23 @@ impl HeaderField for Icmpv6HeaderField {
}
}
+impl Icmpv6HeaderField {
+ pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> {
+ let off = data.offset;
+ let len = data.len;
+
+ if off == 0 && len == 1 {
+ Ok(Self::Type)
+ } else if off == 1 && len == 1 {
+ Ok(Self::Code)
+ } else if off == 2 && len == 2 {
+ Ok(Self::Checksum)
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+}
+
#[macro_export(local_inner_macros)]
macro_rules! nft_expr_payload {
(@ipv4_field ttl) => {
diff --git a/rustables/src/expr/register.rs b/rustables/src/expr/register.rs
new file mode 100644
index 0000000..2cfcc3b
--- /dev/null
+++ b/rustables/src/expr/register.rs
@@ -0,0 +1,34 @@
+use std::fmt::Debug;
+
+use rustables_sys::libc;
+
+use super::DeserializationError;
+
+/// A netfilter data register. The expressions store and read data to and from these
+/// when evaluating rule statements.
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[repr(i32)]
+pub enum Register {
+ Verdict = libc::NFT_REG_VERDICT,
+ Reg1 = libc::NFT_REG_1,
+ Reg2 = libc::NFT_REG_2,
+ Reg3 = libc::NFT_REG_3,
+ Reg4 = libc::NFT_REG_4,
+}
+
+impl Register {
+ pub fn to_raw(self) -> u32 {
+ self as u32
+ }
+
+ pub fn from_raw(val: u32) -> Result<Self, DeserializationError> {
+ match val as i32 {
+ libc::NFT_REG_VERDICT => Ok(Self::Verdict),
+ libc::NFT_REG_1 => Ok(Self::Reg1),
+ libc::NFT_REG_2 => Ok(Self::Reg2),
+ libc::NFT_REG_3 => Ok(Self::Reg3),
+ libc::NFT_REG_4 => Ok(Self::Reg4),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
+}
diff --git a/rustables/src/expr/reject.rs b/rustables/src/expr/reject.rs
new file mode 100644
index 0000000..550a287
--- /dev/null
+++ b/rustables/src/expr/reject.rs
@@ -0,0 +1,99 @@
+use super::{DeserializationError, Expression, Rule};
+use crate::ProtoFamily;
+use rustables_sys::{
+ self as sys,
+ libc::{self, c_char},
+};
+
+/// A reject expression that defines the type of rejection message sent
+/// when discarding a packet.
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
+pub enum Reject {
+ /// Return an ICMP unreachable packet
+ Icmp(IcmpCode),
+ /// Reject by sending a TCP RST packet
+ TcpRst,
+}
+
+impl Reject {
+ fn to_raw(&self, family: ProtoFamily) -> u32 {
+ use libc::*;
+ let value = match *self {
+ Self::Icmp(..) => match family {
+ ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH,
+ _ => NFT_REJECT_ICMP_UNREACH,
+ },
+ Self::TcpRst => NFT_REJECT_TCP_RST,
+ };
+ value as u32
+ }
+}
+
+impl Expression for Reject {
+ fn get_raw_name() -> *const libc::c_char {
+ b"reject\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError>
+ where
+ Self: Sized,
+ {
+ unsafe {
+ if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16)
+ == libc::NFT_REJECT_TCP_RST as u32
+ {
+ Ok(Self::TcpRst)
+ } else {
+ Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8(
+ expr,
+ sys::NFTNL_EXPR_REJECT_CODE as u16,
+ ))?))
+ }
+ }
+ }
+
+ fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ let family = rule.get_chain().get_table().get_family();
+
+ unsafe {
+ let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name()));
+
+ sys::nftnl_expr_set_u32(
+ expr,
+ sys::NFTNL_EXPR_REJECT_TYPE as u16,
+ self.to_raw(family),
+ );
+
+ let reject_code = match *self {
+ Reject::Icmp(code) => code as u8,
+ Reject::TcpRst => 0,
+ };
+
+ sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code);
+
+ expr
+ }
+ }
+}
+
+/// An ICMP reject code.
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
+#[repr(u8)]
+pub enum IcmpCode {
+ NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8,
+ PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8,
+ HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8,
+ AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8,
+}
+
+impl IcmpCode {
+ fn from_raw(code: u8) -> Result<Self, DeserializationError> {
+ match code as i32 {
+ libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute),
+ libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach),
+ libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach),
+ libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
+}
diff --git a/rustables/src/expr/verdict.rs b/rustables/src/expr/verdict.rs
index dc006bb..6a6b802 100644
--- a/rustables/src/expr/verdict.rs
+++ b/rustables/src/expr/verdict.rs
@@ -1,5 +1,4 @@
-use super::{Expression, Rule};
-use crate::ProtoFamily;
+use super::{DeserializationError, Expression, Rule};
use rustables_sys::{
self as sys,
libc::{self, c_char},
@@ -14,8 +13,6 @@ pub enum Verdict {
Drop,
/// Accept the packet and let it pass.
Accept,
- /// Reject the packet and return a message.
- Reject(RejectionType),
Queue,
Continue,
Break,
@@ -28,88 +25,7 @@ pub enum Verdict {
Return,
}
-/// The type of rejection message sent by the Reject verdict.
-#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
-pub enum RejectionType {
- /// Return an ICMP unreachable packet
- Icmp(IcmpCode),
- /// Reject by sending a TCP RST packet
- TcpRst,
-}
-
-impl RejectionType {
- fn to_raw(&self, family: ProtoFamily) -> u32 {
- use libc::*;
- let value = match *self {
- RejectionType::Icmp(..) => match family {
- ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH,
- _ => NFT_REJECT_ICMP_UNREACH,
- },
- RejectionType::TcpRst => NFT_REJECT_TCP_RST,
- };
- value as u32
- }
-}
-
-/// An ICMP reject code.
-#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
-#[repr(u8)]
-pub enum IcmpCode {
- NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8,
- PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8,
- HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8,
- AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8,
-}
-
impl Verdict {
- unsafe fn to_immediate_expr(&self, immediate_const: i32) -> *mut sys::nftnl_expr {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"immediate\0" as *const _ as *const c_char
- ));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_IMM_DREG as u16,
- libc::NFT_REG_VERDICT as u32,
- );
-
- if let Some(chain) = self.chain() {
- sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr());
- }
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_IMM_VERDICT as u16,
- immediate_const as u32,
- );
-
- expr
- }
-
- unsafe fn to_reject_expr(
- &self,
- reject_type: RejectionType,
- family: ProtoFamily,
- ) -> *mut sys::nftnl_expr {
- let expr = try_alloc!(sys::nftnl_expr_alloc(
- b"reject\0" as *const _ as *const c_char
- ));
-
- sys::nftnl_expr_set_u32(
- expr,
- sys::NFTNL_EXPR_REJECT_TYPE as u16,
- reject_type.to_raw(family),
- );
-
- let reject_code = match reject_type {
- RejectionType::Icmp(code) => code as u8,
- RejectionType::TcpRst => 0,
- };
-
- sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code);
-
- expr
- }
-
fn chain(&self) -> Option<&CStr> {
match *self {
Verdict::Jump { ref chain } => Some(chain.as_c_str()),
@@ -120,7 +36,51 @@ impl Verdict {
}
impl Expression for Verdict {
- fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr {
+ fn get_raw_name() -> *const libc::c_char {
+ b"immediate\0" as *const _ as *const c_char
+ }
+
+ fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> {
+ unsafe {
+ let mut chain = None;
+ if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) {
+ let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16);
+
+ if raw_chain.is_null() {
+ return Err(DeserializationError::NullPointer);
+ }
+ chain = Some(CStr::from_ptr(raw_chain).to_owned());
+ }
+
+ let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16);
+
+ match verdict as i32 {
+ libc::NF_DROP => Ok(Verdict::Drop),
+ libc::NF_ACCEPT => Ok(Verdict::Accept),
+ libc::NF_QUEUE => Ok(Verdict::Queue),
+ libc::NFT_CONTINUE => Ok(Verdict::Continue),
+ libc::NFT_BREAK => Ok(Verdict::Break),
+ libc::NFT_JUMP => {
+ if let Some(chain) = chain {
+ Ok(Verdict::Jump { chain })
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+ libc::NFT_GOTO => {
+ if let Some(chain) = chain {
+ Ok(Verdict::Goto { chain })
+ } else {
+ Err(DeserializationError::InvalidValue)
+ }
+ }
+ libc::NFT_RETURN => Ok(Verdict::Return),
+ _ => Err(DeserializationError::InvalidValue),
+ }
+ }
+ }
+
+ fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr {
let immediate_const = match *self {
Verdict::Drop => libc::NF_DROP,
Verdict::Accept => libc::NF_ACCEPT,
@@ -130,13 +90,29 @@ impl Expression for Verdict {
Verdict::Jump { .. } => libc::NFT_JUMP,
Verdict::Goto { .. } => libc::NFT_GOTO,
Verdict::Return => libc::NFT_RETURN,
- Verdict::Reject(reject_type) => {
- return unsafe {
- self.to_reject_expr(reject_type, rule.get_chain().get_table().get_family())
- }
- }
};
- unsafe { self.to_immediate_expr(immediate_const) }
+ unsafe {
+ let expr = try_alloc!(sys::nftnl_expr_alloc(
+ b"immediate\0" as *const _ as *const c_char
+ ));
+
+ sys::nftnl_expr_set_u32(
+ expr,
+ sys::NFTNL_EXPR_IMM_DREG as u16,
+ libc::NFT_REG_VERDICT as u32,
+ );
+
+ if let Some(chain) = self.chain() {
+ sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr());
+ }
+ sys::nftnl_expr_set_u32(
+ expr,
+ sys::NFTNL_EXPR_IMM_VERDICT as u16,
+ immediate_const as u32,
+ );
+
+ expr
+ }
}
}
diff --git a/rustables/src/expr/wrapper.rs b/rustables/src/expr/wrapper.rs
new file mode 100644
index 0000000..5162c21
--- /dev/null
+++ b/rustables/src/expr/wrapper.rs
@@ -0,0 +1,62 @@
+use std::ffi::CStr;
+use std::ffi::CString;
+use std::fmt::Debug;
+use std::rc::Rc;
+
+use super::DeserializationError;
+use super::Expression;
+use crate::Rule;
+use rustables_sys as sys;
+
+pub struct ExpressionWrapper {
+ pub(crate) expr: *const sys::nftnl_expr,
+ // we also need the rule here to ensure that the rule lives as long as the `expr` pointer
+ #[allow(dead_code)]
+ pub(crate) rule: Rc<Rule>,
+}
+
+impl Debug for ExpressionWrapper {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self.get_str())
+ }
+}
+
+impl ExpressionWrapper {
+ /// Retrieves a textual description of the expression.
+ pub fn get_str(&self) -> CString {
+ let mut descr_buf = vec![0i8; 4096];
+ unsafe {
+ sys::nftnl_expr_snprintf(
+ descr_buf.as_mut_ptr(),
+ (descr_buf.len() - 1) as u64,
+ self.expr,
+ sys::NFTNL_OUTPUT_DEFAULT,
+ 0,
+ );
+ CStr::from_ptr(descr_buf.as_ptr()).to_owned()
+ }
+ }
+
+ /// Retrieves the type of expression ("log", "counter", ...).
+ pub fn get_kind(&self) -> Option<&CStr> {
+ unsafe {
+ let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16);
+ if !ptr.is_null() {
+ Some(CStr::from_ptr(ptr))
+ } else {
+ None
+ }
+ }
+ }
+
+ /// Attempt to decode the expression as the type T.
+ pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> {
+ if let Some(kind) = self.get_kind() {
+ let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) };
+ if kind == raw_name {
+ return T::from_expr(self.expr);
+ }
+ }
+ Err(DeserializationError::InvalidExpressionKind)
+ }
+}
diff --git a/rustables/src/rule.rs b/rustables/src/rule.rs
index 6e15db7..b315daf 100644
--- a/rustables/src/rule.rs
+++ b/rustables/src/rule.rs
@@ -1,3 +1,4 @@
+use crate::expr::ExpressionWrapper;
use crate::{chain::Chain, expr::Expression, MsgType};
use rustables_sys::{self as sys, libc};
use std::ffi::{c_void, CStr, CString};
@@ -7,8 +8,8 @@ use std::rc::Rc;
/// A nftables firewall rule.
pub struct Rule {
- rule: *mut sys::nftnl_rule,
- chain: Rc<Chain>,
+ pub(crate) rule: *mut sys::nftnl_rule,
+ pub(crate) chain: Rc<Chain>,
}
impl Rule {
@@ -82,14 +83,15 @@ impl Rule {
pub fn get_userdata(&self) -> Option<&CStr> {
unsafe {
let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16);
- if ptr == std::ptr::null() {
- return None;
+ if !ptr.is_null() {
+ Some(CStr::from_ptr(ptr))
+ } else {
+ None
}
- Some(CStr::from_ptr(ptr))
}
}
- /// Update the userdata of this chain.
+ /// Updates the userdata of this chain.
pub fn set_userdata(&self, data: &CStr) {
unsafe {
sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr());
@@ -111,6 +113,11 @@ impl Rule {
}
}
+ /// Retrieves an iterator to loop over the expressions of the rule
+ pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter {
+ RuleExprsIter::new(self.clone())
+ }
+
#[cfg(feature = "unsafe-raw-handles")]
/// Returns the raw handle.
pub fn as_ptr(&self) -> *const sys::nftnl_rule {
@@ -122,6 +129,59 @@ impl Rule {
pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule {
self.rule
}
+
+ /// Perform a deep comparizon of rules, by checking they have the same expressions inside.
+ /// This is not enabled by default in our PartialEq implementation because of the
+ /// difficulty to compare an expression generated by the library with the expressions returned
+ /// by the kernel when iterating over the currently in-use rules. The kernel-returned
+ /// expressions may have additional attributes despite being generated from the same rule.
+ /// This is particularly true for the 'nat' expression).
+ pub fn deep_eq(&self, other: &Self) -> bool {
+ if self != other {
+ return false;
+ }
+
+ let self_exprs =
+ try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) });
+ let other_exprs = try_alloc!(unsafe {
+ sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule)
+ });
+
+ loop {
+ let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) };
+ let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) };
+ if self_next.is_null() && other_next.is_null() {
+ return true;
+ } else if self_next.is_null() || other_next.is_null() {
+ return false;
+ }
+
+ // we are falling back on comparing the strings, because there is no easy mechanism to
+ // perform a memcmp() between the two expressions :/
+ let mut self_str = [0; 256];
+ let mut other_str = [0; 256];
+ unsafe {
+ sys::nftnl_expr_snprintf(
+ self_str.as_mut_ptr(),
+ (self_str.len() - 1) as u64,
+ self_next,
+ sys::NFTNL_OUTPUT_DEFAULT,
+ 0,
+ );
+ sys::nftnl_expr_snprintf(
+ other_str.as_mut_ptr(),
+ (other_str.len() - 1) as u64,
+ other_next,
+ sys::NFTNL_OUTPUT_DEFAULT,
+ 0,
+ );
+ }
+
+ if self_str != other_str {
+ return false;
+ }
+ }
+ }
}
impl Debug for Rule {
@@ -132,7 +192,28 @@ impl Debug for Rule {
impl PartialEq for Rule {
fn eq(&self, other: &Self) -> bool {
- self.get_chain() == other.get_chain() && self.get_handle() == other.get_handle()
+ if self.get_chain() != other.get_chain() {
+ return false;
+ }
+
+ unsafe {
+ if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16)
+ && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16)
+ {
+ if self.get_handle() != other.get_handle() {
+ return false;
+ }
+ }
+ if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16)
+ && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16)
+ {
+ if self.get_position() != other.get_position() {
+ return false;
+ }
+ }
+ }
+
+ return false;
}
}
@@ -163,6 +244,43 @@ impl Drop for Rule {
}
}
+pub struct RuleExprsIter {
+ rule: Rc<Rule>,
+ iter: *mut sys::nftnl_expr_iter,
+}
+
+impl RuleExprsIter {
+ fn new(rule: Rc<Rule>) -> Self {
+ let iter =
+ try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) });
+ RuleExprsIter { rule, iter }
+ }
+}
+
+impl Iterator for RuleExprsIter {
+ type Item = ExpressionWrapper;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let next = unsafe { sys::nftnl_expr_iter_next(self.iter) };
+ if next.is_null() {
+ trace!("RulesExprsIter iterator ending");
+ None
+ } else {
+ trace!("RulesExprsIter returning new expression");
+ Some(ExpressionWrapper {
+ expr: next,
+ rule: self.rule.clone(),
+ })
+ }
+ }
+}
+
+impl Drop for RuleExprsIter {
+ fn drop(&mut self) {
+ unsafe { sys::nftnl_expr_iter_destroy(self.iter) };
+ }
+}
+
#[cfg(feature = "query")]
pub fn get_rules_cb(
header: &libc::nlmsghdr,
diff --git a/rustables/src/set.rs b/rustables/src/set.rs
index d8c84d6..aef74db 100644
--- a/rustables/src/set.rs
+++ b/rustables/src/set.rs
@@ -27,9 +27,9 @@ macro_rules! nft_set {
}
pub struct Set<'a, K> {
- set: *mut sys::nftnl_set,
- table: &'a Table,
- family: ProtoFamily,
+ pub(crate) set: *mut sys::nftnl_set,
+ pub(crate) table: &'a Table,
+ pub(crate) family: ProtoFamily,
_marker: ::std::marker::PhantomData<K>,
}
@@ -130,10 +130,14 @@ impl<'a, K> Set<'a, K> {
}
}
- pub fn get_name(&self) -> &CStr {
+ pub fn get_name(&self) -> Option<&CStr> {
unsafe {
let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16);
- CStr::from_ptr(ptr)
+ if !ptr.is_null() {
+ Some(CStr::from_ptr(ptr))
+ } else {
+ None
+ }
}
}
diff --git a/rustables/src/table.rs b/rustables/src/table.rs
index dc09b5e..7cc475f 100644
--- a/rustables/src/table.rs
+++ b/rustables/src/table.rs
@@ -38,7 +38,11 @@ impl Table {
pub fn get_name(&self) -> &CStr {
unsafe {
let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16);
- CStr::from_ptr(ptr)
+ if ptr.is_null() {
+ panic!("Impossible situation: retrieving the name of a chain failed")
+ } else {
+ CStr::from_ptr(ptr)
+ }
}
}
@@ -66,10 +70,11 @@ impl Table {
pub fn get_userdata(&self) -> Option<&CStr> {
unsafe {
let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16);
- if ptr == std::ptr::null() {
- return None;
+ if !ptr.is_null() {
+ Some(CStr::from_ptr(ptr))
+ } else {
+ None
}
- Some(CStr::from_ptr(ptr))
}
}