diff options
author | Simon THOBY <git@nightmared.fr> | 2021-10-23 23:02:22 +0200 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2021-11-02 22:18:12 +0100 |
commit | 7f7b2c3af6e6f7a596a85ada823408bdd0b02118 (patch) | |
tree | 48908226b5252d0e86758fe36d05c1491f080ac1 | |
parent | 82ebb702c1358ac4af40c7ee43efa6f364fa6d50 (diff) |
replace Optionnals by Results for a better error propagation when deserializing expressions
-rw-r--r-- | rustables/src/expr/bitwise.rs | 3 | ||||
-rw-r--r-- | rustables/src/expr/cmp.rs | 116 | ||||
-rw-r--r-- | rustables/src/expr/counter.rs | 6 | ||||
-rw-r--r-- | rustables/src/expr/ct.rs | 10 | ||||
-rw-r--r-- | rustables/src/expr/immediate.rs | 54 | ||||
-rw-r--r-- | rustables/src/expr/log.rs | 8 | ||||
-rw-r--r-- | rustables/src/expr/lookup.rs | 8 | ||||
-rw-r--r-- | rustables/src/expr/masquerade.rs | 6 | ||||
-rw-r--r-- | rustables/src/expr/meta.rs | 46 | ||||
-rw-r--r-- | rustables/src/expr/mod.rs | 122 | ||||
-rw-r--r-- | rustables/src/expr/nat.rs | 39 | ||||
-rw-r--r-- | rustables/src/expr/payload.rs | 80 | ||||
-rw-r--r-- | rustables/src/expr/register.rs | 16 | ||||
-rw-r--r-- | rustables/src/expr/reject.rs | 23 | ||||
-rw-r--r-- | rustables/src/expr/verdict.rs | 37 | ||||
-rw-r--r-- | rustables/src/expr/wrapper.rs | 8 |
16 files changed, 285 insertions, 297 deletions
diff --git a/rustables/src/expr/bitwise.rs b/rustables/src/expr/bitwise.rs index 0c6c33c..a5d9343 100644 --- a/rustables/src/expr/bitwise.rs +++ b/rustables/src/expr/bitwise.rs @@ -1,5 +1,4 @@ -use super::{Expression, Rule}; -use crate::expr::cmp::ToSlice; +use super::{Expression, Rule, ToSlice}; use rustables_sys::{self as sys, libc}; use std::ffi::c_void; use std::os::raw::c_char; diff --git a/rustables/src/expr/cmp.rs b/rustables/src/expr/cmp.rs index 11825d6..f24b4ad 100644 --- a/rustables/src/expr/cmp.rs +++ b/rustables/src/expr/cmp.rs @@ -1,11 +1,9 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule, ToSlice}; use rustables_sys::{self as sys, libc}; use std::{ borrow::Cow, ffi::{c_void, CString}, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, os::raw::c_char, - slice, }; /// Comparison operator. @@ -39,16 +37,16 @@ impl CmpOp { } } - pub fn from_raw(val: u32) -> Option<Self> { + pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { use self::CmpOp::*; match val as i32 { - libc::NFT_CMP_EQ => Some(Eq), - libc::NFT_CMP_NEQ => Some(Neq), - libc::NFT_CMP_LT => Some(Lt), - libc::NFT_CMP_LTE => Some(Lte), - libc::NFT_CMP_GT => Some(Gt), - libc::NFT_CMP_GTE => Some(Gte), - _ => None, + libc::NFT_CMP_EQ => Ok(Eq), + libc::NFT_CMP_NEQ => Ok(Neq), + libc::NFT_CMP_LT => Ok(Lt), + libc::NFT_CMP_LTE => Ok(Lte), + libc::NFT_CMP_GT => Ok(Gt), + libc::NFT_CMP_GTE => Ok(Gte), + _ => Err(DeserializationError::InvalidValue), } } } @@ -89,7 +87,7 @@ impl<T: ToSlice> Expression for Cmp<T> { sys::nftnl_expr_set( expr, sys::NFTNL_EXPR_CMP_DATA as u16, - data.as_ref() as *const _ as *const c_void, + data.as_ptr() as *const c_void, data.len() as u32, ); @@ -107,7 +105,7 @@ impl<const N: usize> Expression for Cmp<[u8; N]> { // be *extremely* dangerous in terms of memory safety, // rustables only accept to deserialize expressions with variable-size data // to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { unsafe { let ref_len = std::mem::size_of::<[u8; N]>() as u32; let mut data_len = 0; @@ -118,23 +116,22 @@ impl<const N: usize> Expression for Cmp<[u8; N]> { ); if data.is_null() { - return None; + return Err(DeserializationError::NullPointer); } else if data_len != ref_len { - debug!("Invalid size requested for deserializing a 'cmp' expression: expected {} bytes, got {}", ref_len, data_len); - return None; + return Err(DeserializationError::InvalidDataSize); } let data = *(data as *const [u8; N]); - let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16)); - op.map(|op| Cmp { op, data }) + let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?; + Ok(Cmp { op, data }) } } // call to the other implementation to generate the expression fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { Cmp { - data: self.data.as_ref(), + data: &self.data as &[u8], op: self.op, } .to_expr(rule) @@ -166,87 +163,6 @@ macro_rules! nft_expr_cmp { }; } -/// A type that can be converted into a byte buffer. -pub trait ToSlice { - /// Returns the data this type represents. - fn to_slice(&self) -> Cow<'_, [u8]>; -} - -impl<'a> ToSlice for &'a [u8] { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Borrowed(self) - } -} - -impl<'a> ToSlice for &'a [u16] { - fn to_slice(&self) -> Cow<'_, [u8]> { - let ptr = self.as_ptr() as *const u8; - let len = self.len() * 2; - Cow::Borrowed(unsafe { slice::from_raw_parts(ptr, len) }) - } -} - -impl ToSlice for IpAddr { - fn to_slice(&self) -> Cow<'_, [u8]> { - match *self { - IpAddr::V4(ref addr) => addr.to_slice(), - IpAddr::V6(ref addr) => addr.to_slice(), - } - } -} - -impl ToSlice for Ipv4Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} - -impl ToSlice for Ipv6Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} - -impl ToSlice for u8 { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(vec![*self]) - } -} - -impl ToSlice for u16 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = (*self & 0x00ff) as u8; - let b1 = (*self >> 8) as u8; - Cow::Owned(vec![b0, b1]) - } -} - -impl ToSlice for u32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) - } -} - -impl ToSlice for i32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) - } -} - -impl<'a> ToSlice for &'a str { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::from(self.as_bytes()) - } -} - /// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please /// note that it is faster to check interface index than name. /// diff --git a/rustables/src/expr/counter.rs b/rustables/src/expr/counter.rs index b507e80..099e7fa 100644 --- a/rustables/src/expr/counter.rs +++ b/rustables/src/expr/counter.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys as sys; use std::os::raw::c_char; @@ -24,11 +24,11 @@ impl Expression for Counter { b"counter\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { unsafe { let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); - Some(Counter { + Ok(Counter { nb_bytes, nb_packets, }) diff --git a/rustables/src/expr/ct.rs b/rustables/src/expr/ct.rs index 9d58591..001aef8 100644 --- a/rustables/src/expr/ct.rs +++ b/rustables/src/expr/ct.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys::{self as sys, libc}; use std::os::raw::c_char; @@ -31,7 +31,7 @@ impl Expression for Conntrack { b"ct\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { @@ -40,11 +40,11 @@ impl Expression for Conntrack { let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16); match ct_key as i32 { - libc::NFT_CT_STATE => Some(Conntrack::State), - libc::NFT_CT_MARK => Some(Conntrack::Mark { + libc::NFT_CT_STATE => Ok(Conntrack::State), + libc::NFT_CT_MARK => Ok(Conntrack::Mark { set: ct_sreg_is_set, }), - _ => None, + _ => Err(DeserializationError::InvalidValue), } } } diff --git a/rustables/src/expr/immediate.rs b/rustables/src/expr/immediate.rs index 1196211..b5be101 100644 --- a/rustables/src/expr/immediate.rs +++ b/rustables/src/expr/immediate.rs @@ -1,7 +1,6 @@ -use super::{Expression, Register, Rule, ToSlice}; +use super::{DeserializationError, Expression, Register, Rule, ToSlice}; use rustables_sys as sys; use std::ffi::c_void; -use std::mem::size_of_val; use std::os::raw::c_char; /// An immediate expression. Used to set immediate data. @@ -33,11 +32,12 @@ impl<T: ToSlice> Expression for Immediate<T> { self.register.to_raw(), ); + let data = self.data.to_slice(); sys::nftnl_expr_set( expr, sys::NFTNL_EXPR_IMM_DATA as u16, - &self.data.to_slice() as *const _ as *const c_void, - size_of_val(&self.data) as u32, + data.as_ptr() as *const c_void, + data.len() as u32, ); expr @@ -54,7 +54,7 @@ impl<const N: usize> Expression for Immediate<[u8; N]> { // be *extremely* dangerous in terms of memory safety, // rustables only accept to deserialize expressions with variable-size data // to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { unsafe { let ref_len = std::mem::size_of::<[u8; N]>() as u32; let mut data_len = 0; @@ -65,10 +65,9 @@ impl<const N: usize> Expression for Immediate<[u8; N]> { ); if data.is_null() { - return None; + return Err(DeserializationError::NullPointer); } else if data_len != ref_len { - debug!("Invalid size requested for deserializing an 'immediate' expression: expected {} bytes, got {}", ref_len, data_len); - return None; + return Err(DeserializationError::InvalidDataSize); } let data = *(data as *const [u8; N]); @@ -76,9 +75,9 @@ impl<const N: usize> Expression for Immediate<[u8; N]> { let register = Register::from_raw(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_IMM_DREG as u16, - )); + ))?; - register.map(|register| Immediate { data, register }) + Ok(Immediate { data, register }) } } @@ -86,44 +85,11 @@ impl<const N: usize> Expression for Immediate<[u8; N]> { fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { Immediate { register: self.register, - data: self.data.as_ref(), + data: &self.data as &[u8], } .to_expr(rule) } } -// As casting bytes to any type of the same size as the input would -// be *extremely* dangerous in terms of memory safety, -// rustables only accept to deserialize expressions with variable-size data -// to arrays of bytes, so that the memory layout cannot be invalid. -impl<const N: usize> Immediate<[u8; N]> { - pub fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return None; - } else if data_len != ref_len { - debug!("Invalid size requested for deserializing an 'immediate' expression: expected {} bytes, got {}", ref_len, data_len); - return None; - } - - let data = *(data as *const [u8; N]); - - let register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - )); - - register.map(|register| Immediate { data, register }) - } - } -} #[macro_export] macro_rules! nft_expr_immediate { diff --git a/rustables/src/expr/log.rs b/rustables/src/expr/log.rs index ba1244a..db96ba9 100644 --- a/rustables/src/expr/log.rs +++ b/rustables/src/expr/log.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys as sys; use std::ffi::{CStr, CString}; use std::os::raw::c_char; @@ -16,7 +16,7 @@ impl Expression for Log { b"log\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { @@ -32,12 +32,12 @@ impl Expression for Log { if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); if raw_prefix.is_null() { - trace!("Unexpected empty prefix in a 'log' expression"); + return Err(DeserializationError::NullPointer); } else { prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); } } - Some(Log { group, prefix }) + Ok(Log { group, prefix }) } } diff --git a/rustables/src/expr/lookup.rs b/rustables/src/expr/lookup.rs index b9a03d9..7796b29 100644 --- a/rustables/src/expr/lookup.rs +++ b/rustables/src/expr/lookup.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use crate::set::Set; use rustables_sys::{self as sys, libc}; use std::ffi::{CStr, CString}; @@ -26,7 +26,7 @@ impl Expression for Lookup { b"lookup\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { @@ -35,12 +35,12 @@ impl Expression for Lookup { let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); if set_name.is_null() { - return None; + return Err(DeserializationError::NullPointer); } let set_name = CStr::from_ptr(set_name).to_owned(); - Some(Lookup { set_id, set_name }) + Ok(Lookup { set_id, set_name }) } } diff --git a/rustables/src/expr/masquerade.rs b/rustables/src/expr/masquerade.rs index bf4e0de..40565d5 100644 --- a/rustables/src/expr/masquerade.rs +++ b/rustables/src/expr/masquerade.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys as sys; use std::os::raw::c_char; @@ -11,11 +11,11 @@ impl Expression for Masquerade { b"masq\0" as *const _ as *const c_char } - fn from_expr(_expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { - Some(Masquerade) + Ok(Masquerade) } fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { diff --git a/rustables/src/expr/meta.rs b/rustables/src/expr/meta.rs index ba803ac..199f3d3 100644 --- a/rustables/src/expr/meta.rs +++ b/rustables/src/expr/meta.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys::{self as sys, libc}; use std::os::raw::c_char; @@ -58,24 +58,23 @@ impl Meta { } } - fn from_raw(val: u32) -> Option<Self> { + fn from_raw(val: u32) -> Result<Self, DeserializationError> { match val as i32 { - libc::NFT_META_PROTOCOL => Some(Self::Protocol), - libc::NFT_META_MARK => Some(Self::Mark { set: false }), - libc::NFT_META_IIF => Some(Self::Iif), - libc::NFT_META_OIF => Some(Self::Oif), - libc::NFT_META_IIFNAME => Some(Self::IifName), - libc::NFT_META_OIFNAME => Some(Self::OifName), - libc::NFT_META_IIFTYPE => Some(Self::IifType), - libc::NFT_META_OIFTYPE => Some(Self::OifType), - libc::NFT_META_SKUID => Some(Self::SkUid), - libc::NFT_META_SKGID => Some(Self::SkGid), - libc::NFT_META_NFPROTO => Some(Self::NfProto), - libc::NFT_META_L4PROTO => Some(Self::L4Proto), - libc::NFT_META_CGROUP => Some(Self::Cgroup), - libc::NFT_META_PRANDOM => Some(Self::PRandom), - - _ => None, + libc::NFT_META_PROTOCOL => Ok(Self::Protocol), + libc::NFT_META_MARK => Ok(Self::Mark { set: false }), + libc::NFT_META_IIF => Ok(Self::Iif), + libc::NFT_META_OIF => Ok(Self::Oif), + libc::NFT_META_IIFNAME => Ok(Self::IifName), + libc::NFT_META_OIFNAME => Ok(Self::OifName), + libc::NFT_META_IIFTYPE => Ok(Self::IifType), + libc::NFT_META_OIFTYPE => Ok(Self::OifType), + libc::NFT_META_SKUID => Ok(Self::SkUid), + libc::NFT_META_SKGID => Ok(Self::SkGid), + libc::NFT_META_NFPROTO => Ok(Self::NfProto), + libc::NFT_META_L4PROTO => Ok(Self::L4Proto), + libc::NFT_META_CGROUP => Ok(Self::Cgroup), + libc::NFT_META_PRANDOM => Ok(Self::PRandom), + _ => Err(DeserializationError::InvalidValue), } } } @@ -85,24 +84,21 @@ impl Expression for Meta { b"meta\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { unsafe { - let mut ret = match Self::from_raw(sys::nftnl_expr_get_u32( + let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_META_KEY as u16, - )) { - Some(x) => x, - None => return None, - }; + ))?; if let Self::Mark { ref mut set } = ret { *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); } - Some(ret) + Ok(ret) } } diff --git a/rustables/src/expr/mod.rs b/rustables/src/expr/mod.rs index 431a0b9..b20a752 100644 --- a/rustables/src/expr/mod.rs +++ b/rustables/src/expr/mod.rs @@ -3,8 +3,14 @@ //! //! [`Rule`]: struct.Rule.html +use std::borrow::Cow; +use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; + use super::rule::Rule; use rustables_sys::{self as sys, libc}; +use thiserror::Error; mod bitwise; pub use self::bitwise::*; @@ -51,18 +57,47 @@ pub use self::verdict::*; mod wrapper; pub use self::wrapper::ExpressionWrapper; +#[derive(Debug, Error)] +pub enum DeserializationError { + #[error("The expected expression type doesn't match the name of the raw expression")] + /// The expected expression type doesn't match the name of the raw expression + InvalidExpressionKind, + + #[error("Deserializing the requested type isn't implemented yet")] + /// Deserializing the requested type isn't implemented yet + NotImplemented, + + #[error("The expression value cannot be deserialized to the requested type")] + /// The expression value cannot be deserialized to the requested type + InvalidValue, + + #[error("A pointer was null while a non-null pointer was expected")] + /// A pointer was null while a non-null pointer was expected + NullPointer, + + #[error( + "The size of a raw value was incoherent with the expected type of the deserialized value" + )] + /// The size of a raw value was incoherent with the expected type of the deserialized value + InvalidDataSize, + + #[error(transparent)] + /// Couldn't find a matching protocol + InvalidProtolFamily(#[from] super::InvalidProtocolFamily), +} + /// Trait for every safe wrapper of an nftables expression. pub trait Expression { /// Returns the raw name used by nftables to identify the rule. fn get_raw_name() -> *const libc::c_char; /// Try to parse the expression from a raw nftables expression, - /// returning None if the attempted parsing failed. - fn from_expr(_expr: *const sys::nftnl_expr) -> Option<Self> + /// returning a [DeserializationError] if the attempted parsing failed. + fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { - None + Err(DeserializationError::NotImplemented) } /// Allocates and returns the low level `nftnl_expr` representation of this expression. @@ -70,6 +105,87 @@ pub trait Expression { fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; } +/// A type that can be converted into a byte buffer. +pub trait ToSlice { + /// Returns the data this type represents. + fn to_slice(&self) -> Cow<'_, [u8]>; +} + +impl<'a> ToSlice for &'a [u8] { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(self) + } +} + +impl<'a> ToSlice for &'a [u16] { + fn to_slice(&self) -> Cow<'_, [u8]> { + let ptr = self.as_ptr() as *const u8; + let len = self.len() * 2; + Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) + } +} + +impl ToSlice for IpAddr { + fn to_slice(&self) -> Cow<'_, [u8]> { + match *self { + IpAddr::V4(ref addr) => addr.to_slice(), + IpAddr::V6(ref addr) => addr.to_slice(), + } + } +} + +impl ToSlice for Ipv4Addr { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(self.octets().to_vec()) + } +} + +impl ToSlice for Ipv6Addr { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(self.octets().to_vec()) + } +} + +impl ToSlice for u8 { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::Owned(vec![*self]) + } +} + +impl ToSlice for u16 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = (*self & 0x00ff) as u8; + let b1 = (*self >> 8) as u8; + Cow::Owned(vec![b0, b1]) + } +} + +impl ToSlice for u32 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = *self as u8; + let b1 = (*self >> 8) as u8; + let b2 = (*self >> 16) as u8; + let b3 = (*self >> 24) as u8; + Cow::Owned(vec![b0, b1, b2, b3]) + } +} + +impl ToSlice for i32 { + fn to_slice(&self) -> Cow<'_, [u8]> { + let b0 = *self as u8; + let b1 = (*self >> 8) as u8; + let b2 = (*self >> 16) as u8; + let b3 = (*self >> 24) as u8; + Cow::Owned(vec![b0, b1, b2, b3]) + } +} + +impl<'a> ToSlice for &'a str { + fn to_slice(&self) -> Cow<'_, [u8]> { + Cow::from(self.as_bytes()) + } +} + #[macro_export(local_inner_macros)] macro_rules! nft_expr { (bitwise mask $mask:expr,xor $xor:expr) => { diff --git a/rustables/src/expr/nat.rs b/rustables/src/expr/nat.rs index 6bd0619..51f439f 100644 --- a/rustables/src/expr/nat.rs +++ b/rustables/src/expr/nat.rs @@ -1,4 +1,4 @@ -use super::{Expression, Register, Rule}; +use super::{DeserializationError, Expression, Register, Rule}; use crate::ProtoFamily; use rustables_sys::{self as sys, libc}; use std::{convert::TryFrom, os::raw::c_char}; @@ -13,11 +13,11 @@ pub enum NatType { } impl NatType { - fn from_raw(val: u32) -> Option<Self> { + fn from_raw(val: u32) -> Result<Self, DeserializationError> { match val as i32 { - libc::NFT_NAT_SNAT => Some(NatType::SNat), - libc::NFT_NAT_DNAT => Some(NatType::DNat), - _ => None, + libc::NFT_NAT_SNAT => Ok(NatType::SNat), + libc::NFT_NAT_DNAT => Ok(NatType::DNat), + _ => Err(DeserializationError::InvalidValue), } } } @@ -37,7 +37,7 @@ impl Expression for Nat { b"nat\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { @@ -45,42 +45,27 @@ impl Expression for Nat { let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_NAT_TYPE as u16, - )); - let nat_type = match nat_type { - Some(x) => x, - None => return None, - }; + ))?; let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, - ) as i32); - let family = match family { - Ok(x) => x, - Err(_) => return None, - }; + ) as i32)?; let ip_register = Register::from_raw(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - )); - let ip_register = match ip_register { - Some(x) => x, - None => return None, - }; + ))?; let mut port_register = None; if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) { - port_register = Register::from_raw(sys::nftnl_expr_get_u32( + port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32( expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - )); - if port_register.is_none() { - trace!("Invalid register in expression 'nat'"); - } + ))?); } - Some(Nat { + Ok(Nat { ip_register, nat_type, family, diff --git a/rustables/src/expr/payload.rs b/rustables/src/expr/payload.rs index a6b5ddf..25a71ad 100644 --- a/rustables/src/expr/payload.rs +++ b/rustables/src/expr/payload.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys::{self as sys, libc}; use std::os::raw::c_char; @@ -22,11 +22,11 @@ impl Payload { offset: f.offset(), len: f.len(), }), - Payload::Network(ref f) => RawPayload::LinkLayer(RawPayloadData { + Payload::Network(ref f) => RawPayload::Network(RawPayloadData { offset: f.offset(), len: f.len(), }), - Payload::Transport(ref f) => RawPayload::LinkLayer(RawPayloadData { + Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData { offset: f.offset(), len: f.offset(), }), @@ -88,23 +88,21 @@ impl Expression for RawPayload { b"payload\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { unsafe { let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16); let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16); let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16); match base as i32 { - libc::NFT_PAYLOAD_LL_HEADER => { - Some(Self::LinkLayer(RawPayloadData { offset, len })) - } + libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })), libc::NFT_PAYLOAD_NETWORK_HEADER => { - Some(Self::Network(RawPayloadData { offset, len })) + Ok(Self::Network(RawPayloadData { offset, len })) } libc::NFT_PAYLOAD_TRANSPORT_HEADER => { - Some(Self::Transport(RawPayloadData { offset, len })) + Ok(Self::Transport(RawPayloadData { offset, len })) } - _ => return None, + _ => return Err(DeserializationError::InvalidValue), } } } @@ -156,18 +154,18 @@ impl HeaderField for LLHeaderField { } impl LLHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 0 && len == 6 { - Some(Self::Daddr) + Ok(Self::Daddr) } else if off == 6 && len == 6 { - Some(Self::Saddr) + Ok(Self::Saddr) } else if off == 12 && len == 2 { - Some(Self::EtherType) + Ok(Self::EtherType) } else { - None + Err(DeserializationError::InvalidValue) } } } @@ -228,20 +226,20 @@ impl HeaderField for Ipv4HeaderField { } impl Ipv4HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 8 && len == 1 { - Some(Self::Ttl) + Ok(Self::Ttl) } else if off == 9 && len == 1 { - Some(Self::Protocol) + Ok(Self::Protocol) } else if off == 12 && len == 4 { - Some(Self::Saddr) + Ok(Self::Saddr) } else if off == 16 && len == 4 { - Some(Self::Daddr) + Ok(Self::Daddr) } else { - None + Err(DeserializationError::InvalidValue) } } } @@ -278,20 +276,20 @@ impl HeaderField for Ipv6HeaderField { } impl Ipv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 6 && len == 1 { - Some(Self::NextHeader) + Ok(Self::NextHeader) } else if off == 7 && len == 1 { - Some(Self::HopLimit) + Ok(Self::HopLimit) } else if off == 8 && len == 16 { - Some(Self::Saddr) + Ok(Self::Saddr) } else if off == 24 && len == 16 { - Some(Self::Daddr) + Ok(Self::Daddr) } else { - None + Err(DeserializationError::InvalidValue) } } } @@ -350,16 +348,16 @@ impl HeaderField for TcpHeaderField { } impl TcpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 0 && len == 2 { - Some(Self::Sport) + Ok(Self::Sport) } else if off == 2 && len == 2 { - Some(Self::Dport) + Ok(Self::Dport) } else { - None + Err(DeserializationError::InvalidValue) } } } @@ -393,18 +391,18 @@ impl HeaderField for UdpHeaderField { } impl UdpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 0 && len == 2 { - Some(Self::Sport) + Ok(Self::Sport) } else if off == 2 && len == 2 { - Some(Self::Dport) + Ok(Self::Dport) } else if off == 4 && len == 2 { - Some(Self::Len) + Ok(Self::Len) } else { - None + Err(DeserializationError::InvalidValue) } } } @@ -438,18 +436,18 @@ impl HeaderField for Icmpv6HeaderField { } impl Icmpv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Option<Self> { + pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { let off = data.offset; let len = data.len; if off == 0 && len == 1 { - Some(Self::Type) + Ok(Self::Type) } else if off == 1 && len == 1 { - Some(Self::Code) + Ok(Self::Code) } else if off == 2 && len == 2 { - Some(Self::Checksum) + Ok(Self::Checksum) } else { - None + Err(DeserializationError::InvalidValue) } } } diff --git a/rustables/src/expr/register.rs b/rustables/src/expr/register.rs index 3013451..2cfcc3b 100644 --- a/rustables/src/expr/register.rs +++ b/rustables/src/expr/register.rs @@ -2,6 +2,8 @@ use std::fmt::Debug; use rustables_sys::libc; +use super::DeserializationError; + /// A netfilter data register. The expressions store and read data to and from these /// when evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -19,14 +21,14 @@ impl Register { self as u32 } - pub fn from_raw(val: u32) -> Option<Self> { + pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { match val as i32 { - libc::NFT_REG_VERDICT => Some(Self::Verdict), - libc::NFT_REG_1 => Some(Self::Reg1), - libc::NFT_REG_2 => Some(Self::Reg2), - libc::NFT_REG_3 => Some(Self::Reg3), - libc::NFT_REG_4 => Some(Self::Reg4), - _ => None, + libc::NFT_REG_VERDICT => Ok(Self::Verdict), + libc::NFT_REG_1 => Ok(Self::Reg1), + libc::NFT_REG_2 => Ok(Self::Reg2), + libc::NFT_REG_3 => Ok(Self::Reg3), + libc::NFT_REG_4 => Ok(Self::Reg4), + _ => Err(DeserializationError::InvalidValue), } } } diff --git a/rustables/src/expr/reject.rs b/rustables/src/expr/reject.rs index f94079b..550a287 100644 --- a/rustables/src/expr/reject.rs +++ b/rustables/src/expr/reject.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use crate::ProtoFamily; use rustables_sys::{ self as sys, @@ -34,7 +34,7 @@ impl Expression for Reject { b"reject\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> where Self: Sized, { @@ -42,13 +42,12 @@ impl Expression for Reject { if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16) == libc::NFT_REJECT_TCP_RST as u32 { - Some(Self::TcpRst) + Ok(Self::TcpRst) } else { - IcmpCode::from_raw(sys::nftnl_expr_get_u8( + Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8( expr, sys::NFTNL_EXPR_REJECT_CODE as u16, - )) - .map(Self::Icmp) + ))?)) } } } @@ -88,13 +87,13 @@ pub enum IcmpCode { } impl IcmpCode { - fn from_raw(code: u8) -> Option<Self> { + fn from_raw(code: u8) -> Result<Self, DeserializationError> { match code as i32 { - libc::NFT_REJECT_ICMPX_NO_ROUTE => Some(Self::NoRoute), - libc::NFT_REJECT_ICMPX_PORT_UNREACH => Some(Self::PortUnreach), - libc::NFT_REJECT_ICMPX_HOST_UNREACH => Some(Self::HostUnreach), - libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Some(Self::AdminProhibited), - _ => None, + libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute), + libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach), + libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach), + libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited), + _ => Err(DeserializationError::InvalidValue), } } } diff --git a/rustables/src/expr/verdict.rs b/rustables/src/expr/verdict.rs index 772da52..6a6b802 100644 --- a/rustables/src/expr/verdict.rs +++ b/rustables/src/expr/verdict.rs @@ -1,4 +1,4 @@ -use super::{Expression, Rule}; +use super::{DeserializationError, Expression, Rule}; use rustables_sys::{ self as sys, libc::{self, c_char}, @@ -40,15 +40,14 @@ impl Expression for Verdict { b"immediate\0" as *const _ as *const c_char } - fn from_expr(expr: *const sys::nftnl_expr) -> Option<Self> { + fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { unsafe { let mut chain = None; if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); if raw_chain.is_null() { - trace!("Unexpected empty chain name when deserializing 'verdict' expression"); - return None; + return Err(DeserializationError::NullPointer); } chain = Some(CStr::from_ptr(raw_chain).to_owned()); } @@ -56,15 +55,27 @@ impl Expression for Verdict { let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); match verdict as i32 { - libc::NF_DROP => Some(Verdict::Drop), - libc::NF_ACCEPT => Some(Verdict::Accept), - libc::NF_QUEUE => Some(Verdict::Queue), - libc::NFT_CONTINUE => Some(Verdict::Continue), - libc::NFT_BREAK => Some(Verdict::Break), - libc::NFT_JUMP => chain.map(|chain| Verdict::Jump { chain }), - libc::NFT_GOTO => chain.map(|chain| Verdict::Goto { chain }), - libc::NFT_RETURN => Some(Verdict::Return), - _ => None, + libc::NF_DROP => Ok(Verdict::Drop), + libc::NF_ACCEPT => Ok(Verdict::Accept), + libc::NF_QUEUE => Ok(Verdict::Queue), + libc::NFT_CONTINUE => Ok(Verdict::Continue), + libc::NFT_BREAK => Ok(Verdict::Break), + libc::NFT_JUMP => { + if let Some(chain) = chain { + Ok(Verdict::Jump { chain }) + } else { + Err(DeserializationError::InvalidValue) + } + } + libc::NFT_GOTO => { + if let Some(chain) = chain { + Ok(Verdict::Goto { chain }) + } else { + Err(DeserializationError::InvalidValue) + } + } + libc::NFT_RETURN => Ok(Verdict::Return), + _ => Err(DeserializationError::InvalidValue), } } } diff --git a/rustables/src/expr/wrapper.rs b/rustables/src/expr/wrapper.rs index b9b90b3..5162c21 100644 --- a/rustables/src/expr/wrapper.rs +++ b/rustables/src/expr/wrapper.rs @@ -3,6 +3,7 @@ use std::ffi::CString; use std::fmt::Debug; use std::rc::Rc; +use super::DeserializationError; use super::Expression; use crate::Rule; use rustables_sys as sys; @@ -48,15 +49,14 @@ impl ExpressionWrapper { } } - /// Attempt to decode the expression as the type T, returning None if such - /// conversion is not possible or failed. - pub fn decode_expr<T: Expression>(&self) -> Option<T> { + /// Attempt to decode the expression as the type T. + pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> { if let Some(kind) = self.get_kind() { let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) }; if kind == raw_name { return T::from_expr(self.expr); } } - None + Err(DeserializationError::InvalidExpressionKind) } } |