diff options
Diffstat (limited to 'src/rule.rs')
-rw-r--r-- | src/rule.rs | 392 |
1 files changed, 81 insertions, 311 deletions
diff --git a/src/rule.rs b/src/rule.rs index 2ee5308..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,341 +1,111 @@ -use crate::expr::ExpressionWrapper; -use crate::{chain::Chain, expr::Expression, MsgType}; -use crate::sys::{self, libc}; -use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; -use std::os::raw::c_char; -use std::rc::Rc; + +use rustables_macros::nfnetlink_struct; + +use crate::chain::Chain; +use crate::error::{BuilderError, QueryError}; +use crate::expr::{ExpressionList, RawExpression}; +use crate::nlmsg::NfNetlinkObject; +use crate::query::list_objects_with_data; +use crate::sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION, + NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, + NLM_F_CREATE, +}; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Rule { - pub(crate) rule: *mut sys::nftnl_rule, - pub(crate) chain: Rc<Chain>, + family: ProtocolFamily, + #[field(NFTA_RULE_TABLE)] + table: String, + #[field(NFTA_RULE_CHAIN)] + chain: String, + #[field(NFTA_RULE_HANDLE)] + handle: u64, + #[field(NFTA_RULE_EXPRESSIONS)] + expressions: ExpressionList, + #[field(NFTA_RULE_POSITION)] + position: u64, + #[field(NFTA_RULE_USERDATA)] + userdata: Vec<u8>, + #[field(NFTA_RULE_ID)] + id: u32, } impl Rule { /// Creates a new rule object in the given [`Chain`]. /// /// [`Chain`]: struct.Chain.html - pub fn new(chain: Rc<Chain>) -> Rule { - unsafe { - let rule = try_alloc!(sys::nftnl_rule_alloc()); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - Rule { rule, chain } - } - } - - pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self { - Rule { rule, chain } - } - - pub fn get_position(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) } - } - - /// Sets the position of this rule within the chain it lives in. By default a new rule is added - /// to the end of the chain. - pub fn set_position(&mut self, position: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position); - } - } - - pub fn get_handle(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) } - } - - pub fn set_handle(&mut self, handle: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle); - } - } - - /// Adds an expression to this rule. Expressions are evaluated from first to last added. - /// As soon as an expression does not match the packet it's being evaluated for, evaluation - /// stops and the packet is evaluated against the next rule in the chain. - pub fn add_expr(&mut self, expr: &impl Expression) { - unsafe { sys::nftnl_rule_add_expr(self.rule, expr.to_expr(self)) } - } - - /// Returns a reference to the [`Chain`] this rule lives in. - /// - /// [`Chain`]: struct.Chain.html - pub fn get_chain(&self) -> Rc<Chain> { - self.chain.clone() - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns a textual description of the rule. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_rule_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.rule, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves an iterator to loop over the expressions of the rule. - pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter { - RuleExprsIter::new(self.clone()) - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_rule { - self.rule as *const sys::nftnl_rule - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule { - self.rule - } - - /// Performs a deep comparizon of rules, by checking they have the same expressions inside. - /// This is not enabled by default in our PartialEq implementation because of the difficulty to - /// compare an expression generated by the library with the expressions returned by the kernel - /// when iterating over the currently in-use rules. The kernel-returned expressions may have - /// additional attributes despite being generated from the same rule. This is particularly true - /// for the 'nat' expression). - pub fn deep_eq(&self, other: &Self) -> bool { - if self != other { - return false; - } - - let self_exprs = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) }); - let other_exprs = try_alloc!(unsafe { - sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule) - }); - - loop { - let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) }; - let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) }; - if self_next.is_null() && other_next.is_null() { - return true; - } else if self_next.is_null() || other_next.is_null() { - return false; - } - - // we are falling back on comparing the strings, because there is no easy mechanism to - // perform a memcmp() between the two expressions :/ - let mut self_str = [0; 256]; - let mut other_str = [0; 256]; - unsafe { - sys::nftnl_expr_snprintf( - self_str.as_mut_ptr(), - (self_str.len() - 1) as u64, - self_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - sys::nftnl_expr_snprintf( - other_str.as_mut_ptr(), - (other_str.len() - 1) as u64, - other_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); + pub fn new(chain: &Chain) -> Result<Rule, BuilderError> { + Ok(Rule::default() + .with_family(chain.get_family()) + .with_table( + chain + .get_table() + .ok_or(BuilderError::MissingChainInformationError)?, + ) + .with_chain( + chain + .get_name() + .ok_or(BuilderError::MissingChainInformationError)?, + )) + } + + pub fn add_expr(&mut self, e: impl Into<RawExpression>) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() } - - if self_str != other_str { - return false; - } - } - } -} - -impl Debug for Rule { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl PartialEq for Rule { - fn eq(&self, other: &Self) -> bool { - if self.get_chain() != other.get_chain() { - return false; - } - - unsafe { - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16) - { - if self.get_handle() != other.get_handle() { - return false; - } - } - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16) - { - if self.get_position() != other.get_position() { - return false; - } - } - } - - return false; - } -} - -unsafe impl crate::NlMsg for Rule { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWRULE, - MsgType::Del => libc::NFT_MSG_DELRULE, }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16, - MsgType::Del => 0u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.chain.get_table().get_family() as u16, - flags, - seq, - ); - sys::nftnl_rule_nlmsg_build_payload(header, self.rule); + exprs.add_value(e); } -} -impl Drop for Rule { - fn drop(&mut self) { - unsafe { sys::nftnl_rule_free(self.rule) }; + pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self { + self.add_expr(e); + self } -} - -pub struct RuleExprsIter { - rule: Rc<Rule>, - iter: *mut sys::nftnl_expr_iter, -} -impl RuleExprsIter { - fn new(rule: Rc<Rule>) -> Self { - let iter = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) }); - RuleExprsIter { rule, iter } + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl Iterator for RuleExprsIter { - type Item = ExpressionWrapper; +impl NfNetlinkObject for Rule { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE; - fn next(&mut self) -> Option<Self::Item> { - let next = unsafe { sys::nftnl_expr_iter_next(self.iter) }; - if next.is_null() { - trace!("RulesExprsIter iterator ending"); - None - } else { - trace!("RulesExprsIter returning new expression"); - Some(ExpressionWrapper { - expr: next, - rule: self.rule.clone(), - }) - } + fn get_family(&self) -> ProtocolFamily { + self.family } -} -impl Drop for RuleExprsIter { - fn drop(&mut self) { - unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } -} - -#[cfg(feature = "query")] -pub fn get_rules_cb( - header: &libc::nlmsghdr, - (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> libc::c_int { - unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_rule_nlmsg_parse(header, rule); - if err < 0 { - error!("Failed to parse nelink rule message - {}", err); - sys::nftnl_rule_free(rule); - return err; - } - rules.push(Rule::from_raw(rule, chain.clone())); + // append at the end of the chain, instead of the beginning + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE | NLM_F_APPEND } - mnl::mnl_sys::MNL_CB_OK } -#[cfg(feature = "query")] -pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { - crate::query::list_objects_with_data( +pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> { + let mut result = Vec::new(); + list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - get_rules_cb, - &chain, - // only retrieve rules from the currently targetted chain - Some(&|hdr| unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule as *const _ == std::ptr::null() { - return Err(crate::query::Error::NetlinkAllocationFailed); - } - - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - sys::nftnl_rule_nlmsg_build_payload(hdr, rule); - - sys::nftnl_rule_free(rule); + &|rule: Rule, rules: &mut Vec<Rule>| { + rules.push(rule); Ok(()) - }), - ) + }, + // only retrieve rules from the currently targetted chain + Some(&Rule::new(chain)?), + &mut result, + )?; + Ok(result) } |