diff options
-rw-r--r-- | build.rs | 6 | ||||
-rw-r--r-- | examples/add-rules.rs | 7 | ||||
-rw-r--r-- | src/chain.rs | 286 | ||||
-rw-r--r-- | src/lib.rs | 20 | ||||
-rw-r--r-- | src/nlmsg.rs | 65 | ||||
-rw-r--r-- | src/parser.rs | 193 | ||||
-rw-r--r-- | src/table.rs | 47 | ||||
-rw-r--r-- | tests/batch.rs | 1 | ||||
-rw-r--r-- | tests/table.rs | 2 |
9 files changed, 412 insertions, 215 deletions
@@ -4,7 +4,6 @@ use bindgen; use lazy_static::lazy_static; use regex::{Captures, Regex}; use std::borrow::Cow; -use std::env; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -16,11 +15,6 @@ fn main() { generate_sys(); } -fn get_env(var: &'static str) -> Option<PathBuf> { - println!("cargo:rerun-if-env-changed={}", var); - env::var_os(var).map(PathBuf::from) -} - /// `bindgen`erate a rust sys file from the C kernel headers of the nf_tables capabilities. fn generate_sys() { // Tell cargo to invalidate the built crate whenever the headers change. diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 11e7b6f..812721c 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,7 +37,7 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{Batch, ProtoFamily, Table}; +use rustables::{table::list_tables, Batch, ProtoFamily, Table}; //use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; @@ -46,6 +46,7 @@ const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; const IN_CHAIN_NAME: &str = "chain-for-incoming-packets"; fn main() -> Result<(), Error> { + /* // Create a batch. This is used to store all the netlink messages we will later send. // Creating a new batch also automatically writes the initial batch begin message needed // to tell netlink this is a single transaction that might arrive over multiple netlink packets. @@ -175,6 +176,10 @@ fn main() -> Result<(), Error> { // netfilter the we reached the end of the transaction message. It's also converted to a // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter. Ok(batch.send()?) + */ + + println!("{:?}", list_tables()); + Ok(()) } // Look up the interface index for a given interface name. diff --git a/src/chain.rs b/src/chain.rs index a99d7f8..e29b239 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,9 +1,10 @@ -use crate::nlmsg::NlMsg; -#[cfg(feature = "query")] -use crate::query::{Nfgenmsg, ParseError}; -use crate::sys::{self as sys, libc}; -use crate::{MsgType, Table}; -#[cfg(feature = "query")] +use crate::nlmsg::{ + AttributeDecoder, NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, + NfNetlinkObject, NfNetlinkWriter, +}; +use crate::parser::{DecodeError, NestedAttribute, NfNetlinkAttributeReader}; +use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK}; +use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily, Table}; use std::convert::TryFrom; use std::{ ffi::{c_void, CStr, CString}, @@ -15,24 +16,80 @@ use std::{ pub type Priority = i32; /// The netfilter event hooks a chain can register for. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum Hook { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(u32)] +pub enum HookClass { /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. - PreRouting = libc::NF_INET_PRE_ROUTING as u16, + PreRouting = libc::NF_INET_PRE_ROUTING as u32, /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. - In = libc::NF_INET_LOCAL_IN as u16, + In = libc::NF_INET_LOCAL_IN as u32, /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. - Forward = libc::NF_INET_FORWARD as u16, + Forward = libc::NF_INET_FORWARD as u32, /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. - Out = libc::NF_INET_LOCAL_OUT as u16, + Out = libc::NF_INET_LOCAL_OUT as u32, /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. - PostRouting = libc::NF_INET_POST_ROUTING as u16, + PostRouting = libc::NF_INET_POST_ROUTING as u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Hook { + inner: NestedAttribute, +} + +impl Hook { + fn new(class: HookClass, priority: Priority) -> Self { + Hook { + inner: NestedAttribute::new(), + } + .with_hook_class(class as u32) + .with_hook_priority(priority as u32) + } +} + +impl_attr_getters_and_setters!( + Hook, + [ + // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + ( + get_hook_class, + set_hook_class, + with_hook_class, + sys::NFTA_HOOK_HOOKNUM, + U32, + u32 + ), + ( + get_hook_priority, + set_hook_priority, + with_hook_priority, + sys::NFTA_HOOK_PRIORITY, + U32, + u32 + ) + ] +); + +impl NfNetlinkAttribute for Hook { + fn get_size(&self) -> usize { + self.inner.get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.inner.write_payload(addr) + } +} + +impl NfNetlinkDeserializable for Hook { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let reader = NfNetlinkAttributeReader::new(buf, buf.len())?; + let inner = reader.decode::<Self>()?; + Ok((Hook { inner }, &[])) + } } /// A chain policy. Decides what to do with a packet that was processed by the chain but did not /// match any rules. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] #[repr(u32)] pub enum Policy { /// Accept the packet. @@ -73,37 +130,28 @@ impl ChainType { /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html /// [`set_hook`]: #method.set_hook +#[derive(Debug, PartialEq, Eq)] pub struct Chain { - pub(crate) chain: *mut sys::nftnl_chain, - pub(crate) table: Rc<Table>, + inner: NfNetlinkAttributes, } impl Chain { - /// Creates a new chain instance inside the given [`Table`] and with the given name. + /// Creates a new chain instance inside the given [`Table`]. /// /// [`Table`]: struct.Table.html - pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain { - unsafe { - let chain = try_alloc!(sys::nftnl_chain_alloc()); - sys::nftnl_chain_set_u32( - chain, - sys::NFTNL_CHAIN_FAMILY as u16, - table.get_family() as u32, - ); - sys::nftnl_chain_set_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - table.get_name().as_ptr(), - ); - sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr()); - Chain { chain, table } + pub fn new<T: AsRef<CStr>>(table: &Table) -> Chain { + let mut chain = Chain { + inner: NfNetlinkAttributes::new(), + }; + + if let Some(table_name) = table.get_name() { + chain.set_table(table_name); } - } - pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self { - Chain { chain, table } + chain } + /* /// Sets the hook and priority for this chain. Without calling this method the chain will /// become a "regular chain" without any hook and will thus not receive any traffic unless /// some rule forward packets to it via goto or jump verdicts. @@ -112,62 +160,12 @@ impl Chain { /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the /// networking stack. pub fn set_hook(&mut self, hook: Hook, priority: Priority) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32); - sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority); - } - } - - /// Set the type of a base chain. This only applies if the chain has been registered - /// with a hook by calling `set_hook`. - pub fn set_type(&mut self, chain_type: ChainType) { - unsafe { - sys::nftnl_chain_set_str( - self.chain, - sys::NFTNL_CHAIN_TYPE as u16, - chain_type.as_c_str().as_ptr() as *const c_char, - ); - } - } - - /// Sets the default policy for this chain. That means what action netfilter will apply to - /// packets processed by this chain, but that did not match any rules in it. - pub fn set_policy(&mut self, policy: Policy) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32); - } - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16); - if ptr == std::ptr::null() { - return None; - } - Some(CStr::from_ptr(ptr)) - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns the name of this chain. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } + self.set_hook_type(hook); + self.set_hook_priority(priority); } + */ + /* /// Returns a textual description of the chain. pub fn get_str(&self) -> CString { let mut descr_buf = vec![0i8; 4096]; @@ -182,27 +180,10 @@ impl Chain { CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() } } - - /// Returns a reference to the [`Table`] this chain belongs to. - /// - /// [`Table`]: struct.Table.html - pub fn get_table(&self) -> Rc<Table> { - self.table.clone() - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_chain { - self.chain as *const sys::nftnl_chain - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain { - self.chain - } + */ } +/* impl fmt::Debug for Chain { /// Returns a string representation of the chain. fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -215,7 +196,64 @@ impl PartialEq for Chain { self.get_table() == other.get_table() && self.get_name() == other.get_name() } } +*/ + +/* +impl NfNetlinkObject for Chain { + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => NFT_MSG_NEWCHAIN, + MsgType::Del => NFT_MSG_DELCHAIN, + } as u16; + writer.write_header( + raw_msg_type, + ProtoFamily::Unspec, + NLM_F_ACK as u16, + seq, + None, + ); + self.inner.serialize(writer); + writer.finalize_writing_object(); + } + + fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError> { + match attr_type { + NFTA_TABLE_NAME => Ok(AttributeType::String(String::from_utf8(buf.to_vec())?)), + NFTA_TABLE_FLAGS => { + let val = [buf[0], buf[1], buf[2], buf[3]]; + + Ok(AttributeType::U32(u32::from_ne_bytes(val))) + } + NFTA_TABLE_USERDATA => Ok(AttributeType::VecU8(buf.to_vec())), + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } + + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (hdr, msg) = parse_nlmsg(buf)?; + + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; + + if op != NFT_MSG_NEWTABLE && op != NFT_MSG_DELTABLE { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } + + let (nfgenmsg, attrs, remaining_data) = parse_object(hdr, msg, buf)?; + + let inner = attrs.decode::<Table>()?; + + Ok(( + Table { + inner, + family: ProtoFamily::try_from(nfgenmsg.family as i32)?, + }, + remaining_data, + )) + } +} +*/ +/* unsafe impl NlMsg for Chain { unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { let raw_msg_type = match msg_type { @@ -243,7 +281,6 @@ impl Drop for Chain { } } -#[cfg(feature = "query")] pub fn get_chains_cb<'a>( header: &libc::nlmsghdr, _genmsg: &Nfgenmsg, @@ -302,15 +339,38 @@ pub fn get_chains_cb<'a>( Ok(()) } +*/ + +impl_attr_getters_and_setters!( + Chain, + [ + (get_flags, set_flags, with_flags, sys::NFTA_CHAIN_FLAGS, U32, u32), + (get_name, set_name, with_name, sys::NFTA_CHAIN_NAME, String, String), + (set_hook, get_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook), + (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, U32, u32), + (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, String, String), + // This only applies if the chain has been registered with a hook by calling `set_hook`. + (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, String, String), + ( + get_userdata, + set_userdata, + with_userdata, + sys::NFTA_CHAIN_USERDATA, + VecU8, + Vec<u8> + ) + ] +); -#[cfg(feature = "query")] -pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { +/* +pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, crate::query::Error> { let mut result = Vec::new(); crate::query::list_objects_with_data( libc::NFT_MSG_GETCHAIN as u16, &get_chains_cb, - &mut (&table, &mut result), None, + &mut (&table, &mut result), )?; Ok(result) } +*/ @@ -100,7 +100,7 @@ pub mod table; pub use table::Table; //pub use table::{get_tables_cb, list_tables}; // -//mod chain; +mod chain; //pub use chain::{get_chains_cb, list_chains_for_table}; //pub use chain::{Chain, ChainType, Hook, Policy, Priority}; @@ -141,17 +141,17 @@ pub enum MsgType { /// Denotes a protocol. Used to specify which protocol a table or set belongs to. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(u16)] +#[repr(u32)] pub enum ProtoFamily { - Unspec = libc::NFPROTO_UNSPEC as u16, + Unspec = libc::NFPROTO_UNSPEC as u32, /// Inet - Means both IPv4 and IPv6 - Inet = libc::NFPROTO_INET as u16, - Ipv4 = libc::NFPROTO_IPV4 as u16, - Arp = libc::NFPROTO_ARP as u16, - NetDev = libc::NFPROTO_NETDEV as u16, - Bridge = libc::NFPROTO_BRIDGE as u16, - Ipv6 = libc::NFPROTO_IPV6 as u16, - DecNet = libc::NFPROTO_DECNET as u16, + Inet = libc::NFPROTO_INET as u32, + Ipv4 = libc::NFPROTO_IPV4 as u32, + Arp = libc::NFPROTO_ARP as u32, + NetDev = libc::NFPROTO_NETDEV as u32, + Bridge = libc::NFPROTO_BRIDGE as u32, + Ipv6 = libc::NFPROTO_IPV6 as u32, + DecNet = libc::NFPROTO_DECNET as u32, } #[derive(Error, Debug)] diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 868560a..435fed3 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -1,15 +1,18 @@ -use std::{collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, ops::Deref}; - -use libc::{ - nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, - NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR, +use std::{ + collections::HashMap, + fmt::Debug, + marker::PhantomData, + mem::{size_of, transmute}, }; -use thiserror::Error; use crate::{ parser::{ - pad_netlink_object, pad_netlink_object_with_variable_size, Attribute, DecodeError, - NfNetlinkAttributes, Nfgenmsg, + pad_netlink_object, pad_netlink_object_with_variable_size, AttributeType, DecodeError, + Nfgenmsg, + }, + sys::{ + nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, }, MsgType, ProtoFamily, }; @@ -123,10 +126,50 @@ impl<'a> HeaderStack<'a> { } } -pub trait NfNetlinkObject: Sized { +pub trait AttributeDecoder { + fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError>; +} + +pub trait NfNetlinkDeserializable: Sized { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; +} + +pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable { fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32); +} + +pub trait NfNetlinkSerializable { + fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>); +} - fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError>; +pub type NetlinkType = u16; - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; +pub trait NfNetlinkAttribute: Debug + Sized { + fn get_size(&self) -> usize { + size_of::<Self>() + } + + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); + unsafe fn write_payload(&self, addr: *mut u8); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NfNetlinkAttributes { + pub attributes: HashMap<NetlinkType, AttributeType>, +} + +impl NfNetlinkAttributes { + pub fn new() -> Self { + NfNetlinkAttributes { + attributes: HashMap::new(), + } + } + + pub fn set_attr(&mut self, ty: NetlinkType, obj: AttributeType) { + self.attributes.insert(ty, obj); + } + + pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> { + self.attributes.get(&ty) + } } diff --git a/src/parser.rs b/src/parser.rs index a01c3cd..8e12c5e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,4 +1,5 @@ use std::{ + any::TypeId, collections::HashMap, fmt::Debug, mem::{size_of, transmute}, @@ -8,7 +9,10 @@ use std::{ use thiserror::Error; use crate::{ - nlmsg::{NfNetlinkObject, NfNetlinkWriter}, + nlmsg::{ + AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes, + NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkSerializable, NfNetlinkWriter, + }, sys::{ nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_ALIGNTO, NLMSG_DONE, NLMSG_ERROR, @@ -25,6 +29,9 @@ pub enum DecodeError { #[error("The message is too small")] NlMsgTooSmall, + #[error("The message holds unexpected data")] + InvalidDataSize, + #[error("Invalid subsystem, expected NFTABLES")] InvalidSubsystem(u8), @@ -71,13 +78,13 @@ pub fn nft_nlmsg_maxsize() -> u32 { } #[inline] -pub fn pad_netlink_object_with_variable_size(size: usize) -> usize { +pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { // align on a 4 bytes boundary (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) } #[inline] -pub fn pad_netlink_object<T>() -> usize { +pub const fn pad_netlink_object<T>() -> usize { let size = size_of::<T>(); pad_netlink_object_with_variable_size(size) } @@ -185,20 +192,9 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value))) } -pub type NetlinkType = u16; - -pub trait NfNetlinkAttribute: Debug + Sized { - fn get_size(&self) -> usize { - size_of::<Self>() - } - - // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); - unsafe fn write_payload(&self, addr: *mut u8); -} - /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -fn write_attribute<'a>(ty: NetlinkType, obj: &Attribute, writer: &mut NfNetlinkWriter<'a>) { +fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, writer: &mut NfNetlinkWriter<'a>) { // copy the header let header_len = pad_netlink_object::<libc::nlattr>(); let header = libc::nlattr { @@ -223,15 +219,15 @@ fn write_attribute<'a>(ty: NetlinkType, obj: &Attribute, writer: &mut NfNetlinkW } } -impl NfNetlinkAttribute for ProtoFamily { +impl NfNetlinkAttribute for u8 { unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut u32) = *self as u32; + *addr = *self; } } -impl NfNetlinkAttribute for u8 { - unsafe fn write_payload(&self, addr: *mut u8) { - *addr = *self; +impl NfNetlinkDeserializable for u8 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf[0], &buf[1..])) } } @@ -241,18 +237,60 @@ impl NfNetlinkAttribute for u16 { } } +impl NfNetlinkDeserializable for u16 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) + } +} + +impl NfNetlinkAttribute for i32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = *self; + } +} + +impl NfNetlinkDeserializable for i32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + impl NfNetlinkAttribute for u32 { unsafe fn write_payload(&self, addr: *mut u8) { *(addr as *mut Self) = *self; } } +impl NfNetlinkDeserializable for u32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + impl NfNetlinkAttribute for u64 { unsafe fn write_payload(&self, addr: *mut u8) { *(addr as *mut Self) = *self; } } +impl NfNetlinkDeserializable for u64 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ]), + &buf[8..], + )) + } +} + +// TODO: safe handling for null-delimited strings impl NfNetlinkAttribute for String { fn get_size(&self) -> usize { self.len() @@ -263,6 +301,12 @@ impl NfNetlinkAttribute for String { } } +impl NfNetlinkDeserializable for String { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((String::from_utf8(buf.to_vec())?, &[])) + } +} + impl NfNetlinkAttribute for Vec<u8> { fn get_size(&self) -> usize { self.len() @@ -273,24 +317,38 @@ impl NfNetlinkAttribute for Vec<u8> { } } -#[derive(Debug, PartialEq, Eq)] -pub struct NfNetlinkAttributes { - attributes: HashMap<NetlinkType, Attribute>, +impl NfNetlinkDeserializable for Vec<u8> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf.to_vec(), &[])) + } } -impl NfNetlinkAttributes { - pub fn new() -> Self { - NfNetlinkAttributes { - attributes: HashMap::new(), +pub type NestedAttribute = NfNetlinkAttributes; + +// parts of the NfNetlinkAttribute trait we need for handling nested objects +impl NestedAttribute { + pub fn get_size(&self) -> usize { + let mut size = 0; + + for (_type, attr) in self.attributes.iter() { + // Attribute header + attribute value + size += pad_netlink_object::<nlattr>() + + pad_netlink_object_with_variable_size(attr.get_size()); } - } - pub fn set_attr(&mut self, ty: NetlinkType, obj: Attribute) { - self.attributes.insert(ty, obj); + size } - pub fn get_attr(&self, ty: NetlinkType) -> Option<&Attribute> { - self.attributes.get(&ty) + pub unsafe fn write_payload(&self, mut addr: *mut u8) { + for (ty, attr) in self.attributes.iter() { + *(addr as *mut nlattr) = nlattr { + nla_len: attr.get_size() as u16, + nla_type: *ty, + }; + addr = addr.offset(pad_netlink_object::<nlattr>() as isize); + attr.write_payload(addr); + addr = addr.offset(pad_netlink_object_with_variable_size(attr.get_size()) as isize); + } } } @@ -319,7 +377,9 @@ impl<'a> NfNetlinkAttributeReader<'a> { &self.buf[self.pos..] } - pub fn decode<T: NfNetlinkObject>(mut self) -> Result<NfNetlinkAttributes, DecodeError> { + pub fn decode<T: AttributeDecoder + 'static>( + mut self, + ) -> Result<NfNetlinkAttributes, DecodeError> { while self.remaining_size > pad_netlink_object::<nlattr>() { let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(self.buf[self.pos..].as_ptr()) }; @@ -328,19 +388,28 @@ impl<'a> NfNetlinkAttributeReader<'a> { self.pos += pad_netlink_object::<nlattr>(); let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>(); - self.attrs.set_attr( + match T::decode_attribute( nla_type, - T::decode_attribute( - nla_type, - &self.buf[self.pos..self.pos + attr_remaining_size], - )?, - ); + &self.buf[self.pos..self.pos + attr_remaining_size], + ) { + Ok(x) => self.attrs.set_attr(nla_type, x), + Err(DecodeError::UnsupportedAttributeType(t)) => info!( + "Ignore attribute type {} for type id {:?}", + t, + TypeId::of::<T>() + ), + Err(e) => return Err(e), + } self.pos += pad_netlink_object_with_variable_size(attr_remaining_size); self.remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize); } - Ok(self.attrs) + if self.remaining_size != 0 { + Err(DecodeError::InvalidDataSize) + } else { + Ok(self.attrs) + } } } @@ -364,11 +433,7 @@ pub fn parse_object<'a>( } } -pub trait SerializeNfNetlink { - fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>); -} - -impl SerializeNfNetlink for NfNetlinkAttributes { +impl NfNetlinkSerializable for NfNetlinkAttributes { fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) { // TODO: improve performance by not sorting this let mut keys: Vec<&NetlinkType> = self.attributes.keys().collect(); @@ -379,9 +444,9 @@ impl SerializeNfNetlink for NfNetlinkAttributes { } } -macro_rules! impl_attribute { +macro_rules! impl_attribute_holder { ($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => { - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Debug, Clone, PartialEq, Eq)] pub enum $enum_name { $( $internal_name($type), @@ -403,7 +468,6 @@ macro_rules! impl_attribute { $enum_name::$internal_name(val) => val.write_payload(addr) ),+ } - } } @@ -421,15 +485,16 @@ macro_rules! impl_attribute { }; } -impl_attribute!( - Attribute, +impl_attribute_holder!( + AttributeType, [String, String], [U8, u8], [U16, u16], + [I32, i32], [U32, u32], [U64, u64], [VecU8, Vec<u8>], - [ProtoFamily, ProtoFamily] + [ChainHook, crate::chain::Hook] ); #[macro_export] @@ -439,20 +504,40 @@ macro_rules! impl_attr_getters_and_setters { $( #[allow(dead_code)] pub fn $getter_name(&self) -> Option<&$type> { - self.inner.get_attr($attr_name as $crate::parser::NetlinkType).map(|x| x.$internal_name()).flatten() + self.inner.get_attr($attr_name as $crate::nlmsg::NetlinkType).map(|x| x.$internal_name()).flatten() } #[allow(dead_code)] pub fn $setter_name(&mut self, val: impl Into<$type>) { - self.inner.set_attr($attr_name as $crate::parser::NetlinkType, $crate::parser::Attribute::$internal_name(val.into())); + self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into())); } #[allow(dead_code)] pub fn $in_place_edit_name(mut self, val: impl Into<$type>) -> Self { - self.inner.set_attr($attr_name as $crate::parser::NetlinkType, $crate::parser::Attribute::$internal_name(val.into())); + self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into())); self } + )+ } + + impl $crate::nlmsg::AttributeDecoder for $struct { + #[allow(dead_code)] + fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<$crate::parser::AttributeType, $crate::parser::DecodeError> { + use $crate::nlmsg::NfNetlinkDeserializable; + match attr_type { + $( + x if x == $attr_name => { + let (val, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::parser::DecodeError::InvalidDataSize); + } + Ok($crate::parser::AttributeType::$internal_name(val)) + }, + )+ + _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } }; } diff --git a/src/table.rs b/src/table.rs index 23495a4..66dc667 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,10 +1,13 @@ use std::convert::TryFrom; use std::fmt::Debug; -use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::nlmsg::{ + AttributeDecoder, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, + NfNetlinkSerializable, NfNetlinkWriter, +}; use crate::parser::{ - get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object, Attribute, DecodeError, - NfNetlinkAttributeReader, NfNetlinkAttributes, Nfgenmsg, NlMsg, SerializeNfNetlink, + get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object, DecodeError, + NfNetlinkAttributeReader, }; use crate::sys::{ self, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, @@ -12,7 +15,7 @@ use crate::sys::{ }; use crate::{impl_attr_getters_and_setters, MsgType, ProtoFamily}; -/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol +/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html @@ -65,20 +68,9 @@ impl NfNetlinkObject for Table { self.inner.serialize(writer); writer.finalize_writing_object(); } +} - fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<Attribute, DecodeError> { - match attr_type { - NFTA_TABLE_NAME => Ok(Attribute::String(String::from_utf8(buf.to_vec())?)), - NFTA_TABLE_FLAGS => { - let val = [buf[0], buf[1], buf[2], buf[3]]; - - Ok(Attribute::U32(u32::from_ne_bytes(val))) - } - NFTA_TABLE_USERDATA => Ok(Attribute::VecU8(buf.to_vec())), - _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), - } - } - +impl NfNetlinkDeserializable for Table { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { let (hdr, msg) = parse_nlmsg(buf)?; @@ -102,9 +94,27 @@ impl NfNetlinkObject for Table { } } +/* +impl AttributeDecoder for Table { + fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError> { + match attr_type { + NFTA_TABLE_NAME => Ok(AttributeType::String(String::from_utf8(buf.to_vec())?)), + NFTA_TABLE_FLAGS => { + let val = [buf[0], buf[1], buf[2], buf[3]]; + + Ok(AttributeType::U32(u32::from_ne_bytes(val))) + } + NFTA_TABLE_USERDATA => Ok(AttributeType::VecU8(buf.to_vec())), + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } +} +*/ + impl_attr_getters_and_setters!( Table, [ + (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, U32, u32), (get_name, set_name, with_name, sys::NFTA_TABLE_NAME, String, String), ( get_userdata, @@ -113,8 +123,7 @@ impl_attr_getters_and_setters!( sys::NFTA_TABLE_USERDATA, VecU8, Vec<u8> - ), - (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, U32, u32) + ) ] ); diff --git a/tests/batch.rs b/tests/batch.rs index 081ee97..740fc19 100644 --- a/tests/batch.rs +++ b/tests/batch.rs @@ -1,6 +1,7 @@ mod sys; use libc::NFNL_MSG_BATCH_BEGIN; use nix::libc::NFNL_MSG_BATCH_END; +use rustables::nlmsg::NfNetlinkDeserializable; use rustables::nlmsg::NfNetlinkObject; use rustables::parser::{get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object}; use rustables::{Batch, MsgType, Table}; diff --git a/tests/table.rs b/tests/table.rs index d8a5f1e..5961d65 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -1,6 +1,6 @@ mod sys; use rustables::{ - nlmsg::NfNetlinkObject, + nlmsg::NfNetlinkDeserializable, parser::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize}, MsgType, Table, }; |