aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/batch.rs164
-rw-r--r--src/chain.rs3
-rw-r--r--src/lib.rs82
-rw-r--r--src/nlmsg.rs137
-rw-r--r--src/parser.rs453
-rw-r--r--src/query.rs196
-rw-r--r--src/rule.rs7
-rw-r--r--src/set.rs9
-rw-r--r--src/table.rs166
9 files changed, 803 insertions, 414 deletions
diff --git a/src/batch.rs b/src/batch.rs
index 198e8d0..714dc55 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -1,5 +1,7 @@
-use crate::{MsgType, NlMsg};
-use crate::sys::{self as sys, libc};
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::sys::{self};
+use crate::{MsgType, ProtoFamily};
+use libc;
use std::ffi::c_void;
use std::os::raw::c_char;
use std::ptr;
@@ -10,22 +12,15 @@ use thiserror::Error;
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());
-#[cfg(feature = "query")]
-/// Check if the kernel supports batched netlink messages to netfilter.
-pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
- match unsafe { sys::nftnl_batch_is_supported() } {
- 1 => Ok(true),
- 0 => Ok(false),
- _ => Err(NetlinkError(())),
- }
-}
-
/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to
/// `nftnl_batch` in libnftnl.
pub struct Batch {
- pub(crate) batch: *mut sys::nftnl_batch,
- pub(crate) seq: u32,
- pub(crate) is_empty: bool,
+ buf: Box<Vec<u8>>,
+ // the 'static lifetime here is a cheat, as the writer can only be used as long
+ // as `self.buf` exists. This is why this member must never be exposed directly to
+ // the rest of the crate (let alone publicly).
+ writer: NfNetlinkWriter<'static>,
+ seq: u32,
}
impl Batch {
@@ -33,48 +28,38 @@ impl Batch {
///
/// [default page size]: fn.default_batch_page_size.html
pub fn new() -> Self {
- Self::with_page_size(default_batch_page_size())
- }
-
- pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self {
+ // TODO: use a pinned Box ?
+ let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
+ let mut writer = NfNetlinkWriter::new(unsafe {
+ std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
+ });
+ writer.write_header(
+ libc::NFNL_MSG_BATCH_BEGIN as u16,
+ ProtoFamily::Unspec,
+ 0,
+ 0,
+ Some(libc::NFNL_SUBSYS_NFTABLES as u16),
+ );
Batch {
- batch,
- seq,
- // we assume this batch is not empty by default
- is_empty: false,
+ buf,
+ writer,
+ seq: 1,
}
}
- /// Creates a new nftnl batch with the given batch size.
- pub fn with_page_size(batch_page_size: u32) -> Self {
- let batch = try_alloc!(unsafe {
- sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize())
- });
- let mut this = Batch {
- batch,
- seq: 0,
- is_empty: true,
- };
- this.write_begin_msg();
- this
- }
-
/// Adds the given message to this batch.
- pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) {
+ pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
trace!("Writing NlMsg with seq {} to batch", self.seq);
- unsafe { msg.write(self.current(), self.seq, msg_type) };
- self.is_empty = false;
- self.next()
+ msg.add_or_remove(&mut self.writer, msg_type, self.seq);
+ self.seq += 1;
}
- /// Adds all the messages in the given iterator to this batch. If any message fails to be
- /// added the error for that failure is returned and all messages up until that message stay
- /// added to the batch.
- pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType)
- where
- T: NlMsg,
- I: Iterator<Item = T>,
- {
+ /// Adds all the messages in the given iterator to this batch.
+ pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
+ &mut self,
+ msg_iter: I,
+ msg_type: MsgType,
+ ) {
for msg in msg_iter {
self.add(&msg, msg_type);
}
@@ -86,53 +71,52 @@ impl Batch {
/// Return None if there is no object in the batch (this could block forever).
///
/// [`FinalizedBatch`]: struct.FinalizedBatch.html
- pub fn finalize(mut self) -> Option<FinalizedBatch> {
- self.write_end_msg();
- if self.is_empty {
- return None;
+ pub fn finalize(mut self) -> FinalizedBatch {
+ self.writer.write_header(
+ libc::NFNL_MSG_BATCH_END as u16,
+ ProtoFamily::Unspec,
+ 0,
+ self.seq,
+ Some(libc::NFNL_SUBSYS_NFTABLES as u16),
+ );
+ FinalizedBatch { batch: self }
+ }
+
+ /*
+ fn current(&self) -> *mut c_void {
+ unsafe { sys::nftnl_batch_buffer(self.batch) }
}
- Some(FinalizedBatch { batch: self })
- }
-
- fn current(&self) -> *mut c_void {
- unsafe { sys::nftnl_batch_buffer(self.batch) }
- }
- fn next(&mut self) {
- if unsafe { sys::nftnl_batch_update(self.batch) } < 0 {
- // See try_alloc definition.
- std::process::abort();
+ fn next(&mut self) {
+ if unsafe { sys::nftnl_batch_update(self.batch) } < 0 {
+ // See try_alloc definition.
+ std::process::abort();
+ }
+ self.seq += 1;
}
- self.seq += 1;
- }
-
- fn write_begin_msg(&mut self) {
- unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) };
- self.next();
- }
- fn write_end_msg(&mut self) {
- unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) };
- self.next();
- }
+ fn write_begin_msg(&mut self) {
+ unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) };
+ self.next();
+ }
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_batch {
- self.batch as *const sys::nftnl_batch
- }
+ fn write_end_msg(&mut self) {
+ unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) };
+ self.next();
+ }
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch {
- self.batch
- }
-}
+ #[cfg(feature = "unsafe-raw-handles")]
+ /// Returns the raw handle.
+ pub fn as_ptr(&self) -> *const sys::nftnl_batch {
+ self.batch as *const sys::nftnl_batch
+ }
-impl Drop for Batch {
- fn drop(&mut self) {
- unsafe { sys::nftnl_batch_free(self.batch) };
- }
+ #[cfg(feature = "unsafe-raw-handles")]
+ /// Returns a mutable version of the raw handle.
+ pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch {
+ self.batch
+ }
+ */
}
/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper
@@ -146,6 +130,7 @@ pub struct FinalizedBatch {
batch: Batch,
}
+/*
impl FinalizedBatch {
/// Returns the iterator over byte buffers to send to netlink.
pub fn iter(&mut self) -> Iter<'_> {
@@ -191,6 +176,7 @@ impl<'a> Iterator for Iter<'a> {
})
}
}
+*/
/// Selected batch page is 256 Kbytes long to load ruleset of half a million rules without hitting
/// -EMSGSIZE due to large iovec.
diff --git a/src/chain.rs b/src/chain.rs
index 18e3c64..a99d7f8 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,3 +1,4 @@
+use crate::nlmsg::NlMsg;
#[cfg(feature = "query")]
use crate::query::{Nfgenmsg, ParseError};
use crate::sys::{self as sys, libc};
@@ -215,7 +216,7 @@ impl PartialEq for Chain {
}
}
-unsafe impl crate::NlMsg for Chain {
+unsafe impl NlMsg for Chain {
unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
let raw_msg_type = match msg_type {
MsgType::Add => libc::NFT_MSG_NEWCHAIN,
diff --git a/src/lib.rs b/src/lib.rs
index 5d40c5a..60643fe 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -76,8 +76,8 @@ use thiserror::Error;
extern crate log;
pub mod sys;
+use libc;
use std::{convert::TryFrom, ffi::c_void, ops::Deref};
-use sys::libc;
macro_rules! try_alloc {
($e:expr) => {{
@@ -92,37 +92,40 @@ macro_rules! try_alloc {
}
mod batch;
-#[cfg(feature = "query")]
-pub use batch::{batch_is_supported, default_batch_page_size};
-pub use batch::{Batch, FinalizedBatch, NetlinkError};
+//#[cfg(feature = "query")]
+//pub use batch::{batch_is_supported, default_batch_page_size};
+//pub use batch::{Batch, FinalizedBatch, NetlinkError};
-pub mod expr;
+//pub mod expr;
pub mod table;
pub use table::Table;
-#[cfg(feature = "query")]
-pub use table::{get_tables_cb, list_tables};
-
-mod chain;
-#[cfg(feature = "query")]
-pub use chain::{get_chains_cb, list_chains_for_table};
-pub use chain::{Chain, ChainType, Hook, Policy, Priority};
+//#[cfg(feature = "query")]
+//pub use table::{get_tables_cb, list_tables};
+//
+//mod chain;
+//#[cfg(feature = "query")]
+//pub use chain::{get_chains_cb, list_chains_for_table};
+//pub use chain::{Chain, ChainType, Hook, Policy, Priority};
-mod chain_methods;
-pub use chain_methods::ChainMethods;
+//mod chain_methods;
+//pub use chain_methods::ChainMethods;
pub mod query;
-mod rule;
-pub use rule::Rule;
-#[cfg(feature = "query")]
-pub use rule::{get_rules_cb, list_rules_for_chain};
+pub mod nlmsg;
+pub mod parser;
-mod rule_methods;
-pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods};
+//mod rule;
+//pub use rule::Rule;
+//#[cfg(feature = "query")]
+//pub use rule::{get_rules_cb, list_rules_for_chain};
-pub mod set;
-pub use set::Set;
+//mod rule_methods;
+//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods};
+
+//pub mod set;
+//pub use set::Set;
/// The type of the message as it's sent to netfilter. A message consists of an object, such as a
/// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with
@@ -142,7 +145,7 @@ pub enum MsgType {
}
/// Denotes a protocol. Used to specify which protocol a table or set belongs to.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[repr(u16)]
pub enum ProtoFamily {
Unspec = libc::NFPROTO_UNSPEC as u16,
@@ -176,36 +179,3 @@ impl TryFrom<i32> for ProtoFamily {
}
}
}
-
-/// Trait for all types in this crate that can serialize to a Netlink message.
-///
-/// # Unsafe
-///
-/// This trait is unsafe to implement because it must never serialize to anything larger than the
-/// largest possible netlink message. Internally the `nft_nlmsg_maxsize()` function is used to
-/// make sure the `buf` pointer passed to `write` always has room for the largest possible Netlink
-/// message.
-pub unsafe trait NlMsg {
- /// Serializes the Netlink message to the buffer at `buf`. `buf` must have space for at least
- /// `nft_nlmsg_maxsize()` bytes. This is not checked by the compiler, which is why this method
- /// is unsafe.
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType);
-}
-
-unsafe impl<T, R> NlMsg for T
-where
- T: Deref<Target = R>,
- R: NlMsg,
-{
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- self.deref().write(buf, seq, msg_type);
- }
-}
-
-/// The largest nf_tables netlink message is the set element message, which contains the
-/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
-/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
-/// a bit larger than 64 KBytes.
-pub fn nft_nlmsg_maxsize() -> u32 {
- u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
-}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
new file mode 100644
index 0000000..70a4f01
--- /dev/null
+++ b/src/nlmsg.rs
@@ -0,0 +1,137 @@
+use std::{collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, ops::Deref};
+
+use libc::{
+ nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR,
+};
+use thiserror::Error;
+
+use crate::{
+ parser::{
+ pad_netlink_object, pad_netlink_object_with_variable_size, Attribute, DecodeError,
+ NfNetlinkAttributes, Nfgenmsg,
+ },
+ MsgType, ProtoFamily,
+};
+
+/*
+/// Trait for all types in this crate that can serialize to a Netlink message.
+pub trait NlMsg {
+ /// Serializes the Netlink message to the buffer at `buf`.
+ fn write(&self, buf: &mut Vec<u8>, msg_type: MsgType, seq: u32);
+}
+
+impl<T, R> NlMsg for T
+where
+ T: Deref<Target = R>,
+ R: NlMsg,
+{
+ fn write(&self, buf: &mut Vec<u8>, msg_type: MsgType, seq: u32) {
+ self.deref().write(buf, msg_type, seq);
+ }
+}
+*/
+
+pub struct NfNetlinkWriter<'a> {
+ buf: &'a mut Vec<u8>,
+ headers: HeaderStack<'a>,
+}
+
+impl<'a> NfNetlinkWriter<'a> {
+ pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> {
+ NfNetlinkWriter {
+ buf,
+ headers: HeaderStack::new(),
+ }
+ }
+
+ pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] {
+ let padded_size = pad_netlink_object_with_variable_size(size);
+ let start = self.buf.len();
+ self.buf.resize(start + padded_size, 0);
+
+ self.headers.add_size(padded_size as u32);
+
+ &mut self.buf[start..start + size]
+ }
+
+ pub fn extract_buffer(self) -> &'a mut Vec<u8> {
+ self.buf
+ }
+
+ // rewrite of `__nftnl_nlmsg_build_hdr`
+ pub fn write_header(
+ &mut self,
+ msg_type: u16,
+ family: ProtoFamily,
+ flags: u16,
+ seq: u32,
+ ressource_id: Option<u16>,
+ ) {
+ let nlmsghdr_len = pad_netlink_object::<nlmsghdr>();
+ let nfgenmsg_len = pad_netlink_object::<Nfgenmsg>();
+ let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len);
+
+ let mut hdr: &mut nlmsghdr =
+ unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
+ //let mut hdr = &mut unsafe { *(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
+
+ hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32;
+ hdr.nlmsg_type = ((NFNL_SUBSYS_NFTABLES as u16) << 8) | msg_type;
+ hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
+ hdr.nlmsg_seq = seq;
+
+ let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len);
+
+ let mut nfgenmsg: &mut Nfgenmsg =
+ unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut Nfgenmsg) };
+ nfgenmsg.family = family as u8;
+ nfgenmsg.version = NFNETLINK_V0 as u8;
+ nfgenmsg.res_id = ressource_id.unwrap_or(0);
+
+ self.headers.add_level(hdr, Some(nfgenmsg));
+ }
+
+ pub fn get_current_header(&mut self) -> Option<&mut nlmsghdr> {
+ let stack_size = self.headers.stack.len();
+ if stack_size > 0 {
+ Some(unsafe { std::mem::transmute(self.headers.stack[stack_size - 1].0) })
+ } else {
+ None
+ }
+ }
+}
+
+struct HeaderStack<'a> {
+ stack: Vec<(*mut nlmsghdr, Option<*mut Nfgenmsg>)>,
+ lifetime: PhantomData<&'a ()>,
+}
+
+impl<'a> HeaderStack<'a> {
+ fn new() -> HeaderStack<'a> {
+ HeaderStack {
+ stack: Vec::new(),
+ lifetime: PhantomData,
+ }
+ }
+
+ /// resize all the stacked netlink containers to hold additional_size new bytes
+ fn add_size(&mut self, additional_size: u32) {
+ for (hdr, _) in &mut self.stack {
+ unsafe {
+ (**hdr).nlmsg_len = (**hdr).nlmsg_len + additional_size;
+ }
+ }
+ }
+
+ fn add_level(&mut self, hdr: *mut nlmsghdr, nfgenmsg: Option<*mut Nfgenmsg>) {
+ self.stack.push((hdr, nfgenmsg));
+ }
+}
+
+pub trait NfNetlinkObject: Sized {
+ fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32);
+
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError>;
+
+ fn deserialize(buf: &[u8]) -> Result<(&[u8], Self), DecodeError>;
+}
diff --git a/src/parser.rs b/src/parser.rs
new file mode 100644
index 0000000..ddcfbf4
--- /dev/null
+++ b/src/parser.rs
@@ -0,0 +1,453 @@
+use std::{
+ collections::HashMap,
+ fmt::Debug,
+ mem::{self, size_of, transmute},
+ str::Utf8Error,
+ string::FromUtf8Error,
+};
+
+use libc::{
+ nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_MIN_TYPE,
+ NLM_F_DUMP_INTR,
+};
+use thiserror::Error;
+
+use crate::{
+ nlmsg::{NfNetlinkObject, NfNetlinkWriter},
+ InvalidProtocolFamily, ProtoFamily,
+};
+
+#[derive(Error, Debug)]
+pub enum DecodeError {
+ #[error("The buffer is too small to hold a valid message")]
+ BufTooSmall,
+
+ #[error("The message is too small")]
+ NlMsgTooSmall,
+
+ #[error("Invalid subsystem, expected NFTABLES")]
+ InvalidSubsystem(u8),
+
+ #[error("Invalid version, expected NFNETLINK_V0")]
+ InvalidVersion(u8),
+
+ #[error("Invalid port ID")]
+ InvalidPortId(u32),
+
+ #[error("Invalid sequence number")]
+ InvalidSeq(u32),
+
+ #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
+ ConcurrentGenerationUpdate,
+
+ #[error("Unsupported message type")]
+ UnsupportedType(u16),
+
+ #[error("Invalid attribute type")]
+ InvalidAttributeType,
+
+ #[error("Unsupported attribute type")]
+ UnsupportedAttributeType(u16),
+
+ #[error("Unexpected message type")]
+ UnexpectedType(u16),
+
+ #[error("The decoded String is not UTF8 compliant")]
+ StringDecodeFailure(#[from] FromUtf8Error),
+
+ #[error("Invalid value for a protocol family")]
+ InvalidProtocolFamily(#[from] InvalidProtocolFamily),
+
+ #[error("A custom error occured")]
+ Custom(Box<dyn std::error::Error + 'static>),
+}
+
+/// The largest nf_tables netlink message is the set element message, which contains the
+/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set
+/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is
+/// a bit larger than 64 KBytes.
+pub fn nft_nlmsg_maxsize() -> u32 {
+ u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
+}
+
+#[inline]
+pub fn pad_netlink_object_with_variable_size(size: usize) -> usize {
+ // align on a 4 bytes boundary
+ (size + 3) & (!3)
+}
+
+#[inline]
+pub fn pad_netlink_object<T>() -> usize {
+ let size = size_of::<T>();
+ // align on a 4 bytes boundary
+ pad_netlink_object_with_variable_size(size)
+}
+
+#[repr(C)]
+#[derive(Debug, Clone, Copy)]
+pub struct Nfgenmsg {
+ pub family: u8, /* AF_xxx */
+ pub version: u8, /* nfnetlink version */
+ pub res_id: u16, /* resource id */
+}
+
+pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
+ ((x & 0xff00) >> 8) as u8
+}
+
+pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
+ (x & 0x00ff) as u8
+}
+
+pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> {
+ let size_of_hdr = size_of::<nlmsghdr>();
+
+ if buf.len() < size_of_hdr {
+ return Err(DecodeError::BufTooSmall);
+ }
+
+ let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr;
+ let nlmsghdr = unsafe { *nlmsghdr_ptr };
+
+ if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+
+ if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 {
+ return Err(DecodeError::ConcurrentGenerationUpdate);
+ }
+
+ Ok(nlmsghdr)
+}
+pub enum NlMsg<'a> {
+ Done,
+ Noop,
+ Error(nlmsgerr),
+ NfGenMsg(Nfgenmsg, &'a [u8]),
+}
+
+pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeError> {
+ // in theory the message is composed of the following parts:
+ // - nlmsghdr (contains the message size and type)
+ // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
+ // - the raw value that we want to validate (if the previous part is nfgenmsg)
+ let nlmsghdr = get_nlmsghdr(buf)?;
+
+ let size_of_hdr = pad_netlink_object::<nlmsghdr>();
+
+ if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
+ match nlmsghdr.nlmsg_type as libc::c_int {
+ NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)),
+ NLMSG_ERROR => {
+ if nlmsghdr.nlmsg_len as usize > buf.len()
+ || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
+ {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+ let mut err = unsafe {
+ *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr()
+ as *const nlmsgerr)
+ };
+ // some APIs return negative values, while other return positive values
+ err.error = err.error.abs();
+ return Ok((nlmsghdr, NlMsg::Error(err)));
+ }
+ NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)),
+ x => return Err(DecodeError::UnsupportedType(x as u16)),
+ }
+ }
+
+ let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(DecodeError::InvalidSubsystem(subsys));
+ }
+
+ let size_of_nfgenmsg = pad_netlink_object::<Nfgenmsg>();
+ if nlmsghdr.nlmsg_len as usize > buf.len()
+ || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
+ {
+ return Err(DecodeError::NlMsgTooSmall);
+ }
+
+ let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg;
+ let nfgenmsg = unsafe { *nfgenmsg_ptr };
+ let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(DecodeError::InvalidSubsystem(subsys));
+ }
+ if nfgenmsg.version != NFNETLINK_V0 as u8 {
+ return Err(DecodeError::InvalidVersion(nfgenmsg.version));
+ }
+
+ let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize];
+
+ Ok((nlmsghdr, NlMsg::NfGenMsg(nfgenmsg, raw_value)))
+}
+
+pub type NetlinkType = u16;
+
+pub trait NfNetlinkAttribute: Debug + Sized {
+ fn get_size(&self) -> usize {
+ size_of::<Self>()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8);
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
+}
+
+/// Write the attribute, preceded by a `libc::nlattr`
+// rewrite of `mnl_attr_put`
+fn write_attribute<'a>(ty: NetlinkType, obj: &Attribute, writer: &mut NfNetlinkWriter<'a>) {
+ // copy the header
+ let header_len = pad_netlink_object::<libc::nlattr>();
+ let header = libc::nlattr {
+ // nla_len contains the header size + the unpadded attribute length
+ nla_len: (header_len + obj.get_size() as usize) as u16,
+ nla_type: ty,
+ };
+
+ let buf = writer.add_data_zeroed(header_len);
+ unsafe {
+ std::ptr::copy_nonoverlapping(
+ &header as *const libc::nlattr as *const u8,
+ buf.as_mut_ptr(),
+ header_len as usize,
+ );
+ }
+
+ let buf = writer.add_data_zeroed(obj.get_size());
+ // copy the attribute data itself
+ unsafe {
+ obj.write_payload(buf.as_mut_ptr());
+ }
+}
+
+impl NfNetlinkAttribute for ProtoFamily {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut u32) = *self as u32;
+ }
+}
+
+impl NfNetlinkAttribute for u8 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *addr = *self;
+ }
+}
+
+impl NfNetlinkAttribute for u16 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = *self;
+ }
+}
+
+impl NfNetlinkAttribute for u32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = *self;
+ }
+}
+
+impl NfNetlinkAttribute for u64 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = *self;
+ }
+}
+
+impl NfNetlinkAttribute for String {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
+ }
+}
+
+impl NfNetlinkAttribute for Vec<u8> {
+ fn get_size(&self) -> usize {
+ self.len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
+ }
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct NfNetlinkAttributes {
+ attributes: HashMap<NetlinkType, Attribute>,
+}
+
+impl NfNetlinkAttributes {
+ pub fn new() -> Self {
+ NfNetlinkAttributes {
+ attributes: HashMap::new(),
+ }
+ }
+
+ pub fn set_attr(&mut self, ty: NetlinkType, obj: Attribute) {
+ self.attributes.insert(ty, obj);
+ }
+
+ pub fn get_attr(&self, ty: NetlinkType) -> Option<&Attribute> {
+ self.attributes.get(&ty)
+ }
+}
+
+pub struct NfNetlinkAttributeReader<'a> {
+ buf: &'a [u8],
+ pos: usize,
+ remaining_size: usize,
+ attrs: NfNetlinkAttributes,
+}
+
+impl<'a> NfNetlinkAttributeReader<'a> {
+ pub fn new(buf: &'a [u8], remaining_size: usize) -> Result<Self, DecodeError> {
+ if buf.len() < remaining_size {
+ return Err(DecodeError::BufTooSmall);
+ }
+
+ Ok(Self {
+ buf,
+ pos: 0,
+ remaining_size,
+ attrs: NfNetlinkAttributes::new(),
+ })
+ }
+
+ pub fn decode<T: NfNetlinkObject>(
+ mut self,
+ ) -> Result<(&'a [u8], NfNetlinkAttributes), DecodeError> {
+ while self.remaining_size > pad_netlink_object::<nlattr>() {
+ let nlattr =
+ unsafe { *transmute::<*const u8, *const nlattr>(self.buf[self.pos..].as_ptr()) };
+ // TODO: ignore the byteorder and nested attributes for now
+ let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16;
+
+ self.pos += pad_netlink_object::<nlattr>();
+ let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>();
+ self.attrs.set_attr(
+ nla_type,
+ T::decode_attribute(
+ nla_type,
+ &self.buf[self.pos..self.pos + attr_remaining_size],
+ )?,
+ );
+ self.pos += pad_netlink_object_with_variable_size(attr_remaining_size);
+
+ self.remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
+ }
+
+ Ok((&self.buf[self.pos..], self.attrs))
+ }
+}
+
+pub fn expect_msgtype_in_nlmsg<'a>(
+ buf: &'a [u8],
+ nlmsg_type: u8,
+) -> Result<(nlmsghdr, Nfgenmsg, &'a [u8], NfNetlinkAttributeReader<'a>), DecodeError> {
+ let (hdr, msg) = parse_nlmsg(buf)?;
+
+ if get_operation_from_nlmsghdr_type(hdr.nlmsg_type) != nlmsg_type {
+ return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
+ }
+
+ let remaining_size = hdr.nlmsg_len as usize
+ - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<Nfgenmsg>());
+
+ match msg {
+ NlMsg::NfGenMsg(nfgenmsg, content) => Ok((
+ hdr,
+ nfgenmsg,
+ content,
+ NfNetlinkAttributeReader::new(content, remaining_size)?,
+ )),
+ _ => Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
+ }
+}
+
+pub trait SerializeNfNetlink {
+ fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>);
+}
+
+impl SerializeNfNetlink for NfNetlinkAttributes {
+ fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) {
+ // TODO: improve performance by not sorting this
+ let mut keys: Vec<&NetlinkType> = self.attributes.keys().collect();
+ keys.sort();
+ for k in keys {
+ write_attribute(*k, self.attributes.get(k).unwrap(), writer);
+ }
+ }
+}
+
+macro_rules! impl_attribute {
+ ($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => {
+ #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+ pub enum $enum_name {
+ $(
+ $internal_name($type),
+ )+
+ }
+
+ impl NfNetlinkAttribute for $enum_name {
+ fn get_size(&self) -> usize {
+ match self {
+ $(
+ $enum_name::$internal_name(val) => val.get_size()
+ ),+
+ }
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ match self {
+ $(
+ $enum_name::$internal_name(val) => val.write_payload(addr)
+ ),+
+ }
+
+ }
+ }
+
+ impl $enum_name {
+ $(
+ #[allow(non_snake_case)]
+ pub fn $internal_name(&self) -> Option<&$type> {
+ match self {
+ $enum_name::$internal_name(val) => Some(val),
+ _ => None
+ }
+ }
+ )+
+ }
+ };
+}
+
+impl_attribute!(
+ Attribute,
+ [String, String],
+ [U8, u8],
+ [U16, u16],
+ [U32, u32],
+ [U64, u64],
+ [VecU8, Vec<u8>],
+ [ProtoFamily, ProtoFamily]
+);
+
+#[macro_export]
+macro_rules! impl_attr_getters_and_setters {
+ ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => {
+ impl $struct {
+ $(
+ #[allow(dead_code)]
+ pub fn $getter_name(&self) -> Option<&$type> {
+ self.inner.get_attr($attr_name as $crate::parser::NetlinkType).map(|x| x.$internal_name()).flatten()
+ }
+
+ #[allow(dead_code)]
+ pub fn $setter_name(&mut self, val: $type) {
+ self.inner.set_attr($attr_name as $crate::parser::NetlinkType, $crate::parser::Attribute::$internal_name(val));
+ }
+ )+
+ }
+ };
+}
diff --git a/src/query.rs b/src/query.rs
index 1c81cdd..80fdc75 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,163 +1,14 @@
use std::mem::size_of;
-use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
+use crate::{
+ nlmsg::NfNetlinkWriter,
+ parser::{nft_nlmsg_maxsize, Nfgenmsg},
+ sys, ProtoFamily,
+};
use libc::{
nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR,
NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
};
-use sys::libc;
-
-#[repr(C)]
-#[derive(Debug, Clone, Copy)]
-pub struct Nfgenmsg {
- pub family: u8, /* AF_xxx */
- pub version: u8, /* nfnetlink version */
- pub res_id: u16, /* resource id */
-}
-
-#[derive(thiserror::Error, Debug)]
-pub enum ParseError {
- #[error("The buffer is too small to hold a valid message")]
- BufTooSmall,
-
- #[error("The message is too small")]
- NlMsgTooSmall,
-
- #[error("Invalid subsystem, expected NFTABLES")]
- InvalidSubsystem(u8),
-
- #[error("Invalid version, expected NFNETLINK_V0")]
- InvalidVersion(u8),
-
- #[error("Invalid port ID")]
- InvalidPortId(u32),
-
- #[error("Invalid sequence number")]
- InvalidSeq(u32),
-
- #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
- ConcurrentGenerationUpdate,
-
- #[error("Unsupported message type")]
- UnsupportedType(u16),
-
- #[error("A custom error occured")]
- Custom(Box<dyn std::error::Error + 'static>),
-}
-
-pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
- ((x & 0xff00) >> 8) as u8
-}
-
-pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
- (x & 0x00ff) as u8
-}
-
-pub unsafe fn get_nlmsghdr(
- buf: &[u8],
- expected_seq: u32,
- expected_port_id: u32,
-) -> Result<&nlmsghdr, ParseError> {
- let size_of_hdr = size_of::<nlmsghdr>();
-
- if buf.len() < size_of_hdr {
- return Err(ParseError::BufTooSmall);
- }
-
- let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr;
- let nlmsghdr = *nlmsghdr_ptr;
-
- if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr {
- println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
-
- if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id {
- return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid));
- }
-
- if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq {
- return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq));
- }
-
- if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 {
- return Err(ParseError::ConcurrentGenerationUpdate);
- }
-
- Ok(&*nlmsghdr_ptr as &nlmsghdr)
-}
-
-pub enum NlMsg<'a> {
- Done,
- Noop,
- Error(nlmsgerr),
- NfGenMsg(&'a Nfgenmsg, &'a [u8]),
-}
-
-pub unsafe fn parse_nlmsg<'a>(
- buf: &'a [u8],
- expected_seq: u32,
- expected_port_id: u32,
-) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> {
- // in theory the message is composed of the following parts:
- // - nlmsghdr (contains the message size and type)
- // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
- // - the raw value that we want to validate (if the previous part is nfgenmsg)
- let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?;
-
- let size_of_hdr = size_of::<nlmsghdr>();
-
- if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
- match nlmsghdr.nlmsg_type as libc::c_int {
- NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)),
- NLMSG_ERROR => {
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
- {
- println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
- let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr()
- as *const nlmsgerr);
- // some APIs return negative values, while other return positive values
- err.error = err.error.abs();
- return Ok((nlmsghdr, NlMsg::Error(err)));
- }
- NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)),
- x => return Err(ParseError::UnsupportedType(x as u16)),
- }
- }
-
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(ParseError::InvalidSubsystem(subsys));
- }
-
- let size_of_nfgenmsg = size_of::<Nfgenmsg>();
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
- {
- println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
-
- let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg;
- let nfgenmsg = *nfgenmsg_ptr;
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(ParseError::InvalidSubsystem(subsys));
- }
- if nfgenmsg.version != NFNETLINK_V0 as u8 {
- return Err(ParseError::InvalidVersion(nfgenmsg.version));
- }
-
- let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize];
-
- Ok((
- nlmsghdr,
- NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value),
- ))
-}
/// Returns a buffer containing a netlink message which requests a list of all the netfilter
/// matching objects (e.g. tables, chains, rules, ...).
@@ -165,22 +16,23 @@ pub unsafe fn parse_nlmsg<'a>(
/// to execute on the header, to set parameters for example.
/// To pass arbitrary data inside that callback, please use a closure.
pub fn get_list_of_objects<Error>(
+ msg_type: u16,
seq: u32,
- target: u16,
setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
) -> Result<Vec<u8>, Error> {
let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
- let hdr = unsafe {
- &mut *sys::nftnl_nlmsg_build_hdr(
- buffer.as_mut_ptr() as *mut libc::c_char,
- target,
- ProtoFamily::Unspec as u16,
- (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
- seq,
- )
- };
+ let mut writer = &mut NfNetlinkWriter::new(&mut buffer);
+ writer.write_header(
+ msg_type,
+ ProtoFamily::Unspec,
+ (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
+ seq,
+ None,
+ );
if let Some(cb) = setup_cb {
- cb(hdr)?;
+ cb(writer
+ .get_current_header()
+ .expect("Fatal error: mising header"))?;
}
Ok(buffer)
}
@@ -191,12 +43,12 @@ mod inner {
use nix::{
errno::Errno,
- sys::socket::{
- self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
- },
+ sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType},
};
- use crate::FinalizedBatch;
+ //use crate::FinalizedBatch;
+
+ use crate::nlmsg::{parse_nlmsg, DecodeError, NlMsg};
use super::*;
@@ -212,7 +64,7 @@ mod inner {
NetlinkRecvError(#[source] nix::Error),
#[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[from] ParseError),
+ ProcessNetlinkError(#[from] DecodeError),
#[error("Error received from the kernel")]
NetlinkError(nlmsgerr),
@@ -316,7 +168,7 @@ mod inner {
let seq = 0;
let portid = 0;
- let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
+ let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?;
socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?;
Ok(socket_close_wrapper(sock, move |sock| {
@@ -324,6 +176,7 @@ mod inner {
})?)
}
+ /*
pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
let sock = socket::socket(
AddressFamily::Netlink,
@@ -357,6 +210,7 @@ mod inner {
recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid)
})?)
}
+ */
}
#[cfg(feature = "query")]
diff --git a/src/rule.rs b/src/rule.rs
index 66beef8..80ca0c7 100644
--- a/src/rule.rs
+++ b/src/rule.rs
@@ -1,4 +1,5 @@
use crate::expr::ExpressionWrapper;
+use crate::nlmsg::NlMsg;
#[cfg(feature = "query")]
use crate::query::{Nfgenmsg, ParseError};
use crate::sys::{self, libc};
@@ -219,8 +220,8 @@ impl PartialEq for Rule {
}
}
-unsafe impl crate::NlMsg for Rule {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
+unsafe impl NlMsg for Rule {
+ unsafe fn write(&self, buf: &mut Vec<u8>, seq: u32, msg_type: MsgType) {
let type_ = match msg_type {
MsgType::Add => libc::NFT_MSG_NEWRULE,
MsgType::Del => libc::NFT_MSG_DELRULE,
@@ -229,6 +230,7 @@ unsafe impl crate::NlMsg for Rule {
MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16,
MsgType::Del => 0u16,
} | libc::NLM_F_ACK as u16;
+ /*
let header = sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char,
type_ as u16,
@@ -237,6 +239,7 @@ unsafe impl crate::NlMsg for Rule {
seq,
);
sys::nftnl_rule_nlmsg_build_payload(header, self.rule);
+ */
}
}
diff --git a/src/set.rs b/src/set.rs
index b8c45ac..b153450 100644
--- a/src/set.rs
+++ b/src/set.rs
@@ -1,3 +1,4 @@
+use crate::nlmsg::NlMsg;
use crate::sys::{self, libc};
use crate::{table::Table, MsgType};
use std::{
@@ -145,12 +146,13 @@ impl<K> Debug for Set<K> {
}
}
-unsafe impl<K> crate::NlMsg for Set<K> {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
+unsafe impl<K> NlMsg for Set<K> {
+ unsafe fn write(&self, buf: &mut Vec<u8>, seq: u32, msg_type: MsgType) {
let type_ = match msg_type {
MsgType::Add => libc::NFT_MSG_NEWSET,
MsgType::Del => libc::NFT_MSG_DELSET,
};
+ /*
let header = sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char,
type_ as u16,
@@ -159,6 +161,7 @@ unsafe impl<K> crate::NlMsg for Set<K> {
seq,
);
sys::nftnl_set_nlmsg_build_payload(header, self.set);
+ */
}
}
@@ -217,7 +220,7 @@ pub struct SetElemsMsg<'a, K> {
ret: Rc<Cell<i32>>,
}
-unsafe impl<'a, K> crate::NlMsg for SetElemsMsg<'a, K> {
+unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> {
unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
trace!("Writing SetElemsMsg to NlMsg");
let (type_, flags) = match msg_type {
diff --git a/src/table.rs b/src/table.rs
index 332cc99..6b34291 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,53 +1,42 @@
-#[cfg(feature = "query")]
-use crate::query::{Nfgenmsg, ParseError};
-use crate::sys::{self, libc};
-use crate::{MsgType, ProtoFamily};
-#[cfg(feature = "query")]
use std::convert::TryFrom;
-use std::{
- ffi::{c_void, CStr, CString},
- fmt::Debug,
- os::raw::c_char,
+use std::fmt::Debug;
+
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::parser::{
+ expect_msgtype_in_nlmsg, parse_nlmsg, Attribute, DecodeError, NfNetlinkAttributeReader,
+ NfNetlinkAttributes, NlMsg, SerializeNfNetlink,
};
+use crate::sys::{self, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME};
+use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily};
+#[cfg(feature = "query")]
+use crate::{parser::Nfgenmsg, query::ParseError};
+use libc;
/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
+#[derive(Debug, PartialEq, Eq)]
pub struct Table {
- table: *mut sys::nftnl_table,
+ inner: NfNetlinkAttributes,
family: ProtoFamily,
}
impl Table {
/// Creates a new table instance with the given name and protocol family.
- pub fn new<T: AsRef<CStr>>(name: &T, family: ProtoFamily) -> Table {
- unsafe {
- let table = try_alloc!(sys::nftnl_table_alloc());
+ pub fn new<T: Into<String>>(name: T, family: ProtoFamily) -> Table {
+ let mut res = Table {
+ inner: NfNetlinkAttributes::new(),
+ family,
+ };
- sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FAMILY as u16, family as u32);
- sys::nftnl_table_set_str(table, sys::NFTNL_TABLE_NAME as u16, name.as_ref().as_ptr());
- sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FLAGS as u16, 0u32);
- Table { table, family }
- }
- }
+ res.set_name(name.into());
+ res.set_flags(0);
- pub unsafe fn from_raw(table: *mut sys::nftnl_table, family: ProtoFamily) -> Self {
- Table { table, family }
- }
-
- /// Returns the name of this table.
- pub fn get_name(&self) -> &CStr {
- unsafe {
- let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16);
- if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a table failed")
- } else {
- CStr::from_ptr(ptr)
- }
- }
+ res
}
+ /*
/// Returns a textual description of the table.
pub fn get_str(&self) -> CString {
let mut descr_buf = vec![0i8; 4096];
@@ -62,79 +51,71 @@ impl Table {
CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
}
}
-
- /// Returns the protocol family for this table.
- pub fn get_family(&self) -> ProtoFamily {
- self.family
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16);
- if !ptr.is_null() {
- Some(CStr::from_ptr(ptr))
- } else {
- None
- }
- }
- }
-
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_table_set_str(self.table, sys::NFTNL_TABLE_USERDATA as u16, data.as_ptr());
- }
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_table {
- self.table as *const sys::nftnl_table
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&self) -> *mut sys::nftnl_table {
- self.table
- }
+ */
}
-
+/*
impl PartialEq for Table {
fn eq(&self, other: &Self) -> bool {
- self.get_name() == other.get_name() && self.get_family() == other.get_family()
- }
-}
-
-impl Debug for Table {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.get_str())
+ self.get_name() == other.get_name() && self.family == other.family
}
}
+*/
-unsafe impl crate::NlMsg for Table {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
+impl NfNetlinkObject for Table {
+ fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
let raw_msg_type = match msg_type {
MsgType::Add => libc::NFT_MSG_NEWTABLE,
MsgType::Del => libc::NFT_MSG_DELTABLE,
- };
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- raw_msg_type as u16,
- self.family as u16,
- libc::NLM_F_ACK as u16,
- seq,
- );
- sys::nftnl_table_nlmsg_build_payload(header, self.table);
+ } as u16;
+ writer.write_header(raw_msg_type, self.family, libc::NLM_F_ACK as u16, seq, None);
+ self.inner.serialize(writer);
}
-}
-impl Drop for Table {
- fn drop(&mut self) {
- unsafe { sys::nftnl_table_free(self.table) };
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError> {
+ match attr_type {
+ NFTA_TABLE_NAME => Ok(Attribute::String(String::from_utf8(buf.to_vec())?)),
+ NFTA_TABLE_FLAGS => {
+ let val = [buf[0], buf[1], buf[2], buf[3]];
+
+ Ok(Attribute::U32(u32::from_ne_bytes(val)))
+ }
+ NFTA_TABLE_USERDATA => Ok(Attribute::VecU8(buf.to_vec())),
+ _ => Err(DecodeError::UnsupportedAttributeType(attr_type)),
+ }
+ }
+
+ fn deserialize(buf: &[u8]) -> Result<(&[u8], Self), DecodeError> {
+ let (hdr, nfgenmsg, content, mut attrs) =
+ expect_msgtype_in_nlmsg(buf, libc::NFT_MSG_NEWTABLE as u8)?;
+
+ let (remaining_buf, inner) = attrs.decode::<Table>()?;
+
+ Ok((
+ remaining_buf,
+ Table {
+ inner,
+ family: ProtoFamily::try_from(nfgenmsg.family as i32)?,
+ },
+ ))
}
}
+impl_attr_getters_and_setters!(
+ Table,
+ [
+ (get_name, set_name, sys::NFTA_TABLE_NAME, String, String),
+ (
+ get_userdata,
+ set_userdata,
+ sys::NFTA_TABLE_USERDATA,
+ VecU8,
+ Vec<u8>
+ ),
+ (get_flags, set_flags, sys::NFTA_TABLE_FLAGS, U32, u32)
+ ]
+);
+
+/*
#[cfg(feature = "query")]
/// A callback to parse the response for messages created with `get_tables_nlmsg`.
pub fn get_tables_cb(
@@ -190,3 +171,4 @@ pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> {
)?;
Ok(result)
}
+*/