diff options
author | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
commit | d5b9ec5185a27414286ee303eb3d21ce3069db09 (patch) | |
tree | 369eb90e8a2da307d7cd8f0b15a3318bbdba0003 /src/expr | |
parent | 3e48e7efa516183d623f80d2e4e393cecc2acde9 (diff) | |
parent | c3e3773cccd01f80f2d72a7691e0654d304e6b2d (diff) |
Merge branch 'no_mnl' into 'master'
experimental support for a full-rust rewrite of the codebase (no libnftnl/libmnl anymore)
See merge request rustwall/rustables!16
Diffstat (limited to 'src/expr')
-rw-r--r-- | src/expr/bitwise.rs | 100 | ||||
-rw-r--r-- | src/expr/cmp.rs | 204 | ||||
-rw-r--r-- | src/expr/counter.rs | 43 | ||||
-rw-r--r-- | src/expr/ct.rs | 108 | ||||
-rw-r--r-- | src/expr/immediate.rs | 154 | ||||
-rw-r--r-- | src/expr/log.rs | 127 | ||||
-rw-r--r-- | src/expr/lookup.rs | 94 | ||||
-rw-r--r-- | src/expr/masquerade.rs | 28 | ||||
-rw-r--r-- | src/expr/meta.rs | 183 | ||||
-rw-r--r-- | src/expr/mod.rs | 314 | ||||
-rw-r--r-- | src/expr/nat.rs | 102 | ||||
-rw-r--r-- | src/expr/payload.rs | 443 | ||||
-rw-r--r-- | src/expr/register.rs | 33 | ||||
-rw-r--r-- | src/expr/reject.rs | 109 | ||||
-rw-r--r-- | src/expr/verdict.rs | 169 | ||||
-rw-r--r-- | src/expr/wrapper.rs | 61 |
16 files changed, 654 insertions, 1618 deletions
diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index d34d22c..fb40a04 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,69 +1,47 @@ -use super::{Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// Expression for performing bitwise masking and XOR on the data in a register. -pub struct Bitwise<M: ToSlice, X: ToSlice> { - mask: M, - xor: X, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::parser_impls::NfNetlinkData; +use crate::sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Bitwise { + #[field(NFTA_BITWISE_SREG)] + sreg: Register, + #[field(NFTA_BITWISE_DREG)] + dreg: Register, + #[field(NFTA_BITWISE_LEN)] + len: u32, + #[field(NFTA_BITWISE_MASK)] + mask: NfNetlinkData, + #[field(NFTA_BITWISE_XOR)] + xor: NfNetlinkData, } -impl<M: ToSlice, X: ToSlice> Bitwise<M, X> { - /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and - /// then performs xor with the value in `xor`. - pub fn new(mask: M, xor: X) -> Self { - Self { mask, xor } +impl Expression for Bitwise { + fn get_name() -> &'static str { + "bitwise" } } -impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> { - fn get_raw_name() -> *const c_char { - b"bitwise\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let mask = self.mask.to_slice(); - let xor = self.xor.to_slice(); - assert!(mask.len() == xor.len()); - let len = mask.len() as u32; - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_DREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_BITWISE_LEN as u16, len); - - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_MASK as u16, - mask.as_ref() as *const _ as *const c_void, - len, - ); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_XOR as u16, - xor.as_ref() as *const _ as *const c_void, - len, - ); - - expr +impl Bitwise { + /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and + /// then performs xor with the value in `xor` + pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> { + let mask = mask.into(); + let xor = xor.into(); + if mask.len() != xor.len() { + return Err(BuilderError::IncompatibleLength); } + Ok(Bitwise::default() + .with_sreg(Register::Reg1) + .with_dreg(Register::Reg1) + .with_len(mask.len() as u32) + .with_xor(NfNetlinkData::default().with_value(xor)) + .with_mask(NfNetlinkData::default().with_value(mask))) } } - -#[macro_export] -macro_rules! nft_expr_bitwise { - (mask $mask:expr,xor $xor:expr) => { - $crate::expr::Bitwise::new($mask, $xor) - }; -} diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index f6ea900..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,187 +1,64 @@ -use super::{DeserializationError, Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::{ - borrow::Cow, - ffi::{c_void, CString}, - os::raw::c_char, +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::{ + parser_impls::NfNetlinkData, + sys::{ + NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, + NFT_CMP_LTE, NFT_CMP_NEQ, + }, }; +use super::{Expression, Register}; + /// Comparison operator. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32, nested = true)] pub enum CmpOp { /// Equals. - Eq, + Eq = NFT_CMP_EQ, /// Not equal. - Neq, + Neq = NFT_CMP_NEQ, /// Less than. - Lt, + Lt = NFT_CMP_LT, /// Less than, or equal. - Lte, + Lte = NFT_CMP_LTE, /// Greater than. - Gt, + Gt = NFT_CMP_GT, /// Greater than, or equal. - Gte, -} - -impl CmpOp { - /// Returns the corresponding `NFT_*` constant for this comparison operation. - pub fn to_raw(self) -> u32 { - use self::CmpOp::*; - match self { - Eq => libc::NFT_CMP_EQ as u32, - Neq => libc::NFT_CMP_NEQ as u32, - Lt => libc::NFT_CMP_LT as u32, - Lte => libc::NFT_CMP_LTE as u32, - Gt => libc::NFT_CMP_GT as u32, - Gte => libc::NFT_CMP_GTE as u32, - } - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - use self::CmpOp::*; - match val as i32 { - libc::NFT_CMP_EQ => Ok(Eq), - libc::NFT_CMP_NEQ => Ok(Neq), - libc::NFT_CMP_LT => Ok(Lt), - libc::NFT_CMP_LTE => Ok(Lte), - libc::NFT_CMP_GT => Ok(Gt), - libc::NFT_CMP_GTE => Ok(Gte), - _ => Err(DeserializationError::InvalidValue), - } - } + Gte = NFT_CMP_GTE, } /// Comparator expression. Allows comparing the content of the netfilter register with any value. -#[derive(Debug, PartialEq)] -pub struct Cmp<T> { +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct] +pub struct Cmp { + #[field(NFTA_CMP_SREG)] + sreg: Register, + #[field(NFTA_CMP_OP)] op: CmpOp, - data: T, + #[field(NFTA_CMP_DATA)] + data: NfNetlinkData, } -impl<T: ToSlice> Cmp<T> { +impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: T) -> Self { - Cmp { op, data } - } -} - -impl<T: ToSlice> Expression for Cmp<T> { - fn get_raw_name() -> *const c_char { - b"cmp\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let data = self.data.to_slice(); - trace!("Creating a cmp expr comparing with data {:?}", data); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CMP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16, self.op.to_raw()); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr - } - } -} - -impl<const N: usize> Expression for Cmp<[u8; N]> { - fn get_raw_name() -> *const c_char { - Cmp::<u8>::get_raw_name() - } - - /// The raw data contained inside `Cmp` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size - /// as the raw input would be *extremely* dangerous in terms of memory safety. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?; - Ok(Cmp { op, data }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { Cmp { - data: &self.data as &[u8], - op: self.op, + sreg: Some(Register::Reg1), + op: Some(op), + data: Some(NfNetlinkData::default().with_value(data.into())), } - .to_expr(rule) } } -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_cmp { - (@cmp_op ==) => { - $crate::expr::CmpOp::Eq - }; - (@cmp_op !=) => { - $crate::expr::CmpOp::Neq - }; - (@cmp_op <) => { - $crate::expr::CmpOp::Lt - }; - (@cmp_op <=) => { - $crate::expr::CmpOp::Lte - }; - (@cmp_op >) => { - $crate::expr::CmpOp::Gt - }; - (@cmp_op >=) => { - $crate::expr::CmpOp::Gte - }; - ($op:tt $data:expr) => { - $crate::expr::Cmp::new(nft_expr_cmp!(@cmp_op $op), $data) - }; +impl Expression for Cmp { + fn get_name() -> &'static str { + "cmp" + } } +/* /// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please note /// that it is faster to check interface index than name. /// @@ -207,13 +84,4 @@ impl ToSlice for InterfaceName { Cow::from(bytes) } } - -impl<'a> ToSlice for &'a InterfaceName { - fn to_slice(&self) -> Cow<'_, [u8]> { - let bytes = match *self { - InterfaceName::Exact(ref name) => name.as_bytes_with_nul(), - InterfaceName::StartingWith(ref name) => name.as_bytes(), - }; - Cow::from(bytes) - } -} +*/ diff --git a/src/expr/counter.rs b/src/expr/counter.rs index 4732e85..d22fb8a 100644 --- a/src/expr/counter.rs +++ b/src/expr/counter.rs @@ -1,46 +1,21 @@ -use super::{DeserializationError, Expression, Rule}; +use rustables_macros::nfnetlink_struct; + +use super::Expression; use crate::sys; -use std::os::raw::c_char; /// A counter expression adds a counter to the rule that is incremented to count number of packets /// and number of bytes for all packets that have matched the rule. -#[derive(Debug, PartialEq)] +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct] pub struct Counter { + #[field(sys::NFTA_COUNTER_BYTES)] pub nb_bytes: u64, + #[field(sys::NFTA_COUNTER_PACKETS)] pub nb_packets: u64, } -impl Counter { - pub fn new() -> Self { - Self { - nb_bytes: 0, - nb_packets: 0, - } - } -} - impl Expression for Counter { - fn get_raw_name() -> *const c_char { - b"counter\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); - let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); - Ok(Counter { - nb_bytes, - nb_packets, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets); - expr - } + fn get_name() -> &'static str { + "counter" } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index 7d6614c..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -1,9 +1,13 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_CT_DIRECTION, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_CT_SREG, NFT_CT_MARK, NFT_CT_STATE, +}; + +use super::{Expression, Register}; bitflags::bitflags! { - pub struct States: u32 { + pub struct ConnTrackState: u32 { const INVALID = 1; const ESTABLISHED = 2; const RELATED = 4; @@ -12,76 +16,54 @@ bitflags::bitflags! { } } -pub enum Conntrack { - State, - Mark { set: bool }, +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_enum(u32, nested = true)] +pub enum ConntrackKey { + State = NFT_CT_STATE, + Mark = NFT_CT_MARK, } -impl Conntrack { - fn raw_key(&self) -> u32 { - match *self { - Conntrack::State => libc::NFT_CT_STATE as u32, - Conntrack::Mark { .. } => libc::NFT_CT_MARK as u32, - } - } +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Conntrack { + #[field(NFTA_CT_DREG)] + pub dreg: Register, + #[field(NFTA_CT_KEY)] + pub key: ConntrackKey, + #[field(NFTA_CT_DIRECTION)] + pub direction: u8, + #[field(NFTA_CT_SREG)] + pub sreg: Register, } impl Expression for Conntrack { - fn get_raw_name() -> *const c_char { - b"ct\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "ct" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16); - let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16); - - match ct_key as i32 { - libc::NFT_CT_STATE => Ok(Conntrack::State), - libc::NFT_CT_MARK => Ok(Conntrack::Mark { - set: ct_sreg_is_set, - }), - _ => Err(DeserializationError::InvalidValue), - } - } +impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) } - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + pub fn set_mark_value(&mut self, reg: Register) { + self.set_sreg(reg); + self.set_key(ConntrackKey::Mark); + } - if let Conntrack::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16, self.raw_key()); + pub fn with_mark_value(mut self, reg: Register) -> Self { + self.set_mark_value(reg); + self + } - expr - } + pub fn retrieve_value(&mut self, key: ConntrackKey) { + self.set_key(key); + self.set_dreg(Register::Reg1); } -} -#[macro_export] -macro_rules! nft_expr_ct { - (state) => { - $crate::expr::Conntrack::State - }; - (mark set) => { - $crate::expr::Conntrack::Mark { set: true } - }; - (mark) => { - $crate::expr::Conntrack::Mark { set: false } - }; + pub fn with_retrieve_value(mut self, key: ConntrackKey) -> Self { + self.retrieve_value(key); + self + } } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 71453b3..2fd9bd5 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,124 +1,50 @@ -use super::{DeserializationError, Expression, Register, Rule, ToSlice}; -use crate::sys; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// An immediate expression. Used to set immediate data. Verdicts are handled separately by -/// [crate::expr::Verdict]. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Immediate<T> { - pub data: T, - pub register: Register, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register, Verdict, VerdictKind, VerdictType}; +use crate::{ + parser_impls::NfNetlinkData, + sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Immediate { + #[field(NFTA_IMMEDIATE_DREG)] + dreg: Register, + #[field(NFTA_IMMEDIATE_DATA)] + data: NfNetlinkData, } -impl<T> Immediate<T> { - pub fn new(data: T, register: Register) -> Self { - Self { data, register } +impl Immediate { + pub fn new_data(data: Vec<u8>, register: Register) -> Self { + Immediate::default() + .with_dreg(register) + .with_data(NfNetlinkData::default().with_value(data)) } -} - -impl<T: ToSlice> Expression for Immediate<T> { - fn get_raw_name() -> *const c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - self.register.to_raw(), - ); - - let data = self.data.to_slice(); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr + pub fn new_verdict(kind: VerdictKind) -> Self { + let code = match kind { + VerdictKind::Drop => VerdictType::Drop, + VerdictKind::Accept => VerdictType::Accept, + VerdictKind::Queue => VerdictType::Queue, + VerdictKind::Continue => VerdictType::Continue, + VerdictKind::Break => VerdictType::Break, + VerdictKind::Jump { .. } => VerdictType::Jump, + VerdictKind::Goto { .. } => VerdictType::Goto, + VerdictKind::Return => VerdictType::Return, + }; + let mut data = Verdict::default().with_code(code); + if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { + data.set_chain(chain); } + Immediate::default() + .with_dreg(Register::Verdict) + .with_data(NfNetlinkData::default().with_verdict(data)) } } -impl<const N: usize> Expression for Immediate<[u8; N]> { - fn get_raw_name() -> *const c_char { - Immediate::<u8>::get_raw_name() - } - - /// The raw data contained inside `Immediate` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Immediate::new(42u8, Register::Reg1)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size as the raw input - /// would be *extremely* dangerous in terms of memory safety. - // As casting bytes to any type of the same size as the input would be *extremely* dangerous in - // terms of memory safety, rustables only accept to deserialize expressions with variable-size - // data to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - ))?; - - Ok(Immediate { data, register }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - Immediate { - register: self.register, - data: &self.data as &[u8], - } - .to_expr(rule) +impl Expression for Immediate { + fn get_name() -> &'static str { + "immediate" } } - -#[macro_export] -macro_rules! nft_expr_immediate { - (data $value:expr) => { - $crate::expr::Immediate { - data: $value, - register: $crate::expr::Register::Reg1, - } - }; -} diff --git a/src/expr/log.rs b/src/expr/log.rs index 8d20b48..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,112 +1,41 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] /// A Log expression will log all packets that match the rule. -#[derive(Debug, PartialEq)] pub struct Log { - pub group: Option<LogGroup>, - pub prefix: Option<LogPrefix>, + #[field(NFTA_LOG_GROUP)] + group: u16, + #[field(NFTA_LOG_PREFIX)] + prefix: String, } -impl Expression for Log { - fn get_raw_name() -> *const sys::libc::c_char { - b"log\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut group = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) { - group = Some(LogGroup(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_LOG_GROUP as u16, - ) as u16)); - } - let mut prefix = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { - let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); - if raw_prefix.is_null() { - return Err(DeserializationError::NullPointer); - } else { - prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); - } - } - Ok(Log { group, prefix }) +impl Log { + pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> { + let mut res = Log::default(); + if let Some(group) = group { + res.set_group(group); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char)); - if let Some(log_group) = self.group { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32); - }; - if let Some(LogPrefix(prefix)) = &self.prefix { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr()); - }; + if let Some(prefix) = prefix { + let prefix = prefix.into(); - expr + if prefix.bytes().count() > 127 { + return Err(BuilderError::TooLongLogPrefix); + } + res.set_prefix(prefix); } + Ok(res) } } -#[derive(Error, Debug)] -pub enum LogPrefixError { - #[error("The log prefix string is more than 128 characters long")] - TooLongPrefix, - #[error("The log prefix string contains an invalid Nul character.")] - PrefixContainsANul(#[from] std::ffi::NulError), -} - -/// The NFLOG group that will be assigned to each log line. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct LogGroup(pub u16); - -/// A prefix that will get prepended to each log line. -#[derive(Debug, Clone, PartialEq)] -pub struct LogPrefix(CString); - -impl LogPrefix { - /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that - /// LogPrefix should not be more than 127 characters long. - pub fn new(prefix: &str) -> Result<Self, LogPrefixError> { - if prefix.chars().count() > 127 { - return Err(LogPrefixError::TooLongPrefix); - } - Ok(LogPrefix(CString::new(prefix)?)) +impl Expression for Log { + fn get_name() -> &'static str { + "log" } } - -#[macro_export] -macro_rules! nft_expr_log { - (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log { - group: $group, - prefix: $prefix, - } - }; - (prefix $prefix:expr) => { - $crate::expr::Log { - group: None, - prefix: $prefix, - } - }; - (group $group:ident) => { - $crate::expr::Log { - group: $group, - prefix: None, - } - }; - () => { - $crate::expr::Log { - group: None, - prefix: None, - } - }; -} diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs index a0cc021..2ef830e 100644 --- a/src/expr/lookup.rs +++ b/src/expr/lookup.rs @@ -1,78 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::set::Set; -use crate::sys::{self, libc}; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -#[derive(Debug, PartialEq)] +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG}; +use crate::Set; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] pub struct Lookup { - set_name: CString, + #[field(NFTA_LOOKUP_SET)] + set: String, + #[field(NFTA_LOOKUP_SREG)] + sreg: Register, + #[field(NFTA_LOOKUP_DREG)] + dreg: Register, + #[field(NFTA_LOOKUP_SET_ID)] set_id: u32, } impl Lookup { - /// Creates a new lookup entry. May return None if the set has no name. - pub fn new<K>(set: &Set<K>) -> Option<Self> { - set.get_name().map(|set_name| Lookup { - set_name: set_name.to_owned(), - set_id: set.get_id(), - }) - } -} - -impl Expression for Lookup { - fn get_raw_name() -> *const libc::c_char { - b"lookup\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16); - let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); - - if set_name.is_null() { - return Err(DeserializationError::NullPointer); - } - - let set_name = CStr::from_ptr(set_name).to_owned(); - - Ok(Lookup { set_id, set_name }) + /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name. + pub fn new(set: &Set) -> Result<Self, BuilderError> { + let mut res = Lookup::default() + .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?) + .with_sreg(Register::Reg1); + + if let Some(id) = set.get_id() { + res.set_set_id(*id); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_LOOKUP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_str( - expr, - sys::NFTNL_EXPR_LOOKUP_SET as u16, - self.set_name.as_ptr() as *const _ as *const c_char, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id); - // This code is left here since it's quite likely we need it again when we get further - // if self.reverse { - // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16, - // libc::NFT_LOOKUP_F_INV as u32); - // } - - expr - } + Ok(res) } } -#[macro_export] -macro_rules! nft_expr_lookup { - ($set:expr) => { - $crate::expr::Lookup::new($set) - }; +impl Expression for Lookup { + fn get_name() -> &'static str { + "lookup" + } } diff --git a/src/expr/masquerade.rs b/src/expr/masquerade.rs index c1a06de..dce787f 100644 --- a/src/expr/masquerade.rs +++ b/src/expr/masquerade.rs @@ -1,24 +1,20 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; + +use super::Expression; /// Sets the source IP to that of the output interface. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Masquerade; -impl Expression for Masquerade { - fn get_raw_name() -> *const sys::libc::c_char { - b"masq\0" as *const _ as *const c_char - } - - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Ok(Masquerade) +impl Clone for Masquerade { + fn clone(&self) -> Self { + Masquerade {} } +} - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }) +impl Expression for Masquerade { + fn get_name() -> &'static str { + "masq" } } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index a015f65..3ecb1d1 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,175 +1,62 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::sys; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32)] #[non_exhaustive] -pub enum Meta { +pub enum MetaType { /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. - Protocol, + Protocol = sys::NFT_META_PROTOCOL, /// Packet mark. - Mark { set: bool }, + Mark = sys::NFT_META_MARK, /// Packet input interface index (dev->ifindex). - Iif, + Iif = sys::NFT_META_IIF, /// Packet output interface index (dev->ifindex). - Oif, + Oif = sys::NFT_META_OIF, /// Packet input interface name (dev->name). - IifName, + IifName = sys::NFT_META_IIFNAME, /// Packet output interface name (dev->name). - OifName, + OifName = sys::NFT_META_OIFNAME, /// Packet input interface type (dev->type). - IifType, + IifType = libc::NFT_META_IIFTYPE, /// Packet output interface type (dev->type). - OifType, + OifType = sys::NFT_META_OIFTYPE, /// Originating socket UID (fsuid). - SkUid, + SkUid = sys::NFT_META_SKUID, /// Originating socket GID (fsgid). - SkGid, + SkGid = sys::NFT_META_SKGID, /// Netfilter protocol (Transport layer protocol). - NfProto, + NfProto = sys::NFT_META_NFPROTO, /// Layer 4 protocol number. - L4Proto, + L4Proto = sys::NFT_META_L4PROTO, /// Socket control group (skb->sk->sk_classid). - Cgroup, + Cgroup = sys::NFT_META_CGROUP, /// A 32bit pseudo-random number. - PRandom, + PRandom = sys::NFT_META_PRANDOM, } -impl Meta { - /// Returns the corresponding `NFT_*` constant for this meta expression. - pub fn to_raw_key(&self) -> u32 { - use Meta::*; - match *self { - Protocol => libc::NFT_META_PROTOCOL as u32, - Mark { .. } => libc::NFT_META_MARK as u32, - Iif => libc::NFT_META_IIF as u32, - Oif => libc::NFT_META_OIF as u32, - IifName => libc::NFT_META_IIFNAME as u32, - OifName => libc::NFT_META_OIFNAME as u32, - IifType => libc::NFT_META_IIFTYPE as u32, - OifType => libc::NFT_META_OIFTYPE as u32, - SkUid => libc::NFT_META_SKUID as u32, - SkGid => libc::NFT_META_SKGID as u32, - NfProto => libc::NFT_META_NFPROTO as u32, - L4Proto => libc::NFT_META_L4PROTO as u32, - Cgroup => libc::NFT_META_CGROUP as u32, - PRandom => libc::NFT_META_PRANDOM as u32, - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Meta { + #[field(sys::NFTA_META_DREG)] + dreg: Register, + #[field(sys::NFTA_META_KEY)] + key: MetaType, + #[field(sys::NFTA_META_SREG)] + sreg: Register, +} - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_META_PROTOCOL => Ok(Self::Protocol), - libc::NFT_META_MARK => Ok(Self::Mark { set: false }), - libc::NFT_META_IIF => Ok(Self::Iif), - libc::NFT_META_OIF => Ok(Self::Oif), - libc::NFT_META_IIFNAME => Ok(Self::IifName), - libc::NFT_META_OIFNAME => Ok(Self::OifName), - libc::NFT_META_IIFTYPE => Ok(Self::IifType), - libc::NFT_META_OIFTYPE => Ok(Self::OifType), - libc::NFT_META_SKUID => Ok(Self::SkUid), - libc::NFT_META_SKGID => Ok(Self::SkGid), - libc::NFT_META_NFPROTO => Ok(Self::NfProto), - libc::NFT_META_L4PROTO => Ok(Self::L4Proto), - libc::NFT_META_CGROUP => Ok(Self::Cgroup), - libc::NFT_META_PRANDOM => Ok(Self::PRandom), - _ => Err(DeserializationError::InvalidValue), - } +impl Meta { + pub fn new(ty: MetaType) -> Self { + Meta::default().with_dreg(Register::Reg1).with_key(ty) } } impl Expression for Meta { - fn get_raw_name() -> *const libc::c_char { - b"meta\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_META_KEY as u16, - ))?; - - if let Self::Mark { ref mut set } = ret { - *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); - } - - Ok(ret) - } + fn get_name() -> &'static str { + "meta" } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - if let Meta::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key()); - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_meta { - (proto) => { - $crate::expr::Meta::Protocol - }; - (mark set) => { - $crate::expr::Meta::Mark { set: true } - }; - (mark) => { - $crate::expr::Meta::Mark { set: false } - }; - (iif) => { - $crate::expr::Meta::Iif - }; - (oif) => { - $crate::expr::Meta::Oif - }; - (iifname) => { - $crate::expr::Meta::IifName - }; - (oifname) => { - $crate::expr::Meta::OifName - }; - (iiftype) => { - $crate::expr::Meta::IifType - }; - (oiftype) => { - $crate::expr::Meta::OifType - }; - (skuid) => { - $crate::expr::Meta::SkUid - }; - (skgid) => { - $crate::expr::Meta::SkGid - }; - (nfproto) => { - $crate::expr::Meta::NfProto - }; - (l4proto) => { - $crate::expr::Meta::L4Proto - }; - (cgroup) => { - $crate::expr::Meta::Cgroup - }; - (random) => { - $crate::expr::Meta::PRandom - }; } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index dc59507..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -3,14 +3,14 @@ //! //! [`Rule`]: struct.Rule.html -use std::borrow::Cow; -use std::net::IpAddr; -use std::net::Ipv4Addr; -use std::net::Ipv6Addr; +use std::fmt::Debug; -use super::rule::Rule; -use crate::sys::{self, libc}; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; + +use crate::error::DecodeError; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; +use crate::parser_impls::NfNetlinkList; +use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME}; mod bitwise; pub use self::bitwise::*; @@ -46,7 +46,7 @@ mod payload; pub use self::payload::*; mod reject; -pub use self::reject::{IcmpCode, Reject}; +pub use self::reject::{IcmpCode, Reject, RejectType}; mod register; pub use self::register::Register; @@ -54,189 +54,161 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -mod wrapper; -pub use self::wrapper::ExpressionWrapper; - -#[derive(Debug, Error)] -pub enum DeserializationError { - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, - - #[error(transparent)] - /// Couldn't find a matching protocol. - InvalidProtolFamily(#[from] super::InvalidProtocolFamily), -} - -/// Trait for every safe wrapper of an nftables expression. pub trait Expression { - /// Returns the raw name used by nftables to identify the rule. - fn get_raw_name() -> *const libc::c_char; - - /// Try to parse the expression from a raw nftables expression, returning a - /// [DeserializationError] if the attempted parsing failed. - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Err(DeserializationError::NotImplemented) - } - - /// Allocates and returns the low level `nftnl_expr` representation of this expression. The - /// caller to this method is responsible for freeing the expression. - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; + fn get_name() -> &'static str; } -/// A type that can be converted into a byte buffer. -pub trait ToSlice { - /// Returns the data this type represents. - fn to_slice(&self) -> Cow<'_, [u8]>; +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true, derive_decoder = false)] +pub struct RawExpression { + #[field(NFTA_EXPR_NAME)] + name: String, + #[field(NFTA_EXPR_DATA)] + data: ExpressionVariant, } -impl<'a> ToSlice for &'a [u8] { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Borrowed(self) +impl<T> From<T> for RawExpression +where + T: Expression, + ExpressionVariant: From<T>, +{ + fn from(val: T) -> Self { + RawExpression::default() + .with_name(T::get_name()) + .with_data(ExpressionVariant::from(val)) } } -impl<'a> ToSlice for &'a [u16] { - fn to_slice(&self) -> Cow<'_, [u8]> { - let ptr = self.as_ptr() as *const u8; - let len = self.len() * 2; - Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) - } -} - -impl ToSlice for IpAddr { - fn to_slice(&self) -> Cow<'_, [u8]> { - match *self { - IpAddr::V4(ref addr) => addr.to_slice(), - IpAddr::V6(ref addr) => addr.to_slice(), +macro_rules! create_expr_variant { + ($enum:ident $(, [$name:ident, $type:ty])+) => { + #[derive(Debug, Clone, PartialEq, Eq)] + pub enum $enum { + $( + $name($type), + )+ } - } -} -impl ToSlice for Ipv4Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} - -impl ToSlice for Ipv6Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} + impl $crate::nlmsg::NfNetlinkAttribute for $enum { + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + match self { + $( + $enum::$name(val) => val.get_size(), + )+ + } + } + + unsafe fn write_payload(&self, addr: *mut u8) { + match self { + $( + $enum::$name(val) => val.write_payload(addr), + )+ + } + } + } -impl ToSlice for u8 { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(vec![*self]) - } + $( + impl From<$type> for $enum { + fn from(val: $type) -> Self { + $enum::$name(val) + } + } + )+ + + impl $crate::nlmsg::AttributeDecoder for RawExpression { + fn decode_attribute( + &mut self, + attr_type: u16, + buf: &[u8], + ) -> Result<(), $crate::error::DecodeError> { + debug!("Decoding attribute {} in an expression", attr_type); + match attr_type { + x if x == sys::NFTA_EXPR_NAME => { + debug!("Calling {}::deserialize()", std::any::type_name::<String>()); + let (val, remaining) = String::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.name = Some(val); + Ok(()) + }, + x if x == sys::NFTA_EXPR_DATA => { + // we can assume we have already the name parsed, as that's how we identify the + // type of expression + let name = self.name.as_ref() + .ok_or($crate::error::DecodeError::MissingExpressionName)?; + match name { + $( + x if x == <$type>::get_name() => { + debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); + let (res, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.data = Some(ExpressionVariant::from(res)); + Ok(()) + }, + )+ + name => { + info!("Unrecognized expression '{}', generating an ExpressionRaw", name); + self.data = Some(ExpressionVariant::ExpressionRaw(ExpressionRaw::deserialize(buf)?.0)); + Ok(()) + } + } + }, + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + }; } -impl ToSlice for u16 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = (*self & 0x00ff) as u8; - let b1 = (*self >> 8) as u8; - Cow::Owned(vec![b0, b1]) +create_expr_variant!( + ExpressionVariant, + [Bitwise, Bitwise], + [Cmp, Cmp], + [Conntrack, Conntrack], + [Counter, Counter], + [ExpressionRaw, ExpressionRaw], + [Immediate, Immediate], + [Log, Log], + [Lookup, Lookup], + [Masquerade, Masquerade], + [Meta, Meta], + [Nat, Nat], + [Payload, Payload], + [Reject, Reject] +); + +pub type ExpressionList = NfNetlinkList<RawExpression>; + +// default type for expressions that we do not handle yet +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExpressionRaw(Vec<u8>); + +impl NfNetlinkAttribute for ExpressionRaw { + fn get_size(&self) -> usize { + self.0.get_size() } -} -impl ToSlice for u32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) + unsafe fn write_payload(&self, addr: *mut u8) { + self.0.write_payload(addr); } } -impl ToSlice for i32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) +impl NfNetlinkDeserializable for ExpressionRaw { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((ExpressionRaw(buf.to_vec()), &[])) } } -impl<'a> ToSlice for &'a str { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::from(self.as_bytes()) +// Because we loose the name of the expression when parsing, this is the only expression +// where deserializing a message and then reserializing it is invalid +impl Expression for ExpressionRaw { + fn get_name() -> &'static str { + "unknown_expression" } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr { - (bitwise mask $mask:expr,xor $xor:expr) => { - nft_expr_bitwise!(mask $mask, xor $xor) - }; - (cmp $op:tt $data:expr) => { - nft_expr_cmp!($op $data) - }; - (counter) => { - $crate::expr::Counter { nb_bytes: 0, nb_packets: 0} - }; - (ct $key:ident set) => { - nft_expr_ct!($key set) - }; - (ct $key:ident) => { - nft_expr_ct!($key) - }; - (immediate $expr:ident $value:expr) => { - nft_expr_immediate!($expr $value) - }; - (log group $group:ident prefix $prefix:expr) => { - nft_expr_log!(group $group prefix $prefix) - }; - (log group $group:ident) => { - nft_expr_log!(group $group) - }; - (log prefix $prefix:expr) => { - nft_expr_log!(prefix $prefix) - }; - (log) => { - nft_expr_log!() - }; - (lookup $set:expr) => { - nft_expr_lookup!($set) - }; - (masquerade) => { - $crate::expr::Masquerade - }; - (meta $expr:ident set) => { - nft_expr_meta!($expr set) - }; - (meta $expr:ident) => { - nft_expr_meta!($expr) - }; - (payload $proto:ident $field:ident) => { - nft_expr_payload!($proto $field) - }; - (verdict $verdict:ident) => { - nft_expr_verdict!($verdict) - }; - (verdict $verdict:ident $chain:expr) => { - nft_expr_verdict!($verdict $chain) - }; -} diff --git a/src/expr/nat.rs b/src/expr/nat.rs index ce6b881..406b2e6 100644 --- a/src/expr/nat.rs +++ b/src/expr/nat.rs @@ -1,99 +1,37 @@ -use super::{DeserializationError, Expression, Register, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc}; -use std::{convert::TryFrom, os::raw::c_char}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::{ + sys::{self, NFT_NAT_DNAT, NFT_NAT_SNAT}, + ProtocolFamily, +}; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(i32)] pub enum NatType { /// Source NAT. Changes the source address of a packet. - SNat = libc::NFT_NAT_SNAT, + SNat = NFT_NAT_SNAT, /// Destination NAT. Changes the destination address of a packet. - DNat = libc::NFT_NAT_DNAT, -} - -impl NatType { - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_NAT_SNAT => Ok(NatType::SNat), - libc::NFT_NAT_DNAT => Ok(NatType::DNat), - _ => Err(DeserializationError::InvalidValue), - } - } + DNat = NFT_NAT_DNAT, } /// A source or destination NAT statement. Modifies the source or destination address (and possibly /// port) of packets. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Nat { + #[field(sys::NFTA_NAT_TYPE)] pub nat_type: NatType, - pub family: ProtoFamily, + #[field(sys::NFTA_NAT_FAMILY)] + pub family: ProtocolFamily, + #[field(sys::NFTA_NAT_REG_ADDR_MIN)] pub ip_register: Register, - pub port_register: Option<Register>, + #[field(sys::NFTA_NAT_REG_PROTO_MIN)] + pub port_register: Register, } impl Expression for Nat { - fn get_raw_name() -> *const libc::c_char { - b"nat\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_TYPE as u16, - ))?; - - let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_FAMILY as u16, - ) as i32)?; - - let ip_register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - ))?; - - let mut port_register = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) { - port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - ))?); - } - - Ok(Nat { - ip_register, - nat_type, - family, - port_register, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }); - - unsafe { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, self.family as u32); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - self.ip_register.to_raw(), - ); - if let Some(port_register) = self.port_register { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - port_register.to_raw(), - ); - } - } - - expr + fn get_name() -> &'static str { + "nat" } } diff --git a/src/expr/payload.rs b/src/expr/payload.rs index a108fe8..d0b2cea 100644 --- a/src/expr/payload.rs +++ b/src/expr/payload.rs @@ -1,128 +1,96 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -pub trait HeaderField { - fn offset(&self) -> u32; - fn len(&self) -> u32; +use super::{Expression, Register}; +use crate::{ + error::DecodeError, + sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER}, +}; + +/// Payload expressions refer to data from the packet's payload. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Payload { + #[field(sys::NFTA_PAYLOAD_DREG)] + dreg: Register, + #[field(sys::NFTA_PAYLOAD_BASE)] + base: u32, + #[field(sys::NFTA_PAYLOAD_OFFSET)] + offset: u32, + #[field(sys::NFTA_PAYLOAD_LEN)] + len: u32, + #[field(sys::NFTA_PAYLOAD_SREG)] + sreg: Register, +} + +impl Expression for Payload { + fn get_name() -> &'static str { + "payload" + } } /// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum Payload { +pub enum HighLevelPayload { LinkLayer(LLHeaderField), Network(NetworkHeaderField), Transport(TransportHeaderField), } -impl Payload { - pub fn build(&self) -> RawPayload { +impl HighLevelPayload { + pub fn build(&self) -> Payload { match *self { - Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Network(ref f) => RawPayload::Network(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), + HighLevelPayload::LinkLayer(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_LL_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Network(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_NETWORK_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Transport(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_TRANSPORT_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), } + .with_dreg(Register::Reg1) } } -impl Expression for Payload { - fn get_raw_name() -> *const libc::c_char { - RawPayload::get_raw_name() - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - self.build().to_expr(rule) - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct RawPayloadData { - offset: u32, - len: u32, -} - -/// Because deserializing a `Payload` expression is not possible (there is not enough information -/// in the expression itself), this enum should be used to deserialize payloads. +/// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum RawPayload { - LinkLayer(RawPayloadData), - Network(RawPayloadData), - Transport(RawPayloadData), +pub enum PayloadType { + LinkLayer(LLHeaderField), + Network, + Transport, } -impl RawPayload { - fn base(&self) -> u32 { - match self { - Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32, - Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32, - Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32, +impl PayloadType { + pub fn parse_from_payload(raw: &Payload) -> Result<Self, DecodeError> { + if raw.base.is_none() { + return Err(DecodeError::PayloadMissingBase); } - } -} - -impl HeaderField for RawPayload { - fn offset(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset, + if raw.len.is_none() { + return Err(DecodeError::PayloadMissingLen); } - } - - fn len(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len, + if raw.offset.is_none() { + return Err(DecodeError::PayloadMissingOffset); } + Ok(match raw.base { + Some(NFT_PAYLOAD_LL_HEADER) => PayloadType::LinkLayer(LLHeaderField::from_raw_data( + raw.offset.unwrap(), + raw.len.unwrap(), + )?), + Some(NFT_PAYLOAD_NETWORK_HEADER) => PayloadType::Network, + Some(NFT_PAYLOAD_TRANSPORT_HEADER) => PayloadType::Transport, + Some(v) => return Err(DecodeError::UnknownPayloadType(v)), + None => return Err(DecodeError::PayloadMissingBase), + }) } } -impl Expression for RawPayload { - fn get_raw_name() -> *const libc::c_char { - b"payload\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16); - let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16); - let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16); - match base as i32 { - libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })), - libc::NFT_PAYLOAD_NETWORK_HEADER => { - Ok(Self::Network(RawPayloadData { offset, len })) - } - libc::NFT_PAYLOAD_TRANSPORT_HEADER => { - Ok(Self::Transport(RawPayloadData { offset, len })) - } - - _ => return Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16, self.len()); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_PAYLOAD_DREG as u16, - libc::NFT_REG_1 as u32, - ); - - expr - } - } +pub trait HeaderField { + fn offset(&self) -> u32; + fn len(&self) -> u32; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -154,58 +122,52 @@ impl HeaderField for LLHeaderField { } impl LLHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 6 { - Ok(Self::Daddr) - } else if off == 6 && len == 6 { - Ok(Self::Saddr) - } else if off == 12 && len == 2 { - Ok(Self::EtherType) - } else { - Err(DeserializationError::InvalidValue) - } + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 6) => Self::Daddr, + (6, 6) => Self::Saddr, + (12, 2) => Self::EtherType, + _ => return Err(DecodeError::UnknownLinkLayerHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum NetworkHeaderField { - Ipv4(Ipv4HeaderField), - Ipv6(Ipv6HeaderField), + IPv4(IPv4HeaderField), + IPv6(IPv6HeaderField), } impl HeaderField for NetworkHeaderField { fn offset(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.offset(), - Ipv6(ref f) => f.offset(), + IPv4(ref f) => f.offset(), + IPv6(ref f) => f.offset(), } } fn len(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.len(), - Ipv6(ref f) => f.len(), + IPv4(ref f) => f.len(), + IPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv4HeaderField { +pub enum IPv4HeaderField { Ttl, Protocol, Saddr, Daddr, } -impl HeaderField for Ipv4HeaderField { +impl HeaderField for IPv4HeaderField { fn offset(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 8, Protocol => 9, @@ -215,7 +177,7 @@ impl HeaderField for Ipv4HeaderField { } fn len(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 1, Protocol => 1, @@ -225,37 +187,30 @@ impl HeaderField for Ipv4HeaderField { } } -impl Ipv4HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 8 && len == 1 { - Ok(Self::Ttl) - } else if off == 9 && len == 1 { - Ok(Self::Protocol) - } else if off == 12 && len == 4 { - Ok(Self::Saddr) - } else if off == 16 && len == 4 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv4HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (8, 1) => Self::Ttl, + (9, 1) => Self::Protocol, + (12, 4) => Self::Saddr, + (16, 4) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv4HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv6HeaderField { +pub enum IPv6HeaderField { NextHeader, HopLimit, Saddr, Daddr, } -impl HeaderField for Ipv6HeaderField { +impl HeaderField for IPv6HeaderField { fn offset(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 6, HopLimit => 7, @@ -265,7 +220,7 @@ impl HeaderField for Ipv6HeaderField { } fn len(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 1, HopLimit => 1, @@ -275,31 +230,24 @@ impl HeaderField for Ipv6HeaderField { } } -impl Ipv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 6 && len == 1 { - Ok(Self::NextHeader) - } else if off == 7 && len == 1 { - Ok(Self::HopLimit) - } else if off == 8 && len == 16 { - Ok(Self::Saddr) - } else if off == 24 && len == 16 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (6, 1) => Self::NextHeader, + (7, 1) => Self::HopLimit, + (8, 16) => Self::Saddr, + (24, 16) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv6HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] pub enum TransportHeaderField { - Tcp(TcpHeaderField), - Udp(UdpHeaderField), - Icmpv6(Icmpv6HeaderField), + Tcp(TCPHeaderField), + Udp(UDPHeaderField), + ICMPv6(ICMPv6HeaderField), } impl HeaderField for TransportHeaderField { @@ -308,7 +256,7 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.offset(), Udp(ref f) => f.offset(), - Icmpv6(ref f) => f.offset(), + ICMPv6(ref f) => f.offset(), } } @@ -317,21 +265,21 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.len(), Udp(ref f) => f.len(), - Icmpv6(ref f) => f.len(), + ICMPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum TcpHeaderField { +pub enum TCPHeaderField { Sport, Dport, } -impl HeaderField for TcpHeaderField { +impl HeaderField for TCPHeaderField { fn offset(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -339,7 +287,7 @@ impl HeaderField for TcpHeaderField { } fn len(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -347,32 +295,27 @@ impl HeaderField for TcpHeaderField { } } -impl TcpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else { - Err(DeserializationError::InvalidValue) - } +impl TCPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + _ => return Err(DecodeError::UnknownTCPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum UdpHeaderField { +pub enum UDPHeaderField { Sport, Dport, Len, } -impl HeaderField for UdpHeaderField { +impl HeaderField for UDPHeaderField { fn offset(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -381,7 +324,7 @@ impl HeaderField for UdpHeaderField { } fn len(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -390,34 +333,28 @@ impl HeaderField for UdpHeaderField { } } -impl UdpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else if off == 4 && len == 2 { - Ok(Self::Len) - } else { - Err(DeserializationError::InvalidValue) - } +impl UDPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + (4, 2) => Self::Len, + _ => return Err(DecodeError::UnknownUDPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Icmpv6HeaderField { +pub enum ICMPv6HeaderField { Type, Code, Checksum, } -impl HeaderField for Icmpv6HeaderField { +impl HeaderField for ICMPv6HeaderField { fn offset(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 0, Code => 1, @@ -426,7 +363,7 @@ impl HeaderField for Icmpv6HeaderField { } fn len(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 1, Code => 1, @@ -435,97 +372,13 @@ impl HeaderField for Icmpv6HeaderField { } } -impl Icmpv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 1 { - Ok(Self::Type) - } else if off == 1 && len == 1 { - Ok(Self::Code) - } else if off == 2 && len == 2 { - Ok(Self::Checksum) - } else { - Err(DeserializationError::InvalidValue) - } +impl ICMPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 1) => Self::Type, + (1, 1) => Self::Code, + (2, 2) => Self::Checksum, + _ => return Err(DecodeError::UnknownICMPv6HeaderField(offset, len)), + }) } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_payload { - (@ipv4_field ttl) => { - $crate::expr::Ipv4HeaderField::Ttl - }; - (@ipv4_field protocol) => { - $crate::expr::Ipv4HeaderField::Protocol - }; - (@ipv4_field saddr) => { - $crate::expr::Ipv4HeaderField::Saddr - }; - (@ipv4_field daddr) => { - $crate::expr::Ipv4HeaderField::Daddr - }; - - (@ipv6_field nextheader) => { - $crate::expr::Ipv6HeaderField::NextHeader - }; - (@ipv6_field hoplimit) => { - $crate::expr::Ipv6HeaderField::HopLimit - }; - (@ipv6_field saddr) => { - $crate::expr::Ipv6HeaderField::Saddr - }; - (@ipv6_field daddr) => { - $crate::expr::Ipv6HeaderField::Daddr - }; - - (@tcp_field sport) => { - $crate::expr::TcpHeaderField::Sport - }; - (@tcp_field dport) => { - $crate::expr::TcpHeaderField::Dport - }; - - (@udp_field sport) => { - $crate::expr::UdpHeaderField::Sport - }; - (@udp_field dport) => { - $crate::expr::UdpHeaderField::Dport - }; - (@udp_field len) => { - $crate::expr::UdpHeaderField::Len - }; - - (ethernet daddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Daddr) - }; - (ethernet saddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Saddr) - }; - (ethernet ethertype) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::EtherType) - }; - - (ipv4 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv4( - nft_expr_payload!(@ipv4_field $field), - )) - }; - (ipv6 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv6( - nft_expr_payload!(@ipv6_field $field), - )) - }; - - (tcp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Tcp( - nft_expr_payload!(@tcp_field $field), - )) - }; - (udp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Udp( - nft_expr_payload!(@udp_field $field), - )) - }; -} diff --git a/src/expr/register.rs b/src/expr/register.rs index a05af7e..9cc1bee 100644 --- a/src/expr/register.rs +++ b/src/expr/register.rs @@ -1,34 +1,17 @@ use std::fmt::Debug; -use crate::sys::libc; +use rustables_macros::nfnetlink_enum; -use super::DeserializationError; +use crate::sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}; /// A netfilter data register. The expressions store and read data to and from these when /// evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(u32)] pub enum Register { - Verdict = libc::NFT_REG_VERDICT, - Reg1 = libc::NFT_REG_1, - Reg2 = libc::NFT_REG_2, - Reg3 = libc::NFT_REG_3, - Reg4 = libc::NFT_REG_4, -} - -impl Register { - pub fn to_raw(self) -> u32 { - self as u32 - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_REG_VERDICT => Ok(Self::Verdict), - libc::NFT_REG_1 => Ok(Self::Reg1), - libc::NFT_REG_2 => Ok(Self::Reg2), - libc::NFT_REG_3 => Ok(Self::Reg3), - libc::NFT_REG_4 => Ok(Self::Reg4), - _ => Err(DeserializationError::InvalidValue), - } - } + Verdict = NFT_REG_VERDICT, + Reg1 = NFT_REG_1, + Reg2 = NFT_REG_2, + Reg3 = NFT_REG_3, + Reg4 = NFT_REG_4, } diff --git a/src/expr/reject.rs b/src/expr/reject.rs index 19752ce..83fd843 100644 --- a/src/expr/reject.rs +++ b/src/expr/reject.rs @@ -1,95 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc::{self, c_char}}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -/// A reject expression that defines the type of rejection message sent when discarding a packet. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub enum Reject { - /// Returns an ICMP unreachable packet. - Icmp(IcmpCode), - /// Rejects by sending a TCP RST packet. - TcpRst, -} +use crate::sys; -impl Reject { - fn to_raw(&self, family: ProtoFamily) -> u32 { - use libc::*; - let value = match *self { - Self::Icmp(..) => match family { - ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH, - _ => NFT_REJECT_ICMP_UNREACH, - }, - Self::TcpRst => NFT_REJECT_TCP_RST, - }; - value as u32 - } -} +use super::Expression; impl Expression for Reject { - fn get_raw_name() -> *const libc::c_char { - b"reject\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "reject" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16) - == libc::NFT_REJECT_TCP_RST as u32 - { - Ok(Self::TcpRst) - } else { - Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8( - expr, - sys::NFTNL_EXPR_REJECT_CODE as u16, - ))?)) - } - } - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - let family = rule.get_chain().get_table().get_family(); - - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_REJECT_TYPE as u16, - self.to_raw(family), - ); - - let reject_code = match *self { - Reject::Icmp(code) => code as u8, - Reject::TcpRst => 0, - }; - - sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code); - - expr - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +/// A reject expression that defines the type of rejection message sent when discarding a packet. +pub struct Reject { + #[field(sys::NFTA_REJECT_TYPE, name_in_functions = "type")] + reject_type: RejectType, + #[field(sys::NFTA_REJECT_ICMP_CODE)] + icmp_code: IcmpCode, } /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -#[repr(u8)] -pub enum IcmpCode { - NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8, - PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8, - HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8, - AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8, +#[nfnetlink_enum(u32)] +pub enum RejectType { + IcmpUnreach = sys::NFT_REJECT_ICMP_UNREACH, + TcpRst = sys::NFT_REJECT_TCP_RST, + IcmpxUnreach = sys::NFT_REJECT_ICMPX_UNREACH, } -impl IcmpCode { - fn from_raw(code: u8) -> Result<Self, DeserializationError> { - match code as i32 { - libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute), - libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach), - libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach), - libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited), - _ => Err(DeserializationError::InvalidValue), - } - } +/// An ICMP reject code. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +#[nfnetlink_enum(u8)] +pub enum IcmpCode { + NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE, + PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH, + HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH, + AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED, } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 3c4c374..7edf7cd 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,11 +1,39 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc::{self, c_char}}; -use std::ffi::{CStr, CString}; +use std::fmt::Debug; + +use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, + NFT_GOTO, NFT_JUMP, NFT_RETURN, +}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[nfnetlink_enum(i32)] +pub enum VerdictType { + Drop = NF_DROP, + Accept = NF_ACCEPT, + Queue = NF_QUEUE, + Continue = NFT_CONTINUE, + Break = NFT_BREAK, + Jump = NFT_JUMP, + Goto = NFT_GOTO, + Return = NFT_RETURN, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Verdict { + #[field(NFTA_VERDICT_CODE)] + code: VerdictType, + #[field(NFTA_VERDICT_CHAIN)] + chain: String, + #[field(NFTA_VERDICT_CHAIN_ID)] + chain_id: u32, +} -/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl -/// terms, but here it is simplified to only represent a verdict. #[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum Verdict { +pub enum VerdictKind { /// Silently drop the packet. Drop, /// Accept the packet and let it pass. @@ -14,135 +42,10 @@ pub enum Verdict { Continue, Break, Jump { - chain: CString, + chain: String, }, Goto { - chain: CString, + chain: String, }, Return, } - -impl Verdict { - fn chain(&self) -> Option<&CStr> { - match *self { - Verdict::Jump { ref chain } => Some(chain.as_c_str()), - Verdict::Goto { ref chain } => Some(chain.as_c_str()), - _ => None, - } - } -} - -impl Expression for Verdict { - fn get_raw_name() -> *const libc::c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let mut chain = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { - let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); - - if raw_chain.is_null() { - return Err(DeserializationError::NullPointer); - } - chain = Some(CStr::from_ptr(raw_chain).to_owned()); - } - - let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); - - match verdict as i32 { - libc::NF_DROP => Ok(Verdict::Drop), - libc::NF_ACCEPT => Ok(Verdict::Accept), - libc::NF_QUEUE => Ok(Verdict::Queue), - libc::NFT_CONTINUE => Ok(Verdict::Continue), - libc::NFT_BREAK => Ok(Verdict::Break), - libc::NFT_JUMP => { - if let Some(chain) = chain { - Ok(Verdict::Jump { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_GOTO => { - if let Some(chain) = chain { - Ok(Verdict::Goto { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_RETURN => Ok(Verdict::Return), - _ => Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let immediate_const = match *self { - Verdict::Drop => libc::NF_DROP, - Verdict::Accept => libc::NF_ACCEPT, - Verdict::Queue => libc::NF_QUEUE, - Verdict::Continue => libc::NFT_CONTINUE, - Verdict::Break => libc::NFT_BREAK, - Verdict::Jump { .. } => libc::NFT_JUMP, - Verdict::Goto { .. } => libc::NFT_GOTO, - Verdict::Return => libc::NFT_RETURN, - }; - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc( - b"immediate\0" as *const _ as *const c_char - )); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - libc::NFT_REG_VERDICT as u32, - ); - - if let Some(chain) = self.chain() { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr()); - } - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_VERDICT as u16, - immediate_const as u32, - ); - - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_verdict { - (drop) => { - $crate::expr::Verdict::Drop - }; - (accept) => { - $crate::expr::Verdict::Accept - }; - (reject icmp $code:expr) => { - $crate::expr::Verdict::Reject(RejectionType::Icmp($code)) - }; - (reject tcp-rst) => { - $crate::expr::Verdict::Reject(RejectionType::TcpRst) - }; - (queue) => { - $crate::expr::Verdict::Queue - }; - (continue) => { - $crate::expr::Verdict::Continue - }; - (break) => { - $crate::expr::Verdict::Break - }; - (jump $chain:expr) => { - $crate::expr::Verdict::Jump { chain: $chain } - }; - (goto $chain:expr) => { - $crate::expr::Verdict::Goto { chain: $chain } - }; - (return) => { - $crate::expr::Verdict::Return - }; -} diff --git a/src/expr/wrapper.rs b/src/expr/wrapper.rs deleted file mode 100644 index 12ef60b..0000000 --- a/src/expr/wrapper.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::ffi::CStr; -use std::ffi::CString; -use std::fmt::Debug; -use std::rc::Rc; -use std::os::raw::c_char; - -use super::{DeserializationError, Expression}; -use crate::{sys, Rule}; - -pub struct ExpressionWrapper { - pub(crate) expr: *const sys::nftnl_expr, - // we also need the rule here to ensure that the rule lives as long as the `expr` pointer - #[allow(dead_code)] - pub(crate) rule: Rc<Rule>, -} - -impl Debug for ExpressionWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl ExpressionWrapper { - /// Retrieves a textual description of the expression. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_expr_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.expr, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves the type of expression ("log", "counter", ...). - pub fn get_kind(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Attempts to decode the expression as the type T. - pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> { - if let Some(kind) = self.get_kind() { - let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) }; - if kind == raw_name { - return T::from_expr(self.expr); - } - } - Err(DeserializationError::InvalidExpressionKind) - } -} |