diff options
Diffstat (limited to 'src/chain.rs')
-rw-r--r-- | src/chain.rs | 381 |
1 files changed, 151 insertions, 230 deletions
diff --git a/src/chain.rs b/src/chain.rs index a942a37..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,41 +1,85 @@ -use crate::{MsgType, Table}; -use crate::sys::{self as sys, libc}; -#[cfg(feature = "query")] -use std::convert::TryFrom; -use std::{ - ffi::{c_void, CStr, CString}, - fmt, - os::raw::c_char, - rc::Rc, +use libc::{NF_ACCEPT, NF_DROP}; +use rustables_macros::nfnetlink_struct; + +use crate::error::{DecodeError, QueryError}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; +use crate::sys::{ + NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, + NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, + NFT_MSG_NEWCHAIN, }; +use crate::{Batch, ProtocolFamily, Table}; +use std::fmt::Debug; -pub type Priority = i32; +pub type ChainPriority = i32; /// The netfilter event hooks a chain can register for. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum Hook { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum HookClass { /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. - PreRouting = libc::NF_INET_PRE_ROUTING as u16, + PreRouting = libc::NF_INET_PRE_ROUTING, /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. - In = libc::NF_INET_LOCAL_IN as u16, + In = libc::NF_INET_LOCAL_IN, /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. - Forward = libc::NF_INET_FORWARD as u16, + Forward = libc::NF_INET_FORWARD, /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. - Out = libc::NF_INET_LOCAL_OUT as u16, + Out = libc::NF_INET_LOCAL_OUT, /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. - PostRouting = libc::NF_INET_POST_ROUTING as u16, + PostRouting = libc::NF_INET_POST_ROUTING, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Hook { + /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + #[field(NFTA_HOOK_HOOKNUM)] + class: u32, + #[field(NFTA_HOOK_PRIORITY)] + priority: u32, +} + +impl Hook { + pub fn new(class: HookClass, priority: ChainPriority) -> Self { + Hook::default() + .with_class(class as u32) + .with_priority(priority as u32) + } } /// A chain policy. Decides what to do with a packet that was processed by the chain but did not /// match any rules. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u32)] -pub enum Policy { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum ChainPolicy { /// Accept the packet. - Accept = libc::NF_ACCEPT as u32, + Accept = NF_ACCEPT, /// Drop the packet. - Drop = libc::NF_DROP as u32, + Drop = NF_DROP, +} + +impl NfNetlinkAttribute for ChainPolicy { + fn get_size(&self) -> usize { + (*self as i32).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainPolicy { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok(( + match v { + NF_ACCEPT => ChainPolicy::Accept, + NF_DROP => ChainPolicy::Accept, + _ => return Err(DecodeError::UnknownChainPolicy), + }, + remaining_data, + )) + } } /// Base chain type. @@ -53,240 +97,117 @@ pub enum ChainType { } impl ChainType { - fn as_c_str(&self) -> &'static [u8] { + fn as_str(&self) -> &'static str { match *self { - ChainType::Filter => b"filter\0", - ChainType::Route => b"route\0", - ChainType::Nat => b"nat\0", + ChainType::Filter => "filter", + ChainType::Route => "route", + ChainType::Nat => "nat", } } } -/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s. -/// -/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more -/// details. +impl NfNetlinkAttribute for ChainType { + fn get_size(&self) -> usize { + self.as_str().len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.as_str().to_string().write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainType { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (s, remaining_data) = String::deserialize(buf)?; + Ok(( + match s.as_str() { + "filter" => ChainType::Filter, + "route" => ChainType::Route, + "nat" => ChainType::Nat, + _ => return Err(DecodeError::UnknownChainType), + }, + remaining_data, + )) + } +} + +/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s. /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -/// [`set_hook`]: #method.set_hook +#[derive(PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Chain { - pub(crate) chain: *mut sys::nftnl_chain, - pub(crate) table: Rc<Table>, + family: ProtocolFamily, + #[field(NFTA_CHAIN_TABLE)] + table: String, + #[field(NFTA_CHAIN_NAME)] + name: String, + #[field(NFTA_CHAIN_HOOK)] + hook: Hook, + #[field(NFTA_CHAIN_POLICY)] + policy: ChainPolicy, + #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")] + chain_type: ChainType, + #[field(NFTA_CHAIN_FLAGS)] + flags: u32, + #[field(NFTA_CHAIN_USERDATA)] + userdata: Vec<u8>, } impl Chain { - /// Creates a new chain instance inside the given [`Table`] and with the given name. + /// Creates a new chain instance inside the given [`Table`]. /// /// [`Table`]: struct.Table.html - pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain { - unsafe { - let chain = try_alloc!(sys::nftnl_chain_alloc()); - sys::nftnl_chain_set_u32( - chain, - sys::NFTNL_CHAIN_FAMILY as u16, - table.get_family() as u32, - ); - sys::nftnl_chain_set_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - table.get_name().as_ptr(), - ); - sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr()); - Chain { chain, table } - } - } - - pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self { - Chain { chain, table } - } + pub fn new(table: &Table) -> Chain { + let mut chain = Chain::default(); + chain.family = table.get_family(); - /// Sets the hook and priority for this chain. Without calling this method the chain will - /// become a "regular chain" without any hook and will thus not receive any traffic unless - /// some rule forward packets to it via goto or jump verdicts. - /// - /// By calling `set_hook` with a hook the chain that is created will be registered with that - /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the - /// networking stack. - pub fn set_hook(&mut self, hook: Hook, priority: Priority) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32); - sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority); + if let Some(table_name) = table.get_name() { + chain.set_table(table_name); } - } - /// Set the type of a base chain. This only applies if the chain has been registered - /// with a hook by calling `set_hook`. - pub fn set_type(&mut self, chain_type: ChainType) { - unsafe { - sys::nftnl_chain_set_str( - self.chain, - sys::NFTNL_CHAIN_TYPE as u16, - chain_type.as_c_str().as_ptr() as *const c_char, - ); - } + chain } - /// Sets the default policy for this chain. That means what action netfilter will apply to - /// packets processed by this chain, but that did not match any rules in it. - pub fn set_policy(&mut self, policy: Policy) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32); - } - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16); - if ptr == std::ptr::null() { - return None; - } - Some(CStr::from_ptr(ptr)) - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns the name of this chain. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } - } - - /// Returns a textual description of the chain. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_chain_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.chain, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Returns a reference to the [`Table`] this chain belongs to. - /// - /// [`Table`]: struct.Table.html - pub fn get_table(&self) -> Rc<Table> { - self.table.clone() - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_chain { - self.chain as *const sys::nftnl_chain - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain { - self.chain + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl fmt::Debug for Chain { - /// Returns a string representation of the chain. - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{:?}", self.get_str()) - } -} +impl NfNetlinkObject for Chain { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN; -impl PartialEq for Chain { - fn eq(&self, other: &Self) -> bool { - self.get_table() == other.get_table() && self.get_name() == other.get_name() + fn get_family(&self) -> ProtocolFamily { + self.family } -} -unsafe impl crate::NlMsg for Chain { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let raw_msg_type = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWCHAIN, - MsgType::Del => libc::NFT_MSG_DELCHAIN, - }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16, - MsgType::Del => libc::NLM_F_ACK as u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - raw_msg_type as u16, - self.table.get_family() as u16, - flags, - seq, - ); - sys::nftnl_chain_nlmsg_build_payload(header, self.chain); + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -impl Drop for Chain { - fn drop(&mut self) { - unsafe { sys::nftnl_chain_free(self.chain) }; - } -} - -#[cfg(feature = "query")] -pub fn get_chains_cb<'a>( - header: &libc::nlmsghdr, - (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), -) -> libc::c_int { - unsafe { - let chain = sys::nftnl_chain_alloc(); - if chain == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_chain_nlmsg_parse(header, chain); - if err < 0 { - error!("Failed to parse nelink chain message - {}", err); - sys::nftnl_chain_free(chain); - return err; - } - - let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - )); - let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16); - let family = match crate::ProtoFamily::try_from(family as i32) { - Ok(family) => family, - Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_ERROR; +pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> { + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETCHAIN as u16, + &|chain: Chain, (table, chains): &mut (&Table, &mut Vec<Chain>)| { + if chain.get_table() == table.get_name() { + chains.push(chain); + } else { + info!( + "Ignoring chain {:?} because it doesn't map the table {:?}", + chain.get_name(), + table.get_name() + ); } - }; - - if table_name != table.get_name() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - if family != crate::ProtoFamily::Unspec && family != table.get_family() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - chains.push(Chain::from_raw(chain, table.clone())); - } - mnl::mnl_sys::MNL_CB_OK -} - -#[cfg(feature = "query")] -pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) + Ok(()) + }, + None, + &mut (&table, &mut result), + )?; + Ok(result) } |