diff options
Diffstat (limited to 'src/parser_impls.rs')
-rw-r--r-- | src/parser_impls.rs | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/src/parser_impls.rs b/src/parser_impls.rs new file mode 100644 index 0000000..b2681bb --- /dev/null +++ b/src/parser_impls.rs @@ -0,0 +1,243 @@ +use std::{fmt::Debug, mem::transmute}; + +use rustables_macros::nfnetlink_struct; + +use crate::{ + error::DecodeError, + expr::Verdict, + nlmsg::{ + pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkDeserializable, NfNetlinkObject, + }, + parser::{write_attribute, Parsable}, + sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, + ProtocolFamily, +}; + +impl NfNetlinkAttribute for u8 { + unsafe fn write_payload(&self, addr: *mut u8) { + *addr = *self; + } +} + +impl NfNetlinkDeserializable for u8 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf[0], &buf[1..])) + } +} + +impl NfNetlinkAttribute for u16 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u16 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) + } +} + +impl NfNetlinkAttribute for i32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for i32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u64 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u64 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ]), + &buf[8..], + )) + } +} + +impl NfNetlinkAttribute for String { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for String { + fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + // ignore the NULL byte terminator, if any + if buf.len() > 0 && buf[buf.len() - 1] == 0 { + buf = &buf[..buf.len() - 1]; + } + Ok((String::from_utf8(buf.to_vec())?, &[])) + } +} + +impl NfNetlinkAttribute for Vec<u8> { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for Vec<u8> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf.to_vec(), &[])) + } +} +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct NfNetlinkData { + #[field(NFTA_DATA_VALUE)] + value: Vec<u8>, + #[field(NFTA_DATA_VERDICT)] + verdict: Verdict, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default, +{ + objs: Vec<T>, +} + +impl<T> NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + pub fn add_value(&mut self, e: impl Into<T>) { + self.objs.push(e.into()); + } + + pub fn with_value(mut self, e: impl Into<T>) -> Self { + self.add_value(e); + self + } + + pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> { + self.objs.iter() + } +} + +impl<T> NfNetlinkAttribute for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + // one nlattr LIST_ELEM per object + self.objs.iter().fold(0, |acc, item| { + acc + item.get_size() + pad_netlink_object::<nlattr>() + }) + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + for item in &self.objs { + write_attribute(NFTA_LIST_ELEM, item, addr); + addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + } + } +} + +impl<T> NfNetlinkDeserializable for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let mut objs = Vec::new(); + + let mut pos = 0; + while buf.len() - pos > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + if nla_type != NFTA_LIST_ELEM { + return Err(DecodeError::UnsupportedAttributeType(nla_type)); + } + + let (obj, remaining) = T::deserialize( + &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], + )?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + objs.push(obj); + + pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if pos != buf.len() { + Err(DecodeError::InvalidDataSize) + } else { + Ok((Self { objs }, &[])) + } + } +} + +impl<O, T> From<Vec<O>> for NfNetlinkList<T> +where + T: From<O>, + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn from(v: Vec<O>) -> Self { + NfNetlinkList { + objs: v.into_iter().map(T::from).collect(), + } + } +} + +impl<T> NfNetlinkDeserializable for T +where + T: NfNetlinkObject + Parsable, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + buf, + <T as NfNetlinkObject>::MSG_TYPE_ADD, + <T as NfNetlinkObject>::MSG_TYPE_DEL, + )?; + obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?); + + Ok((obj, remaining_data)) + } +} |