aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/add-rules.rs62
-rw-r--r--src/batch.rs6
-rw-r--r--src/chain.rs76
-rw-r--r--src/lib.rs49
-rw-r--r--src/nlmsg.rs33
-rw-r--r--src/parser.rs107
-rw-r--r--src/query.rs10
-rw-r--r--src/table.rs19
-rw-r--r--tests/lib.rs21
9 files changed, 215 insertions, 168 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs
index 3fd1f49..75fc63e 100644
--- a/examples/add-rules.rs
+++ b/examples/add-rules.rs
@@ -37,8 +37,11 @@
//! ```
use ipnetwork::{IpNetwork, Ipv4Network};
-use rustables::{list_chains_for_table, list_tables, Batch, ProtoFamily, Table};
-//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table};
+use rustables::{
+ chain::HookClass, list_chains_for_table, list_tables, Batch, Chain, ChainPolicy, Hook, MsgType,
+ ProtocolFamily, Table,
+};
+//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, Rule, Table};
use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc};
const TABLE_NAME: &str = "example-table";
@@ -46,44 +49,35 @@ const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
const IN_CHAIN_NAME: &str = "chain-for-incoming-packets";
fn main() -> Result<(), Error> {
- /*
// Create a batch. This is used to store all the netlink messages we will later send.
// Creating a new batch also automatically writes the initial batch begin message needed
// to tell netlink this is a single transaction that might arrive over multiple netlink packets.
let mut batch = Batch::new();
// Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet)
- let table = Table::new(TABLE_NAME, ProtoFamily::Inet);
+ let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME);
// Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add
- // this table under its `ProtoFamily::Inet` ruleset.
- batch.add(&table, rustables::MsgType::Add);
+ // this table under its `ProtocolFamily::Inet` ruleset.
+ batch.add(&table, MsgType::Add);
- let table = Table::new("lool", ProtoFamily::Inet);
+ // Create input and output chains under the table we created above.
+ // Hook the chains to the input and output event hooks, with highest priority (priority zero).
+ let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME);
+ let mut in_chain = Chain::new(&table).with_name(IN_CHAIN_NAME);
- batch.add(&table, rustables::MsgType::Add);
+ out_chain.set_hook(Hook::new(HookClass::Out, 0));
+ in_chain.set_hook(Hook::new(HookClass::In, 0));
+
+ // Set the default policies on the chains. If no rule matches a packet processed by the
+ // `out_chain` or the `in_chain` it will accept the packet.
+ out_chain.set_policy(ChainPolicy::Accept);
+ in_chain.set_policy(ChainPolicy::Accept);
+
+ // Add the two chains to the batch with the `MsgType` to tell netfilter to create the chains
+ // under the table.
+ batch.add(&out_chain, MsgType::Add);
+ batch.add(&in_chain, MsgType::Add);
- // // Create input and output chains under the table we created above.
- // // Hook the chains to the input and output event hooks, with highest priority (priority zero).
- // // See the `Chain::set_hook` documentation for details.
- // let mut out_chain = Chain::new(OUT_CHAIN_NAME, Rc::clone(&table));
- // let mut in_chain = Chain::new(IN_CHAIN_NAME, Rc::clone(&table));
- //
- // out_chain.set_hook(rustables::Hook::Out, 0);
- // in_chain.set_hook(rustables::Hook::In, 0);
- //
- // // Set the default policies on the chains. If no rule matches a packet processed by the
- // // `out_chain` or the `in_chain` it will accept the packet.
- // out_chain.set_policy(rustables::Policy::Accept);
- // in_chain.set_policy(rustables::Policy::Accept);
- //
- // let out_chain = Rc::new(out_chain);
- // let in_chain = Rc::new(in_chain);
- //
- // // Add the two chains to the batch with the `MsgType` to tell netfilter to create the chains
- // // under the table.
- // batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add);
- // batch.add(&Rc::clone(&in_chain), rustables::MsgType::Add);
- //
// // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE ===
//
// // Create a new rule object under the input chain.
@@ -175,14 +169,8 @@ fn main() -> Result<(), Error> {
// Finalize the batch and send it. This means the batch end message is written into the batch, telling
// netfilter the we reached the end of the transaction message. It's also converted to a
// Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter.
+ // Finally, the batch is sent over to the kernel.
Ok(batch.send()?)
- */
-
- env_logger::init();
- let tables = list_tables()?;
- println!("{:?}", tables);
- println!("{:?}", list_chains_for_table(&tables[0]));
- Ok(())
}
// Look up the interface index for a given interface name.
diff --git a/src/batch.rs b/src/batch.rs
index a1c7e0f..d885813 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -4,7 +4,7 @@ use thiserror::Error;
use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
use crate::sys::NFNL_SUBSYS_NFTABLES;
-use crate::{MsgType, ProtoFamily};
+use crate::{MsgType, ProtocolFamily};
use crate::query::Error;
use nix::sys::socket::{
@@ -39,7 +39,7 @@ impl Batch {
let seq = 0;
writer.write_header(
libc::NFNL_MSG_BATCH_BEGIN as u16,
- ProtoFamily::Unspec,
+ ProtocolFamily::Unspec,
0,
seq,
Some(libc::NFNL_SUBSYS_NFTABLES as u16),
@@ -79,7 +79,7 @@ impl Batch {
pub fn finalize(mut self) -> Vec<u8> {
self.writer.write_header(
libc::NFNL_MSG_BATCH_END as u16,
- ProtoFamily::Unspec,
+ ProtocolFamily::Unspec,
0,
self.seq,
Some(NFNL_SUBSYS_NFTABLES as u16),
diff --git a/src/chain.rs b/src/chain.rs
index 000a196..60f5f10 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,4 +1,5 @@
-use crate::nlmsg::NfNetlinkSerializable;
+use libc::{NF_ACCEPT, NF_DROP};
+
use crate::nlmsg::{
NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject,
NfNetlinkWriter,
@@ -7,10 +8,11 @@ use crate::parser::{
parse_object, DecodeError, InnerFormat, NestedAttribute, NfNetlinkAttributeReader,
};
use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK};
-use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily, Table};
+use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily, Table};
+use std::convert::TryFrom;
use std::fmt::Debug;
-pub type Priority = i32;
+pub type ChainPriority = i32;
/// The netfilter event hooks a chain can register for.
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
@@ -34,7 +36,7 @@ pub struct Hook {
}
impl Hook {
- pub fn new(class: HookClass, priority: Priority) -> Self {
+ pub fn new(class: HookClass, priority: ChainPriority) -> Self {
Hook {
inner: NestedAttribute::new(),
}
@@ -73,6 +75,10 @@ impl_attr_getters_and_setters!(
);
impl NfNetlinkAttribute for Hook {
+ fn is_nested(&self) -> bool {
+ true
+ }
+
fn get_size(&self) -> usize {
self.inner.get_size()
}
@@ -93,12 +99,36 @@ impl NfNetlinkDeserializable for Hook {
/// A chain policy. Decides what to do with a packet that was processed by the chain but did not
/// match any rules.
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
-#[repr(u32)]
-pub enum Policy {
+#[repr(i32)]
+pub enum ChainPolicy {
/// Accept the packet.
- Accept = libc::NF_ACCEPT as u32,
+ Accept = NF_ACCEPT,
/// Drop the packet.
- Drop = libc::NF_DROP as u32,
+ Drop = NF_DROP,
+}
+
+impl NfNetlinkAttribute for ChainPolicy {
+ fn get_size(&self) -> usize {
+ (*self as i32).get_size()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ (*self as i32).write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ChainPolicy {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (v, remaining_data) = i32::deserialize(buf)?;
+ Ok((
+ match v {
+ NF_ACCEPT => ChainPolicy::Accept,
+ NF_DROP => ChainPolicy::Accept,
+ _ => return Err(DecodeError::UnknownChainPolicy),
+ },
+ remaining_data,
+ ))
+ }
}
/// Base chain type.
@@ -160,6 +190,7 @@ impl NfNetlinkDeserializable for ChainType {
/// [`set_hook`]: #method.set_hook
#[derive(PartialEq, Eq)]
pub struct Chain {
+ family: ProtocolFamily,
inner: NfNetlinkAttributes,
}
@@ -169,6 +200,7 @@ impl Chain {
/// [`Table`]: struct.Table.html
pub fn new(table: &Table) -> Chain {
let mut chain = Chain {
+ family: table.get_family(),
inner: NfNetlinkAttributes::new(),
};
@@ -207,7 +239,9 @@ impl PartialEq for Chain {
impl Debug for Chain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- self.inner_format_struct(f.debug_struct("Chain"))?.finish()
+ let mut res = f.debug_struct("Chain");
+ res.field("family", &self.family);
+ self.inner_format_struct(res)?.finish()
}
}
@@ -217,13 +251,7 @@ impl NfNetlinkObject for Chain {
MsgType::Add => NFT_MSG_NEWCHAIN,
MsgType::Del => NFT_MSG_DELCHAIN,
} as u16;
- writer.write_header(
- raw_msg_type,
- ProtoFamily::Unspec,
- NLM_F_ACK as u16,
- seq,
- None,
- );
+ writer.write_header(raw_msg_type, self.family, NLM_F_ACK as u16, seq, None);
self.inner.serialize(writer);
writer.finalize_writing_object();
}
@@ -231,10 +259,16 @@ impl NfNetlinkObject for Chain {
impl NfNetlinkDeserializable for Chain {
fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (inner, _nfgenmsg, remaining_data) =
+ let (inner, nfgenmsg, remaining_data) =
parse_object::<Self>(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?;
- Ok((Self { inner }, remaining_data))
+ Ok((
+ Self {
+ inner,
+ family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?,
+ },
+ remaining_data,
+ ))
}
}
@@ -250,11 +284,11 @@ impl_attr_getters_and_setters!(
// By calling `set_hook` with a hook the chain that is created will be registered with that
// hook and is thus a "base chain". A "base chain" is an entry point for packets from the
// networking stack.
- (set_hook, get_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook),
- (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, U32, u32),
+ (get_hook, set_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook),
+ (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, ChainPolicy, ChainPolicy),
(get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, String, String),
// This only applies if the chain has been registered with a hook by calling `set_hook`.
- (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, String, String),
+ (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, ChainType, ChainType),
(
get_userdata,
set_userdata,
diff --git a/src/lib.rs b/src/lib.rs
index 61ee5fc..db23b28 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -70,6 +70,7 @@
//! [`nftables`]: https://netfilter.org/projects/nftables/
//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs
+use parser::DecodeError;
use thiserror::Error;
#[macro_use]
@@ -102,7 +103,7 @@ pub use table::Table;
pub mod chain;
pub use chain::list_chains_for_table;
-pub use chain::{Chain, ChainType, Hook, Policy, Priority};
+pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook};
//mod chain_methods;
//pub use chain_methods::ChainMethods;
@@ -141,36 +142,32 @@ pub enum MsgType {
/// Denotes a protocol. Used to specify which protocol a table or set belongs to.
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
-#[repr(u32)]
-pub enum ProtoFamily {
- Unspec = libc::NFPROTO_UNSPEC as u32,
+#[repr(i32)]
+pub enum ProtocolFamily {
+ Unspec = libc::NFPROTO_UNSPEC,
/// Inet - Means both IPv4 and IPv6
- Inet = libc::NFPROTO_INET as u32,
- Ipv4 = libc::NFPROTO_IPV4 as u32,
- Arp = libc::NFPROTO_ARP as u32,
- NetDev = libc::NFPROTO_NETDEV as u32,
- Bridge = libc::NFPROTO_BRIDGE as u32,
- Ipv6 = libc::NFPROTO_IPV6 as u32,
- DecNet = libc::NFPROTO_DECNET as u32,
+ Inet = libc::NFPROTO_INET,
+ Ipv4 = libc::NFPROTO_IPV4,
+ Arp = libc::NFPROTO_ARP,
+ NetDev = libc::NFPROTO_NETDEV,
+ Bridge = libc::NFPROTO_BRIDGE,
+ Ipv6 = libc::NFPROTO_IPV6,
+ DecNet = libc::NFPROTO_DECNET,
}
-#[derive(Error, Debug)]
-#[error("Couldn't find a matching protocol")]
-pub struct InvalidProtocolFamily;
-
-impl TryFrom<i32> for ProtoFamily {
- type Error = InvalidProtocolFamily;
+impl TryFrom<i32> for ProtocolFamily {
+ type Error = DecodeError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
- libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec),
- libc::NFPROTO_INET => Ok(ProtoFamily::Inet),
- libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4),
- libc::NFPROTO_ARP => Ok(ProtoFamily::Arp),
- libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev),
- libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge),
- libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6),
- libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet),
- _ => Err(InvalidProtocolFamily),
+ libc::NFPROTO_UNSPEC => Ok(ProtocolFamily::Unspec),
+ libc::NFPROTO_INET => Ok(ProtocolFamily::Inet),
+ libc::NFPROTO_IPV4 => Ok(ProtocolFamily::Ipv4),
+ libc::NFPROTO_ARP => Ok(ProtocolFamily::Arp),
+ libc::NFPROTO_NETDEV => Ok(ProtocolFamily::NetDev),
+ libc::NFPROTO_BRIDGE => Ok(ProtocolFamily::Bridge),
+ libc::NFPROTO_IPV6 => Ok(ProtocolFamily::Ipv6),
+ libc::NFPROTO_DECNET => Ok(ProtocolFamily::DecNet),
+ _ => Err(DecodeError::InvalidProtocolFamily(value)),
}
}
}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index a1bb200..b7f90e9 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -1,19 +1,14 @@
-use std::{
- collections::HashMap,
- fmt::Debug,
- marker::PhantomData,
- mem::{size_of, transmute},
-};
+use std::{collections::BTreeMap, fmt::Debug, mem::size_of};
use crate::{
parser::{
pad_netlink_object, pad_netlink_object_with_variable_size, AttributeType, DecodeError,
},
sys::{
- nfgenmsg, nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
+ nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
NFNL_SUBSYS_NFTABLES,
},
- MsgType, ProtoFamily,
+ MsgType, ProtocolFamily,
};
pub struct NfNetlinkWriter<'a> {
@@ -49,7 +44,7 @@ impl<'a> NfNetlinkWriter<'a> {
pub fn write_header(
&mut self,
msg_type: u16,
- family: ProtoFamily,
+ family: ProtocolFamily,
flags: u16,
seq: u32,
ressource_id: Option<u16>,
@@ -103,13 +98,14 @@ pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable {
fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32);
}
-pub trait NfNetlinkSerializable {
- fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>);
-}
-
pub type NetlinkType = u16;
pub trait NfNetlinkAttribute: Debug + Sized {
+ // is it a nested argument that must be marked with a NLA_F_NESTED flag?
+ fn is_nested(&self) -> bool {
+ false
+ }
+
fn get_size(&self) -> usize {
size_of::<Self>()
}
@@ -120,13 +116,13 @@ pub trait NfNetlinkAttribute: Debug + Sized {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NfNetlinkAttributes {
- pub attributes: HashMap<NetlinkType, AttributeType>,
+ pub attributes: BTreeMap<NetlinkType, AttributeType>,
}
impl NfNetlinkAttributes {
pub fn new() -> Self {
NfNetlinkAttributes {
- attributes: HashMap::new(),
+ attributes: BTreeMap::new(),
}
}
@@ -137,4 +133,11 @@ impl NfNetlinkAttributes {
pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> {
self.attributes.get(&ty)
}
+
+ pub fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) {
+ let buf = writer.add_data_zeroed(self.get_size());
+ unsafe {
+ self.write_payload(buf.as_mut_ptr());
+ }
+ }
}
diff --git a/src/parser.rs b/src/parser.rs
index 2d05f4f..25033d2 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -1,6 +1,6 @@
use std::{
any::TypeId,
- collections::HashMap,
+ convert::TryFrom,
fmt::{Debug, DebugStruct},
mem::{size_of, transmute},
string::FromUtf8Error,
@@ -11,14 +11,14 @@ use thiserror::Error;
use crate::{
nlmsg::{
AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes,
- NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkSerializable, NfNetlinkWriter,
+ NfNetlinkDeserializable, NfNetlinkWriter,
},
sys::{
nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN,
- NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_ALIGNTO, NLMSG_DONE,
- NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
+ NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_ALIGNTO,
+ NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
},
- InvalidProtocolFamily, ProtoFamily,
+ ProtocolFamily,
};
#[derive(Error, Debug)]
@@ -56,6 +56,9 @@ pub enum DecodeError {
#[error("Invalid type for a chain")]
UnknownChainType,
+ #[error("Invalid policy for a chain")]
+ UnknownChainPolicy,
+
#[error("Unsupported attribute type")]
UnsupportedAttributeType(u16),
@@ -66,7 +69,7 @@ pub enum DecodeError {
StringDecodeFailure(#[from] FromUtf8Error),
#[error("Invalid value for a protocol family")]
- InvalidProtocolFamily(#[from] InvalidProtocolFamily),
+ InvalidProtocolFamily(i32),
#[error("A custom error occured")]
Custom(Box<dyn std::error::Error + 'static>),
@@ -189,29 +192,21 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
/// Write the attribute, preceded by a `libc::nlattr`
// rewrite of `mnl_attr_put`
-fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, writer: &mut NfNetlinkWriter<'a>) {
- // copy the header
+unsafe fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, mut buf: *mut u8) {
let header_len = pad_netlink_object::<libc::nlattr>();
- let header = libc::nlattr {
+ // copy the header
+ *(buf as *mut nlattr) = nlattr {
// nla_len contains the header size + the unpadded attribute length
nla_len: (header_len + obj.get_size() as usize) as u16,
- nla_type: ty,
+ nla_type: if obj.is_nested() {
+ ty | NLA_F_NESTED as u16
+ } else {
+ ty
+ },
};
-
- let buf = writer.add_data_zeroed(header_len);
- unsafe {
- std::ptr::copy_nonoverlapping(
- &header as *const libc::nlattr as *const u8,
- buf.as_mut_ptr(),
- header_len as usize,
- );
- }
-
- let buf = writer.add_data_zeroed(obj.get_size());
+ buf = buf.offset(pad_netlink_object::<nlattr>() as isize);
// copy the attribute data itself
- unsafe {
- obj.write_payload(buf.as_mut_ptr());
- }
+ obj.write_payload(buf);
}
impl NfNetlinkAttribute for u8 {
@@ -228,7 +223,7 @@ impl NfNetlinkDeserializable for u8 {
impl NfNetlinkAttribute for u16 {
unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = *self;
+ *(addr as *mut Self) = self.to_be();
}
}
@@ -240,7 +235,7 @@ impl NfNetlinkDeserializable for u16 {
impl NfNetlinkAttribute for i32 {
unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = *self;
+ *(addr as *mut Self) = self.to_be();
}
}
@@ -255,7 +250,7 @@ impl NfNetlinkDeserializable for i32 {
impl NfNetlinkAttribute for u32 {
unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = *self;
+ *(addr as *mut Self) = self.to_be();
}
}
@@ -270,7 +265,7 @@ impl NfNetlinkDeserializable for u32 {
impl NfNetlinkAttribute for u64 {
unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = *self;
+ *(addr as *mut Self) = self.to_be();
}
}
@@ -322,11 +317,24 @@ impl NfNetlinkDeserializable for Vec<u8> {
}
}
+impl NfNetlinkAttribute for ProtocolFamily {
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ (*self as i32).write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ProtocolFamily {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (v, remaining_data) = i32::deserialize(buf)?;
+ Ok((Self::try_from(v)?, remaining_data))
+ }
+}
+
pub type NestedAttribute = NfNetlinkAttributes;
// parts of the NfNetlinkAttribute trait we need for handling nested objects
-impl NestedAttribute {
- pub fn get_size(&self) -> usize {
+impl NfNetlinkAttribute for NestedAttribute {
+ fn get_size(&self) -> usize {
let mut size = 0;
for (_type, attr) in self.attributes.iter() {
@@ -338,15 +346,12 @@ impl NestedAttribute {
size
}
- pub unsafe fn write_payload(&self, mut addr: *mut u8) {
+ unsafe fn write_payload(&self, mut addr: *mut u8) {
for (ty, attr) in self.attributes.iter() {
- *(addr as *mut nlattr) = nlattr {
- nla_len: attr.get_size() as u16,
- nla_type: *ty,
- };
- addr = addr.offset(pad_netlink_object::<nlattr>() as isize);
- attr.write_payload(addr);
- addr = addr.offset(pad_netlink_object_with_variable_size(attr.get_size()) as isize);
+ write_attribute(*ty, attr, addr);
+ let size = pad_netlink_object::<nlattr>()
+ + pad_netlink_object_with_variable_size(attr.get_size());
+ addr = addr.offset(size as isize);
}
}
}
@@ -412,17 +417,6 @@ impl<'a> NfNetlinkAttributeReader<'a> {
}
}
-impl NfNetlinkSerializable for NfNetlinkAttributes {
- fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) {
- // TODO: improve performance by not sorting this
- let mut keys: Vec<&NetlinkType> = self.attributes.keys().collect();
- keys.sort();
- for k in keys {
- write_attribute(*k, self.attributes.get(k).unwrap(), writer);
- }
- }
-}
-
macro_rules! impl_attribute_holder {
($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => {
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -433,6 +427,14 @@ macro_rules! impl_attribute_holder {
}
impl NfNetlinkAttribute for $enum_name {
+ fn is_nested(&self) -> bool {
+ match self {
+ $(
+ $enum_name::$internal_name(val) => val.is_nested()
+ ),+
+ }
+ }
+
fn get_size(&self) -> usize {
match self {
$(
@@ -480,12 +482,15 @@ impl_attribute_holder!(
[U32, u32],
[U64, u64],
[VecU8, Vec<u8>],
- [ChainHook, crate::chain::Hook]
+ [ChainHook, crate::chain::Hook],
+ [ChainPolicy, crate::chain::ChainPolicy],
+ [ChainType, crate::chain::ChainType],
+ [ProtocolFamily, crate::ProtocolFamily]
);
#[macro_export]
macro_rules! impl_attr_getters_and_setters {
- ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => {
+ ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty $(, $nested:literal)?)),+]) => {
impl $struct {
$(
#[allow(dead_code)]
diff --git a/src/query.rs b/src/query.rs
index f84586a..da886c0 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -4,7 +4,7 @@ use crate::{
nlmsg::{NfNetlinkObject, NfNetlinkWriter},
parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size},
sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI},
- ProtoFamily,
+ ProtocolFamily,
};
use nix::{
@@ -156,7 +156,13 @@ where
{
let mut buffer = Vec::new();
let mut writer = NfNetlinkWriter::new(&mut buffer);
- writer.write_header(msg_type, ProtoFamily::Unspec, NLM_F_DUMP as u16, seq, None);
+ writer.write_header(
+ msg_type,
+ ProtocolFamily::Unspec,
+ NLM_F_DUMP as u16,
+ seq,
+ None,
+ );
writer.finalize_writing_object();
if let Some(filter) = filter {
filter.add_or_remove(&mut writer, crate::MsgType::Add, 0);
diff --git a/src/table.rs b/src/table.rs
index a21f3f2..768eedd 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -2,15 +2,14 @@ use std::convert::TryFrom;
use std::fmt::Debug;
use crate::nlmsg::{
- AttributeDecoder, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject,
- NfNetlinkSerializable, NfNetlinkWriter,
+ NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter,
};
use crate::parser::{parse_object, DecodeError, InnerFormat};
use crate::sys::{
- self, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
- NFT_MSG_NEWTABLE, NLM_F_ACK,
+ self, NFNL_SUBSYS_NFTABLES, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME,
+ NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK,
};
-use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily};
+use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily};
/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
@@ -19,17 +18,21 @@ use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily};
#[derive(PartialEq, Eq)]
pub struct Table {
inner: NfNetlinkAttributes,
- family: ProtoFamily,
+ family: ProtocolFamily,
}
impl Table {
- pub fn new(family: ProtoFamily) -> Table {
+ pub fn new(family: ProtocolFamily) -> Table {
Table {
inner: NfNetlinkAttributes::new(),
family,
}
}
+ pub fn get_family(&self) -> ProtocolFamily {
+ self.family
+ }
+
/*
/// Returns a textual description of the table.
pub fn get_str(&self) -> CString {
@@ -83,7 +86,7 @@ impl NfNetlinkDeserializable for Table {
Ok((
Self {
inner,
- family: ProtoFamily::try_from(nfgenmsg.nfgen_family as i32)?,
+ family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?,
},
remaining_data,
))
diff --git a/tests/lib.rs b/tests/lib.rs
index cf5ddb4..0268b1a 100644
--- a/tests/lib.rs
+++ b/tests/lib.rs
@@ -5,9 +5,9 @@ use libc::AF_UNIX;
use rustables::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
//use rustables::set::SetKey;
use rustables::{sys::*, Chain};
-use rustables::{MsgType, ProtoFamily, Table};
+use rustables::{MsgType, ProtocolFamily, Table};
-//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Set, Table};
+//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, Rule, Set, Table};
pub const TABLE_NAME: &'static str = "mocktable";
pub const CHAIN_NAME: &'static str = "mockchain";
@@ -26,7 +26,7 @@ type NetLinkType = u16;
#[error("empty data")]
pub struct EmptyDataError;
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
+#[derive(Debug, Clone, Eq, PartialOrd, Ord)]
pub enum NetlinkExpr {
Nested(NetLinkType, Vec<NetlinkExpr>),
Final(NetLinkType, Vec<u8>),
@@ -64,7 +64,7 @@ impl NetlinkExpr {
// set the "NESTED" flag
res.extend(&(len as u16).to_le_bytes());
- res.extend(&(ty | 0x8000).to_le_bytes());
+ res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes());
res.extend(sub);
res
@@ -98,8 +98,19 @@ impl NetlinkExpr {
}
}
+impl PartialEq for NetlinkExpr {
+ fn eq(&self, other: &Self) -> bool {
+ match (self.clone().sort(), other.clone().sort()) {
+ (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2,
+ (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2,
+ _ => false,
+ }
+ }
+}
+
pub fn get_test_table() -> Table {
- Table::new(ProtoFamily::Inet)
+ Table::new(ProtocolFamily::Inet)
.with_name(TABLE_NAME)
.with_flags(0u32)
}