diff options
-rw-r--r-- | examples/add-rules.rs | 62 | ||||
-rw-r--r-- | src/batch.rs | 6 | ||||
-rw-r--r-- | src/chain.rs | 76 | ||||
-rw-r--r-- | src/lib.rs | 49 | ||||
-rw-r--r-- | src/nlmsg.rs | 33 | ||||
-rw-r--r-- | src/parser.rs | 107 | ||||
-rw-r--r-- | src/query.rs | 10 | ||||
-rw-r--r-- | src/table.rs | 19 | ||||
-rw-r--r-- | tests/lib.rs | 21 |
9 files changed, 215 insertions, 168 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 3fd1f49..75fc63e 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,8 +37,11 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{list_chains_for_table, list_tables, Batch, ProtoFamily, Table}; -//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; +use rustables::{ + chain::HookClass, list_chains_for_table, list_tables, Batch, Chain, ChainPolicy, Hook, MsgType, + ProtocolFamily, Table, +}; +//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, Rule, Table}; use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; const TABLE_NAME: &str = "example-table"; @@ -46,44 +49,35 @@ const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; const IN_CHAIN_NAME: &str = "chain-for-incoming-packets"; fn main() -> Result<(), Error> { - /* // Create a batch. This is used to store all the netlink messages we will later send. // Creating a new batch also automatically writes the initial batch begin message needed // to tell netlink this is a single transaction that might arrive over multiple netlink packets. let mut batch = Batch::new(); // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet) - let table = Table::new(TABLE_NAME, ProtoFamily::Inet); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add - // this table under its `ProtoFamily::Inet` ruleset. - batch.add(&table, rustables::MsgType::Add); + // this table under its `ProtocolFamily::Inet` ruleset. + batch.add(&table, MsgType::Add); - let table = Table::new("lool", ProtoFamily::Inet); + // Create input and output chains under the table we created above. + // Hook the chains to the input and output event hooks, with highest priority (priority zero). + let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME); + let mut in_chain = Chain::new(&table).with_name(IN_CHAIN_NAME); - batch.add(&table, rustables::MsgType::Add); + out_chain.set_hook(Hook::new(HookClass::Out, 0)); + in_chain.set_hook(Hook::new(HookClass::In, 0)); + + // Set the default policies on the chains. If no rule matches a packet processed by the + // `out_chain` or the `in_chain` it will accept the packet. + out_chain.set_policy(ChainPolicy::Accept); + in_chain.set_policy(ChainPolicy::Accept); + + // Add the two chains to the batch with the `MsgType` to tell netfilter to create the chains + // under the table. + batch.add(&out_chain, MsgType::Add); + batch.add(&in_chain, MsgType::Add); - // // Create input and output chains under the table we created above. - // // Hook the chains to the input and output event hooks, with highest priority (priority zero). - // // See the `Chain::set_hook` documentation for details. - // let mut out_chain = Chain::new(OUT_CHAIN_NAME, Rc::clone(&table)); - // let mut in_chain = Chain::new(IN_CHAIN_NAME, Rc::clone(&table)); - // - // out_chain.set_hook(rustables::Hook::Out, 0); - // in_chain.set_hook(rustables::Hook::In, 0); - // - // // Set the default policies on the chains. If no rule matches a packet processed by the - // // `out_chain` or the `in_chain` it will accept the packet. - // out_chain.set_policy(rustables::Policy::Accept); - // in_chain.set_policy(rustables::Policy::Accept); - // - // let out_chain = Rc::new(out_chain); - // let in_chain = Rc::new(in_chain); - // - // // Add the two chains to the batch with the `MsgType` to tell netfilter to create the chains - // // under the table. - // batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add); - // batch.add(&Rc::clone(&in_chain), rustables::MsgType::Add); - // // // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === // // // Create a new rule object under the input chain. @@ -175,14 +169,8 @@ fn main() -> Result<(), Error> { // Finalize the batch and send it. This means the batch end message is written into the batch, telling // netfilter the we reached the end of the transaction message. It's also converted to a // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter. + // Finally, the batch is sent over to the kernel. Ok(batch.send()?) - */ - - env_logger::init(); - let tables = list_tables()?; - println!("{:?}", tables); - println!("{:?}", list_chains_for_table(&tables[0])); - Ok(()) } // Look up the interface index for a given interface name. diff --git a/src/batch.rs b/src/batch.rs index a1c7e0f..d885813 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -4,7 +4,7 @@ use thiserror::Error; use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; use crate::sys::NFNL_SUBSYS_NFTABLES; -use crate::{MsgType, ProtoFamily}; +use crate::{MsgType, ProtocolFamily}; use crate::query::Error; use nix::sys::socket::{ @@ -39,7 +39,7 @@ impl Batch { let seq = 0; writer.write_header( libc::NFNL_MSG_BATCH_BEGIN as u16, - ProtoFamily::Unspec, + ProtocolFamily::Unspec, 0, seq, Some(libc::NFNL_SUBSYS_NFTABLES as u16), @@ -79,7 +79,7 @@ impl Batch { pub fn finalize(mut self) -> Vec<u8> { self.writer.write_header( libc::NFNL_MSG_BATCH_END as u16, - ProtoFamily::Unspec, + ProtocolFamily::Unspec, 0, self.seq, Some(NFNL_SUBSYS_NFTABLES as u16), diff --git a/src/chain.rs b/src/chain.rs index 000a196..60f5f10 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,4 +1,5 @@ -use crate::nlmsg::NfNetlinkSerializable; +use libc::{NF_ACCEPT, NF_DROP}; + use crate::nlmsg::{ NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, @@ -7,10 +8,11 @@ use crate::parser::{ parse_object, DecodeError, InnerFormat, NestedAttribute, NfNetlinkAttributeReader, }; use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK}; -use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily, Table}; +use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily, Table}; +use std::convert::TryFrom; use std::fmt::Debug; -pub type Priority = i32; +pub type ChainPriority = i32; /// The netfilter event hooks a chain can register for. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] @@ -34,7 +36,7 @@ pub struct Hook { } impl Hook { - pub fn new(class: HookClass, priority: Priority) -> Self { + pub fn new(class: HookClass, priority: ChainPriority) -> Self { Hook { inner: NestedAttribute::new(), } @@ -73,6 +75,10 @@ impl_attr_getters_and_setters!( ); impl NfNetlinkAttribute for Hook { + fn is_nested(&self) -> bool { + true + } + fn get_size(&self) -> usize { self.inner.get_size() } @@ -93,12 +99,36 @@ impl NfNetlinkDeserializable for Hook { /// A chain policy. Decides what to do with a packet that was processed by the chain but did not /// match any rules. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(u32)] -pub enum Policy { +#[repr(i32)] +pub enum ChainPolicy { /// Accept the packet. - Accept = libc::NF_ACCEPT as u32, + Accept = NF_ACCEPT, /// Drop the packet. - Drop = libc::NF_DROP as u32, + Drop = NF_DROP, +} + +impl NfNetlinkAttribute for ChainPolicy { + fn get_size(&self) -> usize { + (*self as i32).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainPolicy { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok(( + match v { + NF_ACCEPT => ChainPolicy::Accept, + NF_DROP => ChainPolicy::Accept, + _ => return Err(DecodeError::UnknownChainPolicy), + }, + remaining_data, + )) + } } /// Base chain type. @@ -160,6 +190,7 @@ impl NfNetlinkDeserializable for ChainType { /// [`set_hook`]: #method.set_hook #[derive(PartialEq, Eq)] pub struct Chain { + family: ProtocolFamily, inner: NfNetlinkAttributes, } @@ -169,6 +200,7 @@ impl Chain { /// [`Table`]: struct.Table.html pub fn new(table: &Table) -> Chain { let mut chain = Chain { + family: table.get_family(), inner: NfNetlinkAttributes::new(), }; @@ -207,7 +239,9 @@ impl PartialEq for Chain { impl Debug for Chain { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.inner_format_struct(f.debug_struct("Chain"))?.finish() + let mut res = f.debug_struct("Chain"); + res.field("family", &self.family); + self.inner_format_struct(res)?.finish() } } @@ -217,13 +251,7 @@ impl NfNetlinkObject for Chain { MsgType::Add => NFT_MSG_NEWCHAIN, MsgType::Del => NFT_MSG_DELCHAIN, } as u16; - writer.write_header( - raw_msg_type, - ProtoFamily::Unspec, - NLM_F_ACK as u16, - seq, - None, - ); + writer.write_header(raw_msg_type, self.family, NLM_F_ACK as u16, seq, None); self.inner.serialize(writer); writer.finalize_writing_object(); } @@ -231,10 +259,16 @@ impl NfNetlinkObject for Chain { impl NfNetlinkDeserializable for Chain { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (inner, _nfgenmsg, remaining_data) = + let (inner, nfgenmsg, remaining_data) = parse_object::<Self>(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?; - Ok((Self { inner }, remaining_data)) + Ok(( + Self { + inner, + family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, + }, + remaining_data, + )) } } @@ -250,11 +284,11 @@ impl_attr_getters_and_setters!( // By calling `set_hook` with a hook the chain that is created will be registered with that // hook and is thus a "base chain". A "base chain" is an entry point for packets from the // networking stack. - (set_hook, get_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook), - (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, U32, u32), + (get_hook, set_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook), + (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, ChainPolicy, ChainPolicy), (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, String, String), // This only applies if the chain has been registered with a hook by calling `set_hook`. - (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, String, String), + (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, ChainType, ChainType), ( get_userdata, set_userdata, @@ -70,6 +70,7 @@ //! [`nftables`]: https://netfilter.org/projects/nftables/ //! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs +use parser::DecodeError; use thiserror::Error; #[macro_use] @@ -102,7 +103,7 @@ pub use table::Table; pub mod chain; pub use chain::list_chains_for_table; -pub use chain::{Chain, ChainType, Hook, Policy, Priority}; +pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook}; //mod chain_methods; //pub use chain_methods::ChainMethods; @@ -141,36 +142,32 @@ pub enum MsgType { /// Denotes a protocol. Used to specify which protocol a table or set belongs to. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(u32)] -pub enum ProtoFamily { - Unspec = libc::NFPROTO_UNSPEC as u32, +#[repr(i32)] +pub enum ProtocolFamily { + Unspec = libc::NFPROTO_UNSPEC, /// Inet - Means both IPv4 and IPv6 - Inet = libc::NFPROTO_INET as u32, - Ipv4 = libc::NFPROTO_IPV4 as u32, - Arp = libc::NFPROTO_ARP as u32, - NetDev = libc::NFPROTO_NETDEV as u32, - Bridge = libc::NFPROTO_BRIDGE as u32, - Ipv6 = libc::NFPROTO_IPV6 as u32, - DecNet = libc::NFPROTO_DECNET as u32, + Inet = libc::NFPROTO_INET, + Ipv4 = libc::NFPROTO_IPV4, + Arp = libc::NFPROTO_ARP, + NetDev = libc::NFPROTO_NETDEV, + Bridge = libc::NFPROTO_BRIDGE, + Ipv6 = libc::NFPROTO_IPV6, + DecNet = libc::NFPROTO_DECNET, } -#[derive(Error, Debug)] -#[error("Couldn't find a matching protocol")] -pub struct InvalidProtocolFamily; - -impl TryFrom<i32> for ProtoFamily { - type Error = InvalidProtocolFamily; +impl TryFrom<i32> for ProtocolFamily { + type Error = DecodeError; fn try_from(value: i32) -> Result<Self, Self::Error> { match value { - libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec), - libc::NFPROTO_INET => Ok(ProtoFamily::Inet), - libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4), - libc::NFPROTO_ARP => Ok(ProtoFamily::Arp), - libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev), - libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge), - libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6), - libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet), - _ => Err(InvalidProtocolFamily), + libc::NFPROTO_UNSPEC => Ok(ProtocolFamily::Unspec), + libc::NFPROTO_INET => Ok(ProtocolFamily::Inet), + libc::NFPROTO_IPV4 => Ok(ProtocolFamily::Ipv4), + libc::NFPROTO_ARP => Ok(ProtocolFamily::Arp), + libc::NFPROTO_NETDEV => Ok(ProtocolFamily::NetDev), + libc::NFPROTO_BRIDGE => Ok(ProtocolFamily::Bridge), + libc::NFPROTO_IPV6 => Ok(ProtocolFamily::Ipv6), + libc::NFPROTO_DECNET => Ok(ProtocolFamily::DecNet), + _ => Err(DecodeError::InvalidProtocolFamily(value)), } } } diff --git a/src/nlmsg.rs b/src/nlmsg.rs index a1bb200..b7f90e9 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -1,19 +1,14 @@ -use std::{ - collections::HashMap, - fmt::Debug, - marker::PhantomData, - mem::{size_of, transmute}, -}; +use std::{collections::BTreeMap, fmt::Debug, mem::size_of}; use crate::{ parser::{ pad_netlink_object, pad_netlink_object_with_variable_size, AttributeType, DecodeError, }, sys::{ - nfgenmsg, nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, }, - MsgType, ProtoFamily, + MsgType, ProtocolFamily, }; pub struct NfNetlinkWriter<'a> { @@ -49,7 +44,7 @@ impl<'a> NfNetlinkWriter<'a> { pub fn write_header( &mut self, msg_type: u16, - family: ProtoFamily, + family: ProtocolFamily, flags: u16, seq: u32, ressource_id: Option<u16>, @@ -103,13 +98,14 @@ pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable { fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32); } -pub trait NfNetlinkSerializable { - fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>); -} - pub type NetlinkType = u16; pub trait NfNetlinkAttribute: Debug + Sized { + // is it a nested argument that must be marked with a NLA_F_NESTED flag? + fn is_nested(&self) -> bool { + false + } + fn get_size(&self) -> usize { size_of::<Self>() } @@ -120,13 +116,13 @@ pub trait NfNetlinkAttribute: Debug + Sized { #[derive(Debug, Clone, PartialEq, Eq)] pub struct NfNetlinkAttributes { - pub attributes: HashMap<NetlinkType, AttributeType>, + pub attributes: BTreeMap<NetlinkType, AttributeType>, } impl NfNetlinkAttributes { pub fn new() -> Self { NfNetlinkAttributes { - attributes: HashMap::new(), + attributes: BTreeMap::new(), } } @@ -137,4 +133,11 @@ impl NfNetlinkAttributes { pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> { self.attributes.get(&ty) } + + pub fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) { + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } + } } diff --git a/src/parser.rs b/src/parser.rs index 2d05f4f..25033d2 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,6 +1,6 @@ use std::{ any::TypeId, - collections::HashMap, + convert::TryFrom, fmt::{Debug, DebugStruct}, mem::{size_of, transmute}, string::FromUtf8Error, @@ -11,14 +11,14 @@ use thiserror::Error; use crate::{ nlmsg::{ AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes, - NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkSerializable, NfNetlinkWriter, + NfNetlinkDeserializable, NfNetlinkWriter, }, sys::{ nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, - NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_ALIGNTO, NLMSG_DONE, - NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, + NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_ALIGNTO, + NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, }, - InvalidProtocolFamily, ProtoFamily, + ProtocolFamily, }; #[derive(Error, Debug)] @@ -56,6 +56,9 @@ pub enum DecodeError { #[error("Invalid type for a chain")] UnknownChainType, + #[error("Invalid policy for a chain")] + UnknownChainPolicy, + #[error("Unsupported attribute type")] UnsupportedAttributeType(u16), @@ -66,7 +69,7 @@ pub enum DecodeError { StringDecodeFailure(#[from] FromUtf8Error), #[error("Invalid value for a protocol family")] - InvalidProtocolFamily(#[from] InvalidProtocolFamily), + InvalidProtocolFamily(i32), #[error("A custom error occured")] Custom(Box<dyn std::error::Error + 'static>), @@ -189,29 +192,21 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, writer: &mut NfNetlinkWriter<'a>) { - // copy the header +unsafe fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, mut buf: *mut u8) { let header_len = pad_netlink_object::<libc::nlattr>(); - let header = libc::nlattr { + // copy the header + *(buf as *mut nlattr) = nlattr { // nla_len contains the header size + the unpadded attribute length nla_len: (header_len + obj.get_size() as usize) as u16, - nla_type: ty, + nla_type: if obj.is_nested() { + ty | NLA_F_NESTED as u16 + } else { + ty + }, }; - - let buf = writer.add_data_zeroed(header_len); - unsafe { - std::ptr::copy_nonoverlapping( - &header as *const libc::nlattr as *const u8, - buf.as_mut_ptr(), - header_len as usize, - ); - } - - let buf = writer.add_data_zeroed(obj.get_size()); + buf = buf.offset(pad_netlink_object::<nlattr>() as isize); // copy the attribute data itself - unsafe { - obj.write_payload(buf.as_mut_ptr()); - } + obj.write_payload(buf); } impl NfNetlinkAttribute for u8 { @@ -228,7 +223,7 @@ impl NfNetlinkDeserializable for u8 { impl NfNetlinkAttribute for u16 { unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = *self; + *(addr as *mut Self) = self.to_be(); } } @@ -240,7 +235,7 @@ impl NfNetlinkDeserializable for u16 { impl NfNetlinkAttribute for i32 { unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = *self; + *(addr as *mut Self) = self.to_be(); } } @@ -255,7 +250,7 @@ impl NfNetlinkDeserializable for i32 { impl NfNetlinkAttribute for u32 { unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = *self; + *(addr as *mut Self) = self.to_be(); } } @@ -270,7 +265,7 @@ impl NfNetlinkDeserializable for u32 { impl NfNetlinkAttribute for u64 { unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = *self; + *(addr as *mut Self) = self.to_be(); } } @@ -322,11 +317,24 @@ impl NfNetlinkDeserializable for Vec<u8> { } } +impl NfNetlinkAttribute for ProtocolFamily { + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ProtocolFamily { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok((Self::try_from(v)?, remaining_data)) + } +} + pub type NestedAttribute = NfNetlinkAttributes; // parts of the NfNetlinkAttribute trait we need for handling nested objects -impl NestedAttribute { - pub fn get_size(&self) -> usize { +impl NfNetlinkAttribute for NestedAttribute { + fn get_size(&self) -> usize { let mut size = 0; for (_type, attr) in self.attributes.iter() { @@ -338,15 +346,12 @@ impl NestedAttribute { size } - pub unsafe fn write_payload(&self, mut addr: *mut u8) { + unsafe fn write_payload(&self, mut addr: *mut u8) { for (ty, attr) in self.attributes.iter() { - *(addr as *mut nlattr) = nlattr { - nla_len: attr.get_size() as u16, - nla_type: *ty, - }; - addr = addr.offset(pad_netlink_object::<nlattr>() as isize); - attr.write_payload(addr); - addr = addr.offset(pad_netlink_object_with_variable_size(attr.get_size()) as isize); + write_attribute(*ty, attr, addr); + let size = pad_netlink_object::<nlattr>() + + pad_netlink_object_with_variable_size(attr.get_size()); + addr = addr.offset(size as isize); } } } @@ -412,17 +417,6 @@ impl<'a> NfNetlinkAttributeReader<'a> { } } -impl NfNetlinkSerializable for NfNetlinkAttributes { - fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) { - // TODO: improve performance by not sorting this - let mut keys: Vec<&NetlinkType> = self.attributes.keys().collect(); - keys.sort(); - for k in keys { - write_attribute(*k, self.attributes.get(k).unwrap(), writer); - } - } -} - macro_rules! impl_attribute_holder { ($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => { #[derive(Debug, Clone, PartialEq, Eq)] @@ -433,6 +427,14 @@ macro_rules! impl_attribute_holder { } impl NfNetlinkAttribute for $enum_name { + fn is_nested(&self) -> bool { + match self { + $( + $enum_name::$internal_name(val) => val.is_nested() + ),+ + } + } + fn get_size(&self) -> usize { match self { $( @@ -480,12 +482,15 @@ impl_attribute_holder!( [U32, u32], [U64, u64], [VecU8, Vec<u8>], - [ChainHook, crate::chain::Hook] + [ChainHook, crate::chain::Hook], + [ChainPolicy, crate::chain::ChainPolicy], + [ChainType, crate::chain::ChainType], + [ProtocolFamily, crate::ProtocolFamily] ); #[macro_export] macro_rules! impl_attr_getters_and_setters { - ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty $(, $nested:literal)?)),+]) => { impl $struct { $( #[allow(dead_code)] diff --git a/src/query.rs b/src/query.rs index f84586a..da886c0 100644 --- a/src/query.rs +++ b/src/query.rs @@ -4,7 +4,7 @@ use crate::{ nlmsg::{NfNetlinkObject, NfNetlinkWriter}, parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size}, sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI}, - ProtoFamily, + ProtocolFamily, }; use nix::{ @@ -156,7 +156,13 @@ where { let mut buffer = Vec::new(); let mut writer = NfNetlinkWriter::new(&mut buffer); - writer.write_header(msg_type, ProtoFamily::Unspec, NLM_F_DUMP as u16, seq, None); + writer.write_header( + msg_type, + ProtocolFamily::Unspec, + NLM_F_DUMP as u16, + seq, + None, + ); writer.finalize_writing_object(); if let Some(filter) = filter { filter.add_or_remove(&mut writer, crate::MsgType::Add, 0); diff --git a/src/table.rs b/src/table.rs index a21f3f2..768eedd 100644 --- a/src/table.rs +++ b/src/table.rs @@ -2,15 +2,14 @@ use std::convert::TryFrom; use std::fmt::Debug; use crate::nlmsg::{ - AttributeDecoder, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, - NfNetlinkSerializable, NfNetlinkWriter, + NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, }; use crate::parser::{parse_object, DecodeError, InnerFormat}; use crate::sys::{ - self, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, - NFT_MSG_NEWTABLE, NLM_F_ACK, + self, NFNL_SUBSYS_NFTABLES, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, + NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, }; -use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily}; +use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. @@ -19,17 +18,21 @@ use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily}; #[derive(PartialEq, Eq)] pub struct Table { inner: NfNetlinkAttributes, - family: ProtoFamily, + family: ProtocolFamily, } impl Table { - pub fn new(family: ProtoFamily) -> Table { + pub fn new(family: ProtocolFamily) -> Table { Table { inner: NfNetlinkAttributes::new(), family, } } + pub fn get_family(&self) -> ProtocolFamily { + self.family + } + /* /// Returns a textual description of the table. pub fn get_str(&self) -> CString { @@ -83,7 +86,7 @@ impl NfNetlinkDeserializable for Table { Ok(( Self { inner, - family: ProtoFamily::try_from(nfgenmsg.nfgen_family as i32)?, + family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, }, remaining_data, )) diff --git a/tests/lib.rs b/tests/lib.rs index cf5ddb4..0268b1a 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -5,9 +5,9 @@ use libc::AF_UNIX; use rustables::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; //use rustables::set::SetKey; use rustables::{sys::*, Chain}; -use rustables::{MsgType, ProtoFamily, Table}; +use rustables::{MsgType, ProtocolFamily, Table}; -//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Set, Table}; +//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, Rule, Set, Table}; pub const TABLE_NAME: &'static str = "mocktable"; pub const CHAIN_NAME: &'static str = "mockchain"; @@ -26,7 +26,7 @@ type NetLinkType = u16; #[error("empty data")] pub struct EmptyDataError; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Eq, PartialOrd, Ord)] pub enum NetlinkExpr { Nested(NetLinkType, Vec<NetlinkExpr>), Final(NetLinkType, Vec<u8>), @@ -64,7 +64,7 @@ impl NetlinkExpr { // set the "NESTED" flag res.extend(&(len as u16).to_le_bytes()); - res.extend(&(ty | 0x8000).to_le_bytes()); + res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes()); res.extend(sub); res @@ -98,8 +98,19 @@ impl NetlinkExpr { } } +impl PartialEq for NetlinkExpr { + fn eq(&self, other: &Self) -> bool { + match (self.clone().sort(), other.clone().sort()) { + (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2, + _ => false, + } + } +} + pub fn get_test_table() -> Table { - Table::new(ProtoFamily::Inet) + Table::new(ProtocolFamily::Inet) .with_name(TABLE_NAME) .with_flags(0u32) } |