aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--build.rs6
-rw-r--r--examples/add-rules.rs7
-rw-r--r--src/chain.rs286
-rw-r--r--src/lib.rs20
-rw-r--r--src/nlmsg.rs65
-rw-r--r--src/parser.rs193
-rw-r--r--src/table.rs47
-rw-r--r--tests/batch.rs1
-rw-r--r--tests/table.rs2
9 files changed, 412 insertions, 215 deletions
diff --git a/build.rs b/build.rs
index 1eef525..0d4903d 100644
--- a/build.rs
+++ b/build.rs
@@ -4,7 +4,6 @@ use bindgen;
use lazy_static::lazy_static;
use regex::{Captures, Regex};
use std::borrow::Cow;
-use std::env;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
@@ -16,11 +15,6 @@ fn main() {
generate_sys();
}
-fn get_env(var: &'static str) -> Option<PathBuf> {
- println!("cargo:rerun-if-env-changed={}", var);
- env::var_os(var).map(PathBuf::from)
-}
-
/// `bindgen`erate a rust sys file from the C kernel headers of the nf_tables capabilities.
fn generate_sys() {
// Tell cargo to invalidate the built crate whenever the headers change.
diff --git a/examples/add-rules.rs b/examples/add-rules.rs
index 11e7b6f..812721c 100644
--- a/examples/add-rules.rs
+++ b/examples/add-rules.rs
@@ -37,7 +37,7 @@
//! ```
use ipnetwork::{IpNetwork, Ipv4Network};
-use rustables::{Batch, ProtoFamily, Table};
+use rustables::{table::list_tables, Batch, ProtoFamily, Table};
//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table};
use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc};
@@ -46,6 +46,7 @@ const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
const IN_CHAIN_NAME: &str = "chain-for-incoming-packets";
fn main() -> Result<(), Error> {
+ /*
// Create a batch. This is used to store all the netlink messages we will later send.
// Creating a new batch also automatically writes the initial batch begin message needed
// to tell netlink this is a single transaction that might arrive over multiple netlink packets.
@@ -175,6 +176,10 @@ fn main() -> Result<(), Error> {
// netfilter the we reached the end of the transaction message. It's also converted to a
// Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter.
Ok(batch.send()?)
+ */
+
+ println!("{:?}", list_tables());
+ Ok(())
}
// Look up the interface index for a given interface name.
diff --git a/src/chain.rs b/src/chain.rs
index a99d7f8..e29b239 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,9 +1,10 @@
-use crate::nlmsg::NlMsg;
-#[cfg(feature = "query")]
-use crate::query::{Nfgenmsg, ParseError};
-use crate::sys::{self as sys, libc};
-use crate::{MsgType, Table};
-#[cfg(feature = "query")]
+use crate::nlmsg::{
+ AttributeDecoder, NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable,
+ NfNetlinkObject, NfNetlinkWriter,
+};
+use crate::parser::{DecodeError, NestedAttribute, NfNetlinkAttributeReader};
+use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK};
+use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily, Table};
use std::convert::TryFrom;
use std::{
ffi::{c_void, CStr, CString},
@@ -15,24 +16,80 @@ use std::{
pub type Priority = i32;
/// The netfilter event hooks a chain can register for.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u16)]
-pub enum Hook {
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[repr(u32)]
+pub enum HookClass {
/// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`.
- PreRouting = libc::NF_INET_PRE_ROUTING as u16,
+ PreRouting = libc::NF_INET_PRE_ROUTING as u32,
/// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`.
- In = libc::NF_INET_LOCAL_IN as u16,
+ In = libc::NF_INET_LOCAL_IN as u32,
/// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`.
- Forward = libc::NF_INET_FORWARD as u16,
+ Forward = libc::NF_INET_FORWARD as u32,
/// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`.
- Out = libc::NF_INET_LOCAL_OUT as u16,
+ Out = libc::NF_INET_LOCAL_OUT as u32,
/// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`.
- PostRouting = libc::NF_INET_POST_ROUTING as u16,
+ PostRouting = libc::NF_INET_POST_ROUTING as u32,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Hook {
+ inner: NestedAttribute,
+}
+
+impl Hook {
+ fn new(class: HookClass, priority: Priority) -> Self {
+ Hook {
+ inner: NestedAttribute::new(),
+ }
+ .with_hook_class(class as u32)
+ .with_hook_priority(priority as u32)
+ }
+}
+
+impl_attr_getters_and_setters!(
+ Hook,
+ [
+ // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it.
+ (
+ get_hook_class,
+ set_hook_class,
+ with_hook_class,
+ sys::NFTA_HOOK_HOOKNUM,
+ U32,
+ u32
+ ),
+ (
+ get_hook_priority,
+ set_hook_priority,
+ with_hook_priority,
+ sys::NFTA_HOOK_PRIORITY,
+ U32,
+ u32
+ )
+ ]
+);
+
+impl NfNetlinkAttribute for Hook {
+ fn get_size(&self) -> usize {
+ self.inner.get_size()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ self.inner.write_payload(addr)
+ }
+}
+
+impl NfNetlinkDeserializable for Hook {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let reader = NfNetlinkAttributeReader::new(buf, buf.len())?;
+ let inner = reader.decode::<Self>()?;
+ Ok((Hook { inner }, &[]))
+ }
}
/// A chain policy. Decides what to do with a packet that was processed by the chain but did not
/// match any rules.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[repr(u32)]
pub enum Policy {
/// Accept the packet.
@@ -73,37 +130,28 @@ impl ChainType {
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
/// [`set_hook`]: #method.set_hook
+#[derive(Debug, PartialEq, Eq)]
pub struct Chain {
- pub(crate) chain: *mut sys::nftnl_chain,
- pub(crate) table: Rc<Table>,
+ inner: NfNetlinkAttributes,
}
impl Chain {
- /// Creates a new chain instance inside the given [`Table`] and with the given name.
+ /// Creates a new chain instance inside the given [`Table`].
///
/// [`Table`]: struct.Table.html
- pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain {
- unsafe {
- let chain = try_alloc!(sys::nftnl_chain_alloc());
- sys::nftnl_chain_set_u32(
- chain,
- sys::NFTNL_CHAIN_FAMILY as u16,
- table.get_family() as u32,
- );
- sys::nftnl_chain_set_str(
- chain,
- sys::NFTNL_CHAIN_TABLE as u16,
- table.get_name().as_ptr(),
- );
- sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr());
- Chain { chain, table }
+ pub fn new<T: AsRef<CStr>>(table: &Table) -> Chain {
+ let mut chain = Chain {
+ inner: NfNetlinkAttributes::new(),
+ };
+
+ if let Some(table_name) = table.get_name() {
+ chain.set_table(table_name);
}
- }
- pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self {
- Chain { chain, table }
+ chain
}
+ /*
/// Sets the hook and priority for this chain. Without calling this method the chain will
/// become a "regular chain" without any hook and will thus not receive any traffic unless
/// some rule forward packets to it via goto or jump verdicts.
@@ -112,62 +160,12 @@ impl Chain {
/// hook and is thus a "base chain". A "base chain" is an entry point for packets from the
/// networking stack.
pub fn set_hook(&mut self, hook: Hook, priority: Priority) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32);
- sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority);
- }
- }
-
- /// Set the type of a base chain. This only applies if the chain has been registered
- /// with a hook by calling `set_hook`.
- pub fn set_type(&mut self, chain_type: ChainType) {
- unsafe {
- sys::nftnl_chain_set_str(
- self.chain,
- sys::NFTNL_CHAIN_TYPE as u16,
- chain_type.as_c_str().as_ptr() as *const c_char,
- );
- }
- }
-
- /// Sets the default policy for this chain. That means what action netfilter will apply to
- /// packets processed by this chain, but that did not match any rules in it.
- pub fn set_policy(&mut self, policy: Policy) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32);
- }
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16);
- if ptr == std::ptr::null() {
- return None;
- }
- Some(CStr::from_ptr(ptr))
- }
- }
-
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr());
- }
- }
-
- /// Returns the name of this chain.
- pub fn get_name(&self) -> &CStr {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16);
- if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a chain failed")
- } else {
- CStr::from_ptr(ptr)
- }
- }
+ self.set_hook_type(hook);
+ self.set_hook_priority(priority);
}
+ */
+ /*
/// Returns a textual description of the chain.
pub fn get_str(&self) -> CString {
let mut descr_buf = vec![0i8; 4096];
@@ -182,27 +180,10 @@ impl Chain {
CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
}
}
-
- /// Returns a reference to the [`Table`] this chain belongs to.
- ///
- /// [`Table`]: struct.Table.html
- pub fn get_table(&self) -> Rc<Table> {
- self.table.clone()
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_chain {
- self.chain as *const sys::nftnl_chain
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain {
- self.chain
- }
+ */
}
+/*
impl fmt::Debug for Chain {
/// Returns a string representation of the chain.
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -215,7 +196,64 @@ impl PartialEq for Chain {
self.get_table() == other.get_table() && self.get_name() == other.get_name()
}
}
+*/
+
+/*
+impl NfNetlinkObject for Chain {
+ fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
+ let raw_msg_type = match msg_type {
+ MsgType::Add => NFT_MSG_NEWCHAIN,
+ MsgType::Del => NFT_MSG_DELCHAIN,
+ } as u16;
+ writer.write_header(
+ raw_msg_type,
+ ProtoFamily::Unspec,
+ NLM_F_ACK as u16,
+ seq,
+ None,
+ );
+ self.inner.serialize(writer);
+ writer.finalize_writing_object();
+ }
+
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError> {
+ match attr_type {
+ NFTA_TABLE_NAME => Ok(AttributeType::String(String::from_utf8(buf.to_vec())?)),
+ NFTA_TABLE_FLAGS => {
+ let val = [buf[0], buf[1], buf[2], buf[3]];
+
+ Ok(AttributeType::U32(u32::from_ne_bytes(val)))
+ }
+ NFTA_TABLE_USERDATA => Ok(AttributeType::VecU8(buf.to_vec())),
+ _ => Err(DecodeError::UnsupportedAttributeType(attr_type)),
+ }
+ }
+
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (hdr, msg) = parse_nlmsg(buf)?;
+
+ let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
+
+ if op != NFT_MSG_NEWTABLE && op != NFT_MSG_DELTABLE {
+ return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
+ }
+
+ let (nfgenmsg, attrs, remaining_data) = parse_object(hdr, msg, buf)?;
+
+ let inner = attrs.decode::<Table>()?;
+
+ Ok((
+ Table {
+ inner,
+ family: ProtoFamily::try_from(nfgenmsg.family as i32)?,
+ },
+ remaining_data,
+ ))
+ }
+}
+*/
+/*
unsafe impl NlMsg for Chain {
unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
let raw_msg_type = match msg_type {
@@ -243,7 +281,6 @@ impl Drop for Chain {
}
}
-#[cfg(feature = "query")]
pub fn get_chains_cb<'a>(
header: &libc::nlmsghdr,
_genmsg: &Nfgenmsg,
@@ -302,15 +339,38 @@ pub fn get_chains_cb<'a>(
Ok(())
}
+*/
+
+impl_attr_getters_and_setters!(
+ Chain,
+ [
+ (get_flags, set_flags, with_flags, sys::NFTA_CHAIN_FLAGS, U32, u32),
+ (get_name, set_name, with_name, sys::NFTA_CHAIN_NAME, String, String),
+ (set_hook, get_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook),
+ (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, U32, u32),
+ (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, String, String),
+ // This only applies if the chain has been registered with a hook by calling `set_hook`.
+ (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, String, String),
+ (
+ get_userdata,
+ set_userdata,
+ with_userdata,
+ sys::NFTA_CHAIN_USERDATA,
+ VecU8,
+ Vec<u8>
+ )
+ ]
+);
-#[cfg(feature = "query")]
-pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> {
+/*
+pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, crate::query::Error> {
let mut result = Vec::new();
crate::query::list_objects_with_data(
libc::NFT_MSG_GETCHAIN as u16,
&get_chains_cb,
- &mut (&table, &mut result),
None,
+ &mut (&table, &mut result),
)?;
Ok(result)
}
+*/
diff --git a/src/lib.rs b/src/lib.rs
index 6208bb5..7c330bb 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -100,7 +100,7 @@ pub mod table;
pub use table::Table;
//pub use table::{get_tables_cb, list_tables};
//
-//mod chain;
+mod chain;
//pub use chain::{get_chains_cb, list_chains_for_table};
//pub use chain::{Chain, ChainType, Hook, Policy, Priority};
@@ -141,17 +141,17 @@ pub enum MsgType {
/// Denotes a protocol. Used to specify which protocol a table or set belongs to.
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
-#[repr(u16)]
+#[repr(u32)]
pub enum ProtoFamily {
- Unspec = libc::NFPROTO_UNSPEC as u16,
+ Unspec = libc::NFPROTO_UNSPEC as u32,
/// Inet - Means both IPv4 and IPv6
- Inet = libc::NFPROTO_INET as u16,
- Ipv4 = libc::NFPROTO_IPV4 as u16,
- Arp = libc::NFPROTO_ARP as u16,
- NetDev = libc::NFPROTO_NETDEV as u16,
- Bridge = libc::NFPROTO_BRIDGE as u16,
- Ipv6 = libc::NFPROTO_IPV6 as u16,
- DecNet = libc::NFPROTO_DECNET as u16,
+ Inet = libc::NFPROTO_INET as u32,
+ Ipv4 = libc::NFPROTO_IPV4 as u32,
+ Arp = libc::NFPROTO_ARP as u32,
+ NetDev = libc::NFPROTO_NETDEV as u32,
+ Bridge = libc::NFPROTO_BRIDGE as u32,
+ Ipv6 = libc::NFPROTO_IPV6 as u32,
+ DecNet = libc::NFPROTO_DECNET as u32,
}
#[derive(Error, Debug)]
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index 868560a..435fed3 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -1,15 +1,18 @@
-use std::{collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, ops::Deref};
-
-use libc::{
- nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
- NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR,
+use std::{
+ collections::HashMap,
+ fmt::Debug,
+ marker::PhantomData,
+ mem::{size_of, transmute},
};
-use thiserror::Error;
use crate::{
parser::{
- pad_netlink_object, pad_netlink_object_with_variable_size, Attribute, DecodeError,
- NfNetlinkAttributes, Nfgenmsg,
+ pad_netlink_object, pad_netlink_object_with_variable_size, AttributeType, DecodeError,
+ Nfgenmsg,
+ },
+ sys::{
+ nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
+ NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK,
},
MsgType, ProtoFamily,
};
@@ -123,10 +126,50 @@ impl<'a> HeaderStack<'a> {
}
}
-pub trait NfNetlinkObject: Sized {
+pub trait AttributeDecoder {
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError>;
+}
+
+pub trait NfNetlinkDeserializable: Sized {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>;
+}
+
+pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable {
fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32);
+}
+
+pub trait NfNetlinkSerializable {
+ fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>);
+}
- fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError>;
+pub type NetlinkType = u16;
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>;
+pub trait NfNetlinkAttribute: Debug + Sized {
+ fn get_size(&self) -> usize {
+ size_of::<Self>()
+ }
+
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
+ unsafe fn write_payload(&self, addr: *mut u8);
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct NfNetlinkAttributes {
+ pub attributes: HashMap<NetlinkType, AttributeType>,
+}
+
+impl NfNetlinkAttributes {
+ pub fn new() -> Self {
+ NfNetlinkAttributes {
+ attributes: HashMap::new(),
+ }
+ }
+
+ pub fn set_attr(&mut self, ty: NetlinkType, obj: AttributeType) {
+ self.attributes.insert(ty, obj);
+ }
+
+ pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> {
+ self.attributes.get(&ty)
+ }
}
diff --git a/src/parser.rs b/src/parser.rs
index a01c3cd..8e12c5e 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -1,4 +1,5 @@
use std::{
+ any::TypeId,
collections::HashMap,
fmt::Debug,
mem::{size_of, transmute},
@@ -8,7 +9,10 @@ use std::{
use thiserror::Error;
use crate::{
- nlmsg::{NfNetlinkObject, NfNetlinkWriter},
+ nlmsg::{
+ AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes,
+ NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkSerializable, NfNetlinkWriter,
+ },
sys::{
nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_ALIGNTO, NLMSG_DONE, NLMSG_ERROR,
@@ -25,6 +29,9 @@ pub enum DecodeError {
#[error("The message is too small")]
NlMsgTooSmall,
+ #[error("The message holds unexpected data")]
+ InvalidDataSize,
+
#[error("Invalid subsystem, expected NFTABLES")]
InvalidSubsystem(u8),
@@ -71,13 +78,13 @@ pub fn nft_nlmsg_maxsize() -> u32 {
}
#[inline]
-pub fn pad_netlink_object_with_variable_size(size: usize) -> usize {
+pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize {
// align on a 4 bytes boundary
(size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1)
}
#[inline]
-pub fn pad_netlink_object<T>() -> usize {
+pub const fn pad_netlink_object<T>() -> usize {
let size = size_of::<T>();
pad_netlink_object_with_variable_size(size)
}
@@ -185,20 +192,9 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value)))
}
-pub type NetlinkType = u16;
-
-pub trait NfNetlinkAttribute: Debug + Sized {
- fn get_size(&self) -> usize {
- size_of::<Self>()
- }
-
- // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
- unsafe fn write_payload(&self, addr: *mut u8);
-}
-
/// Write the attribute, preceded by a `libc::nlattr`
// rewrite of `mnl_attr_put`
-fn write_attribute<'a>(ty: NetlinkType, obj: &Attribute, writer: &mut NfNetlinkWriter<'a>) {
+fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, writer: &mut NfNetlinkWriter<'a>) {
// copy the header
let header_len = pad_netlink_object::<libc::nlattr>();
let header = libc::nlattr {
@@ -223,15 +219,15 @@ fn write_attribute<'a>(ty: NetlinkType, obj: &Attribute, writer: &mut NfNetlinkW
}
}
-impl NfNetlinkAttribute for ProtoFamily {
+impl NfNetlinkAttribute for u8 {
unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut u32) = *self as u32;
+ *addr = *self;
}
}
-impl NfNetlinkAttribute for u8 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *addr = *self;
+impl NfNetlinkDeserializable for u8 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf[0], &buf[1..]))
}
}
@@ -241,18 +237,60 @@ impl NfNetlinkAttribute for u16 {
}
}
+impl NfNetlinkDeserializable for u16 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..]))
+ }
+}
+
+impl NfNetlinkAttribute for i32 {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ *(addr as *mut Self) = *self;
+ }
+}
+
+impl NfNetlinkDeserializable for i32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
impl NfNetlinkAttribute for u32 {
unsafe fn write_payload(&self, addr: *mut u8) {
*(addr as *mut Self) = *self;
}
}
+impl NfNetlinkDeserializable for u32 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
+ &buf[4..],
+ ))
+ }
+}
+
impl NfNetlinkAttribute for u64 {
unsafe fn write_payload(&self, addr: *mut u8) {
*(addr as *mut Self) = *self;
}
}
+impl NfNetlinkDeserializable for u64 {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((
+ u64::from_be_bytes([
+ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
+ ]),
+ &buf[8..],
+ ))
+ }
+}
+
+// TODO: safe handling for null-delimited strings
impl NfNetlinkAttribute for String {
fn get_size(&self) -> usize {
self.len()
@@ -263,6 +301,12 @@ impl NfNetlinkAttribute for String {
}
}
+impl NfNetlinkDeserializable for String {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((String::from_utf8(buf.to_vec())?, &[]))
+ }
+}
+
impl NfNetlinkAttribute for Vec<u8> {
fn get_size(&self) -> usize {
self.len()
@@ -273,24 +317,38 @@ impl NfNetlinkAttribute for Vec<u8> {
}
}
-#[derive(Debug, PartialEq, Eq)]
-pub struct NfNetlinkAttributes {
- attributes: HashMap<NetlinkType, Attribute>,
+impl NfNetlinkDeserializable for Vec<u8> {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ Ok((buf.to_vec(), &[]))
+ }
}
-impl NfNetlinkAttributes {
- pub fn new() -> Self {
- NfNetlinkAttributes {
- attributes: HashMap::new(),
+pub type NestedAttribute = NfNetlinkAttributes;
+
+// parts of the NfNetlinkAttribute trait we need for handling nested objects
+impl NestedAttribute {
+ pub fn get_size(&self) -> usize {
+ let mut size = 0;
+
+ for (_type, attr) in self.attributes.iter() {
+ // Attribute header + attribute value
+ size += pad_netlink_object::<nlattr>()
+ + pad_netlink_object_with_variable_size(attr.get_size());
}
- }
- pub fn set_attr(&mut self, ty: NetlinkType, obj: Attribute) {
- self.attributes.insert(ty, obj);
+ size
}
- pub fn get_attr(&self, ty: NetlinkType) -> Option<&Attribute> {
- self.attributes.get(&ty)
+ pub unsafe fn write_payload(&self, mut addr: *mut u8) {
+ for (ty, attr) in self.attributes.iter() {
+ *(addr as *mut nlattr) = nlattr {
+ nla_len: attr.get_size() as u16,
+ nla_type: *ty,
+ };
+ addr = addr.offset(pad_netlink_object::<nlattr>() as isize);
+ attr.write_payload(addr);
+ addr = addr.offset(pad_netlink_object_with_variable_size(attr.get_size()) as isize);
+ }
}
}
@@ -319,7 +377,9 @@ impl<'a> NfNetlinkAttributeReader<'a> {
&self.buf[self.pos..]
}
- pub fn decode<T: NfNetlinkObject>(mut self) -> Result<NfNetlinkAttributes, DecodeError> {
+ pub fn decode<T: AttributeDecoder + 'static>(
+ mut self,
+ ) -> Result<NfNetlinkAttributes, DecodeError> {
while self.remaining_size > pad_netlink_object::<nlattr>() {
let nlattr =
unsafe { *transmute::<*const u8, *const nlattr>(self.buf[self.pos..].as_ptr()) };
@@ -328,19 +388,28 @@ impl<'a> NfNetlinkAttributeReader<'a> {
self.pos += pad_netlink_object::<nlattr>();
let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>();
- self.attrs.set_attr(
+ match T::decode_attribute(
nla_type,
- T::decode_attribute(
- nla_type,
- &self.buf[self.pos..self.pos + attr_remaining_size],
- )?,
- );
+ &self.buf[self.pos..self.pos + attr_remaining_size],
+ ) {
+ Ok(x) => self.attrs.set_attr(nla_type, x),
+ Err(DecodeError::UnsupportedAttributeType(t)) => info!(
+ "Ignore attribute type {} for type id {:?}",
+ t,
+ TypeId::of::<T>()
+ ),
+ Err(e) => return Err(e),
+ }
self.pos += pad_netlink_object_with_variable_size(attr_remaining_size);
self.remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize);
}
- Ok(self.attrs)
+ if self.remaining_size != 0 {
+ Err(DecodeError::InvalidDataSize)
+ } else {
+ Ok(self.attrs)
+ }
}
}
@@ -364,11 +433,7 @@ pub fn parse_object<'a>(
}
}
-pub trait SerializeNfNetlink {
- fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>);
-}
-
-impl SerializeNfNetlink for NfNetlinkAttributes {
+impl NfNetlinkSerializable for NfNetlinkAttributes {
fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) {
// TODO: improve performance by not sorting this
let mut keys: Vec<&NetlinkType> = self.attributes.keys().collect();
@@ -379,9 +444,9 @@ impl SerializeNfNetlink for NfNetlinkAttributes {
}
}
-macro_rules! impl_attribute {
+macro_rules! impl_attribute_holder {
($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => {
- #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+ #[derive(Debug, Clone, PartialEq, Eq)]
pub enum $enum_name {
$(
$internal_name($type),
@@ -403,7 +468,6 @@ macro_rules! impl_attribute {
$enum_name::$internal_name(val) => val.write_payload(addr)
),+
}
-
}
}
@@ -421,15 +485,16 @@ macro_rules! impl_attribute {
};
}
-impl_attribute!(
- Attribute,
+impl_attribute_holder!(
+ AttributeType,
[String, String],
[U8, u8],
[U16, u16],
+ [I32, i32],
[U32, u32],
[U64, u64],
[VecU8, Vec<u8>],
- [ProtoFamily, ProtoFamily]
+ [ChainHook, crate::chain::Hook]
);
#[macro_export]
@@ -439,20 +504,40 @@ macro_rules! impl_attr_getters_and_setters {
$(
#[allow(dead_code)]
pub fn $getter_name(&self) -> Option<&$type> {
- self.inner.get_attr($attr_name as $crate::parser::NetlinkType).map(|x| x.$internal_name()).flatten()
+ self.inner.get_attr($attr_name as $crate::nlmsg::NetlinkType).map(|x| x.$internal_name()).flatten()
}
#[allow(dead_code)]
pub fn $setter_name(&mut self, val: impl Into<$type>) {
- self.inner.set_attr($attr_name as $crate::parser::NetlinkType, $crate::parser::Attribute::$internal_name(val.into()));
+ self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into()));
}
#[allow(dead_code)]
pub fn $in_place_edit_name(mut self, val: impl Into<$type>) -> Self {
- self.inner.set_attr($attr_name as $crate::parser::NetlinkType, $crate::parser::Attribute::$internal_name(val.into()));
+ self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into()));
self
}
+
)+
}
+
+ impl $crate::nlmsg::AttributeDecoder for $struct {
+ #[allow(dead_code)]
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<$crate::parser::AttributeType, $crate::parser::DecodeError> {
+ use $crate::nlmsg::NfNetlinkDeserializable;
+ match attr_type {
+ $(
+ x if x == $attr_name => {
+ let (val, remaining) = <$type>::deserialize(buf)?;
+ if remaining.len() != 0 {
+ return Err($crate::parser::DecodeError::InvalidDataSize);
+ }
+ Ok($crate::parser::AttributeType::$internal_name(val))
+ },
+ )+
+ _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)),
+ }
+ }
+ }
};
}
diff --git a/src/table.rs b/src/table.rs
index 23495a4..66dc667 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,10 +1,13 @@
use std::convert::TryFrom;
use std::fmt::Debug;
-use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::nlmsg::{
+ AttributeDecoder, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject,
+ NfNetlinkSerializable, NfNetlinkWriter,
+};
use crate::parser::{
- get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object, Attribute, DecodeError,
- NfNetlinkAttributeReader, NfNetlinkAttributes, Nfgenmsg, NlMsg, SerializeNfNetlink,
+ get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object, DecodeError,
+ NfNetlinkAttributeReader,
};
use crate::sys::{
self, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
@@ -12,7 +15,7 @@ use crate::sys::{
};
use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily};
-/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol
+/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
@@ -65,20 +68,9 @@ impl NfNetlinkObject for Table {
self.inner.serialize(writer);
writer.finalize_writing_object();
}
+}
- fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError> {
- match attr_type {
- NFTA_TABLE_NAME => Ok(Attribute::String(String::from_utf8(buf.to_vec())?)),
- NFTA_TABLE_FLAGS => {
- let val = [buf[0], buf[1], buf[2], buf[3]];
-
- Ok(Attribute::U32(u32::from_ne_bytes(val)))
- }
- NFTA_TABLE_USERDATA => Ok(Attribute::VecU8(buf.to_vec())),
- _ => Err(DecodeError::UnsupportedAttributeType(attr_type)),
- }
- }
-
+impl NfNetlinkDeserializable for Table {
fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
let (hdr, msg) = parse_nlmsg(buf)?;
@@ -102,9 +94,27 @@ impl NfNetlinkObject for Table {
}
}
+/*
+impl AttributeDecoder for Table {
+ fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError> {
+ match attr_type {
+ NFTA_TABLE_NAME => Ok(AttributeType::String(String::from_utf8(buf.to_vec())?)),
+ NFTA_TABLE_FLAGS => {
+ let val = [buf[0], buf[1], buf[2], buf[3]];
+
+ Ok(AttributeType::U32(u32::from_ne_bytes(val)))
+ }
+ NFTA_TABLE_USERDATA => Ok(AttributeType::VecU8(buf.to_vec())),
+ _ => Err(DecodeError::UnsupportedAttributeType(attr_type)),
+ }
+ }
+}
+*/
+
impl_attr_getters_and_setters!(
Table,
[
+ (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, U32, u32),
(get_name, set_name, with_name, sys::NFTA_TABLE_NAME, String, String),
(
get_userdata,
@@ -113,8 +123,7 @@ impl_attr_getters_and_setters!(
sys::NFTA_TABLE_USERDATA,
VecU8,
Vec<u8>
- ),
- (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, U32, u32)
+ )
]
);
diff --git a/tests/batch.rs b/tests/batch.rs
index 081ee97..740fc19 100644
--- a/tests/batch.rs
+++ b/tests/batch.rs
@@ -1,6 +1,7 @@
mod sys;
use libc::NFNL_MSG_BATCH_BEGIN;
use nix::libc::NFNL_MSG_BATCH_END;
+use rustables::nlmsg::NfNetlinkDeserializable;
use rustables::nlmsg::NfNetlinkObject;
use rustables::parser::{get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object};
use rustables::{Batch, MsgType, Table};
diff --git a/tests/table.rs b/tests/table.rs
index d8a5f1e..5961d65 100644
--- a/tests/table.rs
+++ b/tests/table.rs
@@ -1,6 +1,6 @@
mod sys;
use rustables::{
- nlmsg::NfNetlinkObject,
+ nlmsg::NfNetlinkDeserializable,
parser::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize},
MsgType, Table,
};