diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/chain.rs | 58 | ||||
-rw-r--r-- | src/expr/immediate.rs | 174 | ||||
-rw-r--r-- | src/expr/log.rs | 112 | ||||
-rw-r--r-- | src/expr/mod.rs | 361 | ||||
-rw-r--r-- | src/expr/register.rs | 50 | ||||
-rw-r--r-- | src/expr/verdict.rs | 232 | ||||
-rw-r--r-- | src/lib.rs | 8 | ||||
-rw-r--r-- | src/nlmsg.rs | 35 | ||||
-rw-r--r-- | src/parser.rs | 117 | ||||
-rw-r--r-- | src/query.rs | 34 | ||||
-rw-r--r-- | src/rule.rs | 239 | ||||
-rw-r--r-- | src/table.rs | 14 |
12 files changed, 813 insertions, 621 deletions
diff --git a/src/chain.rs b/src/chain.rs index 60f5f10..cce0fa9 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -4,10 +4,8 @@ use crate::nlmsg::{ NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, }; -use crate::parser::{ - parse_object, DecodeError, InnerFormat, NestedAttribute, NfNetlinkAttributeReader, -}; -use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK}; +use crate::parser::{parse_object, DecodeError, InnerFormat, NfNetlinkAttributeReader}; +use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE}; use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily, Table}; use std::convert::TryFrom; use std::fmt::Debug; @@ -16,32 +14,32 @@ pub type ChainPriority = i32; /// The netfilter event hooks a chain can register for. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(u32)] +#[repr(i32)] pub enum HookClass { /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. - PreRouting = libc::NF_INET_PRE_ROUTING as u32, + PreRouting = libc::NF_INET_PRE_ROUTING, /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. - In = libc::NF_INET_LOCAL_IN as u32, + In = libc::NF_INET_LOCAL_IN, /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. - Forward = libc::NF_INET_FORWARD as u32, + Forward = libc::NF_INET_FORWARD, /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. - Out = libc::NF_INET_LOCAL_OUT as u32, + Out = libc::NF_INET_LOCAL_OUT, /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. - PostRouting = libc::NF_INET_POST_ROUTING as u32, + PostRouting = libc::NF_INET_POST_ROUTING, } #[derive(Clone, PartialEq, Eq)] pub struct Hook { - inner: NestedAttribute, + inner: NfNetlinkAttributes, } impl Hook { pub fn new(class: HookClass, priority: ChainPriority) -> Self { Hook { - inner: NestedAttribute::new(), + inner: NfNetlinkAttributes::new(), } - .with_hook_class(class as u32) - .with_hook_priority(priority as u32) + .with_class(class as u32) + .with_priority(priority as u32) } } @@ -56,17 +54,17 @@ impl_attr_getters_and_setters!( [ // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. ( - get_hook_class, - set_hook_class, - with_hook_class, + get_class, + set_class, + with_class, sys::NFTA_HOOK_HOOKNUM, U32, u32 ), ( - get_hook_priority, - set_hook_priority, - with_hook_priority, + get_priority, + set_priority, + with_priority, sys::NFTA_HOOK_PRIORITY, U32, u32 @@ -211,6 +209,14 @@ impl Chain { chain } + pub fn get_family(&self) -> ProtocolFamily { + self.family + } + + fn raw_attributes(&self) -> &NfNetlinkAttributes { + &self.inner + } + /* /// Returns a textual description of the chain. pub fn get_str(&self) -> CString { @@ -251,7 +257,17 @@ impl NfNetlinkObject for Chain { MsgType::Add => NFT_MSG_NEWCHAIN, MsgType::Del => NFT_MSG_DELCHAIN, } as u16; - writer.write_header(raw_msg_type, self.family, NLM_F_ACK as u16, seq, None); + writer.write_header( + raw_msg_type, + self.family, + (if let MsgType::Add = msg_type { + NLM_F_CREATE + } else { + 0 + } | NLM_F_ACK) as u16, + seq, + None, + ); self.inner.serialize(writer); writer.finalize_writing_object(); } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 71453b3..e9f7b5b 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,124 +1,60 @@ -use super::{DeserializationError, Expression, Register, Rule, ToSlice}; -use crate::sys; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// An immediate expression. Used to set immediate data. Verdicts are handled separately by -/// [crate::expr::Verdict]. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Immediate<T> { - pub data: T, - pub register: Register, -} - -impl<T> Immediate<T> { - pub fn new(data: T, register: Register) -> Self { - Self { data, register } - } -} - -impl<T: ToSlice> Expression for Immediate<T> { - fn get_raw_name() -> *const c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - self.register.to_raw(), - ); - - let data = self.data.to_slice(); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr - } +use super::{Expression, Register, VerdictAttribute}; +use crate::{create_expr_type, sys}; + +create_expr_type!( + nested with_builder : ImmediateData, + [ + ( + get_value, + set_value, + with_value, + sys::NFTA_DATA_VALUE, + VecU8, + Vec<u8> + ), + ( + get_verdict, + set_verdict, + with_verdict, + sys::NFTA_DATA_VERDICT, + ExprVerdictAttribute, + VerdictAttribute + ) + ] +); + +create_expr_type!( + inline with_builder : Immediate, + [ + ( + get_dreg, + set_dreg, + with_dreg, + sys::NFTA_IMMEDIATE_DREG, + Register, + Register + ), + ( + get_data, + set_data, + with_data, + sys::NFTA_IMMEDIATE_DATA, + ExprImmediateData, + ImmediateData + ) + ] +); + +impl Immediate { + pub fn new_data(data: Vec<u8>, register: Register) -> Self { + Immediate::builder() + .with_dreg(register) + .with_data(ImmediateData::builder().with_value(data)) } } -impl<const N: usize> Expression for Immediate<[u8; N]> { - fn get_raw_name() -> *const c_char { - Immediate::<u8>::get_raw_name() +impl Expression for Immediate { + fn get_name() -> &'static str { + "immediate" } - - /// The raw data contained inside `Immediate` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Immediate::new(42u8, Register::Reg1)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size as the raw input - /// would be *extremely* dangerous in terms of memory safety. - // As casting bytes to any type of the same size as the input would be *extremely* dangerous in - // terms of memory safety, rustables only accept to deserialize expressions with variable-size - // data to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - ))?; - - Ok(Immediate { data, register }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - Immediate { - register: self.register, - data: &self.data as &[u8], - } - .to_expr(rule) - } -} - -#[macro_export] -macro_rules! nft_expr_immediate { - (data $value:expr) => { - $crate::expr::Immediate { - data: $value, - register: $crate::expr::Register::Reg1, - } - }; } diff --git a/src/expr/log.rs b/src/expr/log.rs index 8d20b48..cf50cb2 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,21 +1,61 @@ -use super::{DeserializationError, Expression, Rule}; +use super::{Expression, ExpressionError}; +use crate::create_expr_type; +use crate::nlmsg::NfNetlinkAttributes; use crate::sys; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; -use thiserror::Error; -/// A Log expression will log all packets that match the rule. -#[derive(Debug, PartialEq)] -pub struct Log { - pub group: Option<LogGroup>, - pub prefix: Option<LogPrefix>, +// A Log expression will log all packets that match the rule. +create_expr_type!( + inline with_builder : Log, + [ + ( + get_group, + set_group, + with_group, + sys::NFTA_LOG_GROUP, + U32, + u32 + ), + ( + get_prefix, + set_prefix, + with_prefix, + sys::NFTA_LOG_PREFIX, + String, + String + ) + ] +); + +impl Log { + pub fn new( + group: Option<u16>, + prefix: Option<impl Into<String>>, + ) -> Result<Log, ExpressionError> { + let mut res = Log { + inner: NfNetlinkAttributes::new(), + //pub group: Option<LogGroup>, + //pub prefix: Option<LogPrefix>, + }; + if let Some(group) = group { + res.set_group(group); + } + if let Some(prefix) = prefix { + let prefix = prefix.into(); + + if prefix.bytes().count() > 127 { + return Err(ExpressionError::TooLongLogPrefix); + } + res.set_prefix(prefix); + } + Ok(res) + } } impl Expression for Log { - fn get_raw_name() -> *const sys::libc::c_char { - b"log\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "log" } - + /* fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, @@ -54,59 +94,21 @@ impl Expression for Log { expr } } -} - -#[derive(Error, Debug)] -pub enum LogPrefixError { - #[error("The log prefix string is more than 128 characters long")] - TooLongPrefix, - #[error("The log prefix string contains an invalid Nul character.")] - PrefixContainsANul(#[from] std::ffi::NulError), -} - -/// The NFLOG group that will be assigned to each log line. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct LogGroup(pub u16); - -/// A prefix that will get prepended to each log line. -#[derive(Debug, Clone, PartialEq)] -pub struct LogPrefix(CString); - -impl LogPrefix { - /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that - /// LogPrefix should not be more than 127 characters long. - pub fn new(prefix: &str) -> Result<Self, LogPrefixError> { - if prefix.chars().count() > 127 { - return Err(LogPrefixError::TooLongPrefix); - } - Ok(LogPrefix(CString::new(prefix)?)) - } + */ } #[macro_export] macro_rules! nft_expr_log { (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log { - group: $group, - prefix: $prefix, - } + $crate::expr::Log::new(Some($group), Some($prefix)) }; (prefix $prefix:expr) => { - $crate::expr::Log { - group: None, - prefix: $prefix, - } + $crate::expr::Log::new(None, Some($prefix)) }; (group $group:ident) => { - $crate::expr::Log { - group: $group, - prefix: None, - } + $crate::expr::Log::new(Some($group), None) }; () => { - $crate::expr::Log { - group: None, - prefix: None, - } + $crate::expr::Log::new(None, None) }; } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index dc59507..4c702b2 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -4,14 +4,29 @@ //! [`Rule`]: struct.Rule.html use std::borrow::Cow; +use std::fmt::Debug; +use std::mem::transmute; use std::net::IpAddr; use std::net::Ipv4Addr; use std::net::Ipv6Addr; +use std::slice::Iter; use super::rule::Rule; -use crate::sys::{self, libc}; +use crate::nlmsg::AttributeDecoder; +use crate::nlmsg::NfNetlinkAttribute; +use crate::nlmsg::NfNetlinkAttributes; +use crate::nlmsg::NfNetlinkDeserializable; +use crate::parser::pad_netlink_object; +use crate::parser::pad_netlink_object_with_variable_size; +use crate::parser::write_attribute; +use crate::parser::AttributeType; +use crate::parser::DecodeError; +use crate::parser::InnerFormat; +use crate::sys::{self, nlattr}; +use libc::NLA_TYPE_MASK; use thiserror::Error; +/* mod bitwise; pub use self::bitwise::*; @@ -23,12 +38,14 @@ pub use self::counter::*; pub mod ct; pub use self::ct::*; +*/ mod immediate; pub use self::immediate::*; mod log; pub use self::log::*; +/* mod lookup; pub use self::lookup::*; @@ -47,6 +64,7 @@ pub use self::payload::*; mod reject; pub use self::reject::{IcmpCode, Reject}; +*/ mod register; pub use self::register::Register; @@ -54,11 +72,18 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; +/* + mod wrapper; pub use self::wrapper::ExpressionWrapper; +*/ #[derive(Debug, Error)] -pub enum DeserializationError { +pub enum ExpressionError { + #[error("The log prefix string is more than 127 characters long")] + /// The log prefix string is more than 127 characters long + TooLongLogPrefix, + #[error("The expected expression type doesn't match the name of the raw expression")] /// The expected expression type doesn't match the name of the raw expression. InvalidExpressionKind, @@ -80,109 +105,295 @@ pub enum DeserializationError { )] /// The size of a raw value was incoherent with the expected type of the deserialized value/ InvalidDataSize, - - #[error(transparent)] - /// Couldn't find a matching protocol. - InvalidProtolFamily(#[from] super::InvalidProtocolFamily), } -/// Trait for every safe wrapper of an nftables expression. pub trait Expression { - /// Returns the raw name used by nftables to identify the rule. - fn get_raw_name() -> *const libc::c_char; + fn get_name() -> &'static str; +} + +// wrapper for the general case, as we need to create many holder types given the depth of some +// netlink expressions +#[macro_export] +macro_rules! create_expr_type { + (without_decoder : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + #[derive(Clone, PartialEq, Eq)] + pub struct $struct { + inner: $crate::nlmsg::NfNetlinkAttributes, + } + + + $crate::impl_attr_getters_and_setters!(without_decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + + impl std::fmt::Debug for $struct { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use $crate::parser::InnerFormat; + self.inner_format_struct(f.debug_struct(stringify!($struct)))? + .finish() + } + } + + + impl $crate::nlmsg::NfNetlinkDeserializable for $struct { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), $crate::parser::DecodeError> { + let reader = $crate::parser::NfNetlinkAttributeReader::new(buf, buf.len())?; + let inner = reader.decode::<Self>()?; + Ok(($struct { inner }, &[])) + } + } + + }; + ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_expr_type!(without_decoder : $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_attr_getters_and_setters!(decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + }; + (with_builder : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_expr_type!($struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + + impl $struct { + pub fn builder() -> Self { + Self { inner: $crate::nlmsg::NfNetlinkAttributes::new() } + } + } + }; + (inline $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_expr_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + + impl $crate::nlmsg::NfNetlinkAttribute for $struct { + fn get_size(&self) -> usize { + self.inner.get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.inner.write_payload(addr) + } + } + }; + (nested $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_expr_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + + impl $crate::nlmsg::NfNetlinkAttribute for $struct { + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + self.inner.get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.inner.write_payload(addr) + } + } + }; +} - /// Try to parse the expression from a raw nftables expression, returning a - /// [DeserializationError] if the attempted parsing failed. - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> +create_expr_type!( + nested without_decoder : ExpressionHolder, [ + // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + ( + get_name, + set_name, + with_name, + sys::NFTA_EXPR_NAME, + String, + String + ), + ( + get_data, + set_data, + with_data, + sys::NFTA_EXPR_DATA, + ExpressionVariant, + ExpressionVariant + ) +]); + +impl ExpressionHolder { + pub fn new<T>(expr: T) -> Self where - Self: Sized, + T: Expression, + ExpressionVariant: From<T>, { - Err(DeserializationError::NotImplemented) + ExpressionHolder { + inner: NfNetlinkAttributes::new(), + } + .with_name(T::get_name()) + .with_data(ExpressionVariant::from(expr)) } - - /// Allocates and returns the low level `nftnl_expr` representation of this expression. The - /// caller to this method is responsible for freeing the expression. - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; } -/// A type that can be converted into a byte buffer. -pub trait ToSlice { - /// Returns the data this type represents. - fn to_slice(&self) -> Cow<'_, [u8]>; -} +#[macro_export] +macro_rules! create_expr_variant { + ($enum:ident $(, [$name:ident, $type:ty])+) => { + #[derive(Debug, Clone, PartialEq, Eq)] + pub enum $enum { + $( + $name($type), + )+ + } -impl<'a> ToSlice for &'a [u8] { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Borrowed(self) - } + impl $crate::nlmsg::NfNetlinkAttribute for $enum { + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + match self { + $( + $enum::$name(val) => val.get_size(), + )+ + } + } + + unsafe fn write_payload(&self, addr: *mut u8) { + match self { + $( + $enum::$name(val) => val.write_payload(addr), + )+ + } + } + } + + $( + impl From<$type> for $enum { + fn from(val: $type) -> Self { + $enum::$name(val) + } + } + )+ + + impl AttributeDecoder for ExpressionHolder { + fn decode_attribute( + attrs: &NfNetlinkAttributes, + attr_type: u16, + buf: &[u8], + ) -> Result<AttributeType, DecodeError> { + debug!("Decoding attribute {} in an expression", attr_type); + match attr_type { + x if x == sys::NFTA_EXPR_NAME => { + debug!("Calling {}::deserialize()", std::any::type_name::<String>()); + let (val, remaining) = String::deserialize(buf)?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + Ok(AttributeType::String(val)) + }, + x if x == sys::NFTA_EXPR_DATA => { + // we can assume we have already the name parsed, as that's how we identify the + // type of expression + let name = attrs + .get_attr(sys::NFTA_EXPR_NAME) + .ok_or(DecodeError::MissingExpressionName)?; + match name { + $( + AttributeType::String(x) if x == <$type>::get_name() => { + debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); + let (res, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::parser::DecodeError::InvalidDataSize); + } + Ok(AttributeType::ExpressionVariant(ExpressionVariant::from(res))) + }, + )+ + AttributeType::String(name) => Err(DecodeError::UnknownExpressionName(name.to_string())), + _ => unreachable!() + } + }, + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + }; } -impl<'a> ToSlice for &'a [u16] { - fn to_slice(&self) -> Cow<'_, [u8]> { - let ptr = self.as_ptr() as *const u8; - let len = self.len() * 2; - Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) - } +create_expr_variant!(ExpressionVariant, [Log, Log], [Immediate, Immediate]); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExpressionList { + exprs: Vec<AttributeType>, } -impl ToSlice for IpAddr { - fn to_slice(&self) -> Cow<'_, [u8]> { - match *self { - IpAddr::V4(ref addr) => addr.to_slice(), - IpAddr::V6(ref addr) => addr.to_slice(), - } +impl ExpressionList { + pub fn builder() -> Self { + Self { exprs: Vec::new() } } -} -impl ToSlice for Ipv4Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) + pub fn add_expression<T>(&mut self, e: T) + where + T: Expression, + ExpressionVariant: From<T>, + { + self.exprs + .push(AttributeType::Expression(ExpressionHolder::new(e))); } -} -impl ToSlice for Ipv6Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) + pub fn with_expression<T>(mut self, e: T) -> Self + where + T: Expression, + ExpressionVariant: From<T>, + { + self.add_expression(e); + self } -} -impl ToSlice for u8 { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(vec![*self]) + pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a ExpressionVariant> { + self.exprs.iter().map(|t| match t { + AttributeType::Expression(e) => e.get_data().unwrap(), + _ => unreachable!(), + }) } } -impl ToSlice for u16 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = (*self & 0x00ff) as u8; - let b1 = (*self >> 8) as u8; - Cow::Owned(vec![b0, b1]) +impl NfNetlinkAttribute for ExpressionList { + fn is_nested(&self) -> bool { + true } -} -impl ToSlice for u32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) + fn get_size(&self) -> usize { + // one nlattr LIST_ELEM per object + self.exprs.iter().fold(0, |acc, item| { + acc + item.get_size() + pad_netlink_object::<nlattr>() + }) } -} -impl ToSlice for i32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) + unsafe fn write_payload(&self, mut addr: *mut u8) { + for item in &self.exprs { + write_attribute(sys::NFTA_LIST_ELEM, item, addr); + addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + } } } -impl<'a> ToSlice for &'a str { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::from(self.as_bytes()) +impl NfNetlinkDeserializable for ExpressionList { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let mut exprs = Vec::new(); + + let mut pos = 0; + while buf.len() - pos > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + if nla_type != sys::NFTA_LIST_ELEM { + return Err(DecodeError::UnsupportedAttributeType(nla_type)); + } + + let (expr, remaining) = ExpressionHolder::deserialize( + &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], + )?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + exprs.push(AttributeType::Expression(expr)); + + pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if pos != buf.len() { + Err(DecodeError::InvalidDataSize) + } else { + Ok((Self { exprs }, &[])) + } } } diff --git a/src/expr/register.rs b/src/expr/register.rs index a05af7e..def58a5 100644 --- a/src/expr/register.rs +++ b/src/expr/register.rs @@ -1,34 +1,42 @@ use std::fmt::Debug; -use crate::sys::libc; - -use super::DeserializationError; +use crate::{ + nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, + parser::DecodeError, + sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}, +}; /// A netfilter data register. The expressions store and read data to and from these when /// evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[repr(u32)] pub enum Register { - Verdict = libc::NFT_REG_VERDICT, - Reg1 = libc::NFT_REG_1, - Reg2 = libc::NFT_REG_2, - Reg3 = libc::NFT_REG_3, - Reg4 = libc::NFT_REG_4, + Verdict = NFT_REG_VERDICT, + Reg1 = NFT_REG_1, + Reg2 = NFT_REG_2, + Reg3 = NFT_REG_3, + Reg4 = NFT_REG_4, } -impl Register { - pub fn to_raw(self) -> u32 { - self as u32 +impl NfNetlinkAttribute for Register { + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as u32).write_payload(addr); } +} - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_REG_VERDICT => Ok(Self::Verdict), - libc::NFT_REG_1 => Ok(Self::Reg1), - libc::NFT_REG_2 => Ok(Self::Reg2), - libc::NFT_REG_3 => Ok(Self::Reg3), - libc::NFT_REG_4 => Ok(Self::Reg4), - _ => Err(DeserializationError::InvalidValue), - } +impl NfNetlinkDeserializable for Register { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + let (val, remaining) = u32::deserialize(buf)?; + Ok(( + match val { + NFT_REG_VERDICT => Self::Verdict, + NFT_REG_1 => Self::Reg1, + NFT_REG_2 => Self::Reg2, + NFT_REG_3 => Self::Reg3, + NFT_REG_4 => Self::Reg4, + _ => return Err(DecodeError::UnknownRegisterValue), + }, + remaining, + )) } } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 3c4c374..326ef3b 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,11 +1,90 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc::{self, c_char}}; -use std::ffi::{CStr, CString}; +use std::fmt::Debug; + +use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; + +use super::{Expression, Immediate, ImmediateData, Register, Rule}; +use crate::{ + create_expr_type, impl_attr_getters_and_setters, + nlmsg::{NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable}, + parser::{DecodeError, InnerFormat}, + sys::{self, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_REG_VERDICT, NFT_RETURN}, +}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(i32)] +pub enum VerdictType { + Drop = NF_DROP, + Accept = NF_ACCEPT, + Queue = NF_QUEUE, + Continue = NFT_CONTINUE, + Break = NFT_BREAK, + Jump = NFT_JUMP, + Goto = NFT_GOTO, + Return = NFT_RETURN, +} + +impl NfNetlinkAttribute for VerdictType { + fn get_size(&self) -> usize { + (*self as i32).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for VerdictType { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok(( + match v { + NF_DROP => VerdictType::Drop, + NF_ACCEPT => VerdictType::Accept, + NF_QUEUE => VerdictType::Queue, + NFT_CONTINUE => VerdictType::Continue, + NFT_BREAK => VerdictType::Break, + NFT_JUMP => VerdictType::Jump, + NFT_GOTO => VerdictType::Goto, + NFT_RETURN => VerdictType::Goto, + _ => return Err(DecodeError::UnknownExpressionVerdictType), + }, + remaining_data, + )) + } +} + +create_expr_type!( + nested with_builder : VerdictAttribute, + [ + ( + get_code, + set_code, + with_code, + sys::NFTA_VERDICT_CODE, + ExprVerdictType, + VerdictType + ), + ( + get_chain, + set_chain, + with_chain, + sys::NFTA_VERDICT_CHAIN, + String, + String + ), + ( + get_chain_id, + set_chain_id, + with_chain_id, + sys::NFTA_VERDICT_CHAIN_ID, + U32, + u32 + ) + ] +); -/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl -/// terms, but here it is simplified to only represent a verdict. #[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum Verdict { +pub enum VerdictKind { /// Silently drop the packet. Drop, /// Accept the packet and let it pass. @@ -14,135 +93,32 @@ pub enum Verdict { Continue, Break, Jump { - chain: CString, + chain: String, }, Goto { - chain: CString, + chain: String, }, Return, } -impl Verdict { - fn chain(&self) -> Option<&CStr> { - match *self { - Verdict::Jump { ref chain } => Some(chain.as_c_str()), - Verdict::Goto { ref chain } => Some(chain.as_c_str()), - _ => None, - } - } -} - -impl Expression for Verdict { - fn get_raw_name() -> *const libc::c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let mut chain = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { - let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); - - if raw_chain.is_null() { - return Err(DeserializationError::NullPointer); - } - chain = Some(CStr::from_ptr(raw_chain).to_owned()); - } - - let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); - - match verdict as i32 { - libc::NF_DROP => Ok(Verdict::Drop), - libc::NF_ACCEPT => Ok(Verdict::Accept), - libc::NF_QUEUE => Ok(Verdict::Queue), - libc::NFT_CONTINUE => Ok(Verdict::Continue), - libc::NFT_BREAK => Ok(Verdict::Break), - libc::NFT_JUMP => { - if let Some(chain) = chain { - Ok(Verdict::Jump { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_GOTO => { - if let Some(chain) = chain { - Ok(Verdict::Goto { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_RETURN => Ok(Verdict::Return), - _ => Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let immediate_const = match *self { - Verdict::Drop => libc::NF_DROP, - Verdict::Accept => libc::NF_ACCEPT, - Verdict::Queue => libc::NF_QUEUE, - Verdict::Continue => libc::NFT_CONTINUE, - Verdict::Break => libc::NFT_BREAK, - Verdict::Jump { .. } => libc::NFT_JUMP, - Verdict::Goto { .. } => libc::NFT_GOTO, - Verdict::Return => libc::NFT_RETURN, +impl Immediate { + pub fn new_verdict(kind: VerdictKind) -> Self { + let code = match kind { + VerdictKind::Drop => VerdictType::Drop, + VerdictKind::Accept => VerdictType::Accept, + VerdictKind::Queue => VerdictType::Queue, + VerdictKind::Continue => VerdictType::Continue, + VerdictKind::Break => VerdictType::Break, + VerdictKind::Jump { .. } => VerdictType::Jump, + VerdictKind::Goto { .. } => VerdictType::Goto, + VerdictKind::Return => VerdictType::Return, }; - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc( - b"immediate\0" as *const _ as *const c_char - )); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - libc::NFT_REG_VERDICT as u32, - ); - - if let Some(chain) = self.chain() { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr()); - } - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_VERDICT as u16, - immediate_const as u32, - ); - - expr + let mut data = VerdictAttribute::builder().with_code(code); + if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { + data.set_chain(chain); } + Immediate::builder() + .with_dreg(Register::Verdict) + .with_data(ImmediateData::builder().with_verdict(data)) } } - -#[macro_export] -macro_rules! nft_expr_verdict { - (drop) => { - $crate::expr::Verdict::Drop - }; - (accept) => { - $crate::expr::Verdict::Accept - }; - (reject icmp $code:expr) => { - $crate::expr::Verdict::Reject(RejectionType::Icmp($code)) - }; - (reject tcp-rst) => { - $crate::expr::Verdict::Reject(RejectionType::TcpRst) - }; - (queue) => { - $crate::expr::Verdict::Queue - }; - (continue) => { - $crate::expr::Verdict::Continue - }; - (break) => { - $crate::expr::Verdict::Break - }; - (jump $chain:expr) => { - $crate::expr::Verdict::Jump { chain: $chain } - }; - (goto $chain:expr) => { - $crate::expr::Verdict::Goto { chain: $chain } - }; - (return) => { - $crate::expr::Verdict::Return - }; -} @@ -95,7 +95,7 @@ macro_rules! try_alloc { mod batch; pub use batch::{default_batch_page_size, Batch}; -//pub mod expr; +pub mod expr; mod table; pub use table::list_tables; @@ -113,9 +113,9 @@ pub mod query; pub mod nlmsg; pub mod parser; -//mod rule; -//pub use rule::Rule; -//pub use rule::{get_rules_cb, list_rules_for_chain}; +mod rule; +pub use rule::list_rules_for_chain; +pub use rule::Rule; //mod rule_methods; //pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; diff --git a/src/nlmsg.rs b/src/nlmsg.rs index b7f90e9..8960146 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -2,10 +2,11 @@ use std::{collections::BTreeMap, fmt::Debug, mem::size_of}; use crate::{ parser::{ - pad_netlink_object, pad_netlink_object_with_variable_size, AttributeType, DecodeError, + pad_netlink_object, pad_netlink_object_with_variable_size, write_attribute, AttributeType, + DecodeError, }, sys::{ - nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + nfgenmsg, nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, }, MsgType, ProtocolFamily, @@ -87,7 +88,11 @@ impl<'a> NfNetlinkWriter<'a> { } pub trait AttributeDecoder { - fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<AttributeType, DecodeError>; + fn decode_attribute( + attrs: &NfNetlinkAttributes, + attr_type: u16, + buf: &[u8], + ) -> Result<AttributeType, DecodeError>; } pub trait NfNetlinkDeserializable: Sized { @@ -141,3 +146,27 @@ impl NfNetlinkAttributes { } } } + +impl NfNetlinkAttribute for NfNetlinkAttributes { + fn get_size(&self) -> usize { + let mut size = 0; + + for (_type, attr) in self.attributes.iter() { + // Attribute header + attribute value + size += pad_netlink_object::<nlattr>() + + pad_netlink_object_with_variable_size(attr.get_size()); + } + + size + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + for (ty, attr) in self.attributes.iter() { + debug!("writing attribute {} - {:?}", ty, attr); + write_attribute(*ty, attr, addr); + let size = pad_netlink_object::<nlattr>() + + pad_netlink_object_with_variable_size(attr.get_size()); + addr = addr.offset(size as isize); + } + } +} diff --git a/src/parser.rs b/src/parser.rs index 25033d2..42bcb00 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -9,6 +9,7 @@ use std::{ use thiserror::Error; use crate::{ + expr::ExpressionHolder, nlmsg::{ AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkWriter, @@ -32,6 +33,9 @@ pub enum DecodeError { #[error("The message holds unexpected data")] InvalidDataSize, + #[error("Missing information in the chain to create a rule")] + MissingChainInformationError, + #[error("Invalid subsystem, expected NFTABLES")] InvalidSubsystem(u8), @@ -59,6 +63,18 @@ pub enum DecodeError { #[error("Invalid policy for a chain")] UnknownChainPolicy, + #[error("Invalid value for a register")] + UnknownRegisterValue, + + #[error("Invalid type for a verdict expression")] + UnknownExpressionVerdictType, + + #[error("The object does not contain a name for the expression being parsed")] + MissingExpressionName, + + #[error("The expression name was not recognized")] + UnknownExpressionName(String), + #[error("Unsupported attribute type")] UnsupportedAttributeType(u16), @@ -192,7 +208,7 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -unsafe fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, mut buf: *mut u8) { +pub unsafe fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, mut buf: *mut u8) { let header_len = pad_netlink_object::<libc::nlattr>(); // copy the header *(buf as *mut nlattr) = nlattr { @@ -330,32 +346,6 @@ impl NfNetlinkDeserializable for ProtocolFamily { } } -pub type NestedAttribute = NfNetlinkAttributes; - -// parts of the NfNetlinkAttribute trait we need for handling nested objects -impl NfNetlinkAttribute for NestedAttribute { - fn get_size(&self) -> usize { - let mut size = 0; - - for (_type, attr) in self.attributes.iter() { - // Attribute header + attribute value - size += pad_netlink_object::<nlattr>() - + pad_netlink_object_with_variable_size(attr.get_size()); - } - - size - } - - unsafe fn write_payload(&self, mut addr: *mut u8) { - for (ty, attr) in self.attributes.iter() { - write_attribute(*ty, attr, addr); - let size = pad_netlink_object::<nlattr>() - + pad_netlink_object_with_variable_size(attr.get_size()); - addr = addr.offset(size as isize); - } - } -} - pub struct NfNetlinkAttributeReader<'a> { buf: &'a [u8], pos: usize, @@ -384,23 +374,28 @@ impl<'a> NfNetlinkAttributeReader<'a> { pub fn decode<T: AttributeDecoder + 'static>( mut self, ) -> Result<NfNetlinkAttributes, DecodeError> { + debug!( + "Calling NfNetlinkAttributeReader::decode() on {}", + std::any::type_name::<T>() + ); while self.remaining_size > pad_netlink_object::<nlattr>() { let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(self.buf[self.pos..].as_ptr()) }; - // TODO: ignore the byteorder and nested attributes for now + // ignore the byteorder and nested attributes let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; self.pos += pad_netlink_object::<nlattr>(); let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>(); match T::decode_attribute( + &self.attrs, nla_type, &self.buf[self.pos..self.pos + attr_remaining_size], ) { Ok(x) => self.attrs.set_attr(nla_type, x), Err(DecodeError::UnsupportedAttributeType(t)) => info!( - "Ignore attribute type {} for type identified by {:?}", + "Ignoring unsupported attribute type {} for type {}", t, - TypeId::of::<T>() + std::any::type_name::<T>() ), Err(e) => return Err(e), } @@ -417,6 +412,7 @@ impl<'a> NfNetlinkAttributeReader<'a> { } } +#[macro_export] macro_rules! impl_attribute_holder { ($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => { #[derive(Debug, Clone, PartialEq, Eq)] @@ -485,12 +481,21 @@ impl_attribute_holder!( [ChainHook, crate::chain::Hook], [ChainPolicy, crate::chain::ChainPolicy], [ChainType, crate::chain::ChainType], - [ProtocolFamily, crate::ProtocolFamily] + [ProtocolFamily, crate::ProtocolFamily], + [Expression, crate::expr::ExpressionHolder], + [ExpressionVariant, crate::expr::ExpressionVariant], + [ExpressionList, crate::expr::ExpressionList], + [ExprLog, crate::expr::Log], + [ExprImmediate, crate::expr::ImmediateData], + [ExprImmediateData, crate::expr::ImmediateData], + [ExprVerdictAttribute, crate::expr::VerdictAttribute], + [ExprVerdictType, crate::expr::VerdictType], + [Register, crate::expr::Register] ); #[macro_export] macro_rules! impl_attr_getters_and_setters { - ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty $(, $nested:literal)?)),+]) => { + (without_decoder $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { impl $struct { $( #[allow(dead_code)] @@ -512,26 +517,6 @@ macro_rules! impl_attr_getters_and_setters { )+ } - impl $crate::nlmsg::AttributeDecoder for $struct { - #[allow(dead_code)] - fn decode_attribute(attr_type: u16, buf: &[u8]) -> Result<$crate::parser::AttributeType, $crate::parser::DecodeError> { - use $crate::nlmsg::NfNetlinkDeserializable; - match attr_type { - $( - x if x == $attr_name => { - let (val, remaining) = <$type>::deserialize(buf)?; - if remaining.len() != 0 { - return Err($crate::parser::DecodeError::InvalidDataSize); - } - Ok($crate::parser::AttributeType::$internal_name(val)) - }, - )+ - _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), - } - } - } - - impl $crate::parser::InnerFormat for $struct { fn inner_format_struct<'a, 'b: 'a>(&'a self, mut s: std::fmt::DebugStruct<'a, 'b>) -> Result<std::fmt::DebugStruct<'a, 'b>, std::fmt::Error> { $( @@ -552,6 +537,33 @@ macro_rules! impl_attr_getters_and_setters { Ok(s) } } + + }; + (decoder $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + impl $crate::nlmsg::AttributeDecoder for $struct { + #[allow(dead_code)] + fn decode_attribute(_attrs: &$crate::nlmsg::NfNetlinkAttributes, attr_type: u16, buf: &[u8]) -> Result<$crate::parser::AttributeType, $crate::parser::DecodeError> { + use $crate::nlmsg::NfNetlinkDeserializable; + debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<$struct>()); + match attr_type { + $( + x if x == $attr_name => { + debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); + let (val, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::parser::DecodeError::InvalidDataSize); + } + Ok($crate::parser::AttributeType::$internal_name(val)) + }, + )+ + _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + }; + ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + $crate::impl_attr_getters_and_setters!(without_decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_attr_getters_and_setters!(decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); }; } @@ -560,6 +572,7 @@ pub fn parse_object<T: AttributeDecoder + 'static>( add_obj: u32, del_obj: u32, ) -> Result<(NfNetlinkAttributes, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() running"); let (hdr, msg) = parse_nlmsg(buf)?; let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; diff --git a/src/query.rs b/src/query.rs index da886c0..8ea7b89 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,7 +1,7 @@ use std::os::unix::prelude::RawFd; use crate::{ - nlmsg::{NfNetlinkObject, NfNetlinkWriter}, + nlmsg::{NfNetlinkAttributes, NfNetlinkObject, NfNetlinkWriter}, parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size}, sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI}, ProtocolFamily, @@ -77,14 +77,10 @@ pub(crate) fn recv_and_process<'a, T>( break; } - debug!("calling parse_nlmsg"); + debug!("Calling parse_nlmsg"); let (nlmsghdr, msg) = parse_nlmsg(&buf)?; debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg); - // we cannot know when a message will end if we are not receiving messages ending with an - // NlMsg::Done marker, and if a maximum sequence number wasn't specified either - if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { - return Err(Error::UndecidableMessageTermination); - } + match msg { NlMsg::Done => { return Ok(()); @@ -102,6 +98,12 @@ pub(crate) fn recv_and_process<'a, T>( } } + // we cannot know when a sequence of messages will end if the messages do not end + // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { + return Err(Error::UndecidableMessageTermination); + } + // retrieve the next message if let Some(max_seq) = max_seq { if nlmsghdr.nlmsg_seq >= max_seq { @@ -150,10 +152,11 @@ where /// Returns a buffer containing a netlink message which requests a list of all the netfilter /// matching objects (e.g. tables, chains, rules, ...). /// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter. -pub fn get_list_of_objects<T>(msg_type: u16, seq: u32, filter: Option<&T>) -> Result<Vec<u8>, Error> -where - T: NfNetlinkObject, -{ +pub fn get_list_of_objects( + msg_type: u16, + seq: u32, + filter: Option<&NfNetlinkAttributes>, +) -> Result<Vec<u8>, Error> { let mut buffer = Vec::new(); let mut writer = NfNetlinkWriter::new(&mut buffer); writer.write_header( @@ -163,10 +166,10 @@ where seq, None, ); - writer.finalize_writing_object(); if let Some(filter) = filter { - filter.add_or_remove(&mut writer, crate::MsgType::Add, 0); + filter.serialize(&mut writer); } + writer.finalize_writing_object(); Ok(buffer) } @@ -177,13 +180,13 @@ where pub fn list_objects_with_data<'a, Object, Accumulator>( data_type: u16, cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), Error>, - filter: Option<&Object>, + filter: Option<&NfNetlinkAttributes>, working_data: &'a mut Accumulator, ) -> Result<(), Error> where Object: NfNetlinkObject, { - debug!("listing objects of kind {}", data_type); + debug!("Listing objects of kind {}", data_type); let sock = socket::socket( AddressFamily::Netlink, SockType::Raw, @@ -203,6 +206,7 @@ where sock, None, Some(&|buf: &[u8], working_data: &mut Accumulator| { + debug!("Calling Object::deserialize()"); cb(Object::deserialize(buf)?.0, working_data) }), working_data, diff --git a/src/rule.rs b/src/rule.rs index 80ca0c7..a596fce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,73 +1,63 @@ -use crate::expr::ExpressionWrapper; -use crate::nlmsg::NlMsg; -#[cfg(feature = "query")] -use crate::query::{Nfgenmsg, ParseError}; -use crate::sys::{self, libc}; -use crate::{chain::Chain, expr::Expression, MsgType}; -use std::ffi::{c_void, CStr, CString}; +use crate::expr::ExpressionList; +use crate::nlmsg::{ + NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, +}; +use crate::parser::InnerFormat; +use crate::parser::{parse_object, DecodeError}; +use crate::query::list_objects_with_data; +use crate::sys::{self, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_ACK, NLM_F_CREATE}; +use crate::{chain::Chain, MsgType}; +use crate::{impl_attr_getters_and_setters, ProtocolFamily}; +use std::convert::TryFrom; use std::fmt::Debug; -use std::os::raw::c_char; -use std::rc::Rc; /// A nftables firewall rule. +#[derive(PartialEq, Eq)] pub struct Rule { - pub(crate) rule: *mut sys::nftnl_rule, - pub(crate) chain: Rc<Chain>, + inner: NfNetlinkAttributes, + family: ProtocolFamily, } impl Rule { /// Creates a new rule object in the given [`Chain`]. /// /// [`Chain`]: struct.Chain.html - pub fn new(chain: Rc<Chain>) -> Rule { - unsafe { - let rule = try_alloc!(sys::nftnl_rule_alloc()); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - Rule { rule, chain } + pub fn new(chain: &Chain) -> Result<Rule, DecodeError> { + let inner = NfNetlinkAttributes::new(); + Ok(Rule { + inner, + family: chain.get_family(), } + .with_table( + chain + .get_table() + .ok_or(DecodeError::MissingChainInformationError)?, + ) + .with_chain( + chain + .get_name() + .ok_or(DecodeError::MissingChainInformationError)?, + )) } - pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self { - Rule { rule, chain } + pub fn get_family(&self) -> ProtocolFamily { + self.family } - pub fn get_position(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) } + pub fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } - /// Sets the position of this rule within the chain it lives in. By default a new rule is added - /// to the end of the chain. - pub fn set_position(&mut self, position: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position); - } - } - - pub fn get_handle(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) } + pub fn with_family(mut self, family: ProtocolFamily) -> Self { + self.set_family(family); + self } - pub fn set_handle(&mut self, handle: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle); - } + fn raw_attributes(&self) -> &NfNetlinkAttributes { + &self.inner } + /* /// Adds an expression to this rule. Expressions are evaluated from first to last added. /// As soon as an expression does not match the packet it's being evaluated for, evaluation /// stops and the packet is evaluated against the next rule in the chain. @@ -121,18 +111,6 @@ impl Rule { RuleExprsIter::new(self.clone()) } - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_rule { - self.rule as *const sys::nftnl_rule - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule { - self.rule - } - /// Performs a deep comparizon of rules, by checking they have the same expressions inside. /// This is not enabled by default in our PartialEq implementation because of the difficulty to /// compare an expression generated by the library with the expressions returned by the kernel @@ -185,14 +163,78 @@ impl Rule { } } } + */ } impl Debug for Rule { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) + let mut res = f.debug_struct("Rule"); + res.field("family", &self.family); + self.inner_format_struct(res)?.finish() + } +} + +impl NfNetlinkObject for Rule { + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => NFT_MSG_NEWRULE, + MsgType::Del => NFT_MSG_DELRULE, + } as u16; + writer.write_header( + raw_msg_type, + self.family, + (if let MsgType::Add = msg_type { + NLM_F_CREATE + } else { + 0 + } | NLM_F_ACK) as u16, + seq, + None, + ); + self.inner.serialize(writer); + writer.finalize_writing_object(); + } +} + +impl NfNetlinkDeserializable for Rule { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (inner, nfgenmsg, remaining_data) = + parse_object::<Self>(buf, NFT_MSG_NEWRULE, NFT_MSG_DELRULE)?; + + Ok(( + Self { + inner, + family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, + }, + remaining_data, + )) } } +impl_attr_getters_and_setters!( + Rule, + [ + (get_id, set_id, with_id, sys::NFTA_RULE_ID, U32, u32), + (get_handle, set_handle, with_handle, sys::NFTA_RULE_HANDLE, U64, u64), + // Sets the position of this rule within the chain it lives in. By default a new rule is added + // to the end of the chain. + (get_position, set_position, with_position, sys::NFTA_RULE_POSITION, U64, u64), + (get_table, set_table, with_table, sys::NFTA_RULE_TABLE, String, String), + (get_chain, set_chain, with_chain, sys::NFTA_RULE_CHAIN, String, String), + ( + get_userdata, + set_userdata, + with_userdata, + sys::NFTA_RULE_USERDATA, + VecU8, + Vec<u8> + ), + (get_expressions, set_expressions, with_expressions, sys::NFTA_RULE_EXPRESSIONS, ExpressionList, ExpressionList) + ] +); + +/* + impl PartialEq for Rule { fn eq(&self, other: &Self) -> bool { if self.get_chain() != other.get_chain() { @@ -285,74 +327,19 @@ impl Drop for RuleExprsIter { unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; } } +*/ -#[cfg(feature = "query")] -pub fn get_rules_cb( - header: &libc::nlmsghdr, - _genmsg: &Nfgenmsg, - _data: &[u8], - (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> Result<(), crate::query::Error> { - unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule == std::ptr::null_mut() { - return Err(ParseError::Custom(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "Rule allocation failed", - ))) - .into()); - } - let err = sys::nftnl_rule_nlmsg_parse(header, rule); - if err < 0 { - sys::nftnl_rule_free(rule); - return Err(ParseError::Custom(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "The netlink table couldn't be parsed !?", - ))) - .into()); - } - - rules.push(Rule::from_raw(rule, chain.clone())); - } - - Ok(()) -} - -#[cfg(feature = "query")] -pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { +pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, crate::query::Error> { let mut result = Vec::new(); - crate::query::list_objects_with_data( + list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - &get_rules_cb, - &mut (chain, &mut result), - // only retrieve rules from the currently targetted chain - Some(&|hdr| unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule as *const _ == std::ptr::null() { - return Err(crate::query::Error::NetlinkAllocationFailed); - } - - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - sys::nftnl_rule_nlmsg_build_payload(hdr, rule); - - sys::nftnl_rule_free(rule); + &|rule: Rule, rules: &mut Vec<Rule>| { + rules.push(rule); Ok(()) - }), + }, + // only retrieve rules from the currently targetted chain + Some(&Rule::new(chain)?.raw_attributes()), + &mut result, )?; Ok(result) } diff --git a/src/table.rs b/src/table.rs index 768eedd..5074ac9 100644 --- a/src/table.rs +++ b/src/table.rs @@ -7,7 +7,7 @@ use crate::nlmsg::{ use crate::parser::{parse_object, DecodeError, InnerFormat}; use crate::sys::{ self, NFNL_SUBSYS_NFTABLES, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, - NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, + NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, }; use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily}; @@ -72,7 +72,17 @@ impl NfNetlinkObject for Table { MsgType::Add => NFT_MSG_NEWTABLE, MsgType::Del => NFT_MSG_DELTABLE, } as u16; - writer.write_header(raw_msg_type, self.family, NLM_F_ACK as u16, seq, None); + writer.write_header( + raw_msg_type, + self.family, + (if let MsgType::Add = msg_type { + NLM_F_CREATE + } else { + 0 + } | NLM_F_ACK) as u16, + seq, + None, + ); self.inner.serialize(writer); writer.finalize_writing_object(); } |