diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/batch.rs | 1 | ||||
-rw-r--r-- | src/chain.rs | 11 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/expr/mod.rs | 4 | ||||
-rw-r--r-- | src/expr/verdict.rs | 7 | ||||
-rw-r--r-- | src/nlmsg.rs | 19 | ||||
-rw-r--r-- | src/parser.rs | 77 | ||||
-rw-r--r-- | src/parser_impls.rs | 50 | ||||
-rw-r--r-- | src/query.rs | 14 | ||||
-rw-r--r-- | src/table.rs | 6 |
10 files changed, 85 insertions, 110 deletions
diff --git a/src/batch.rs b/src/batch.rs index b5c88b8..980194b 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -33,6 +33,7 @@ impl Batch { pub fn new() -> Self { // TODO: use a pinned Box ? let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize)); + // Safe because we hold onto the buffer for as long as `writer` exists let mut writer = NfNetlinkWriter::new(unsafe { std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) }); diff --git a/src/chain.rs b/src/chain.rs index 37e4cb3..53ac595 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError}; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; use crate::sys::{ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, - NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, - NFT_MSG_NEWCHAIN, + NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; @@ -63,7 +62,7 @@ impl NfNetlinkAttribute for ChainPolicy { (*self as i32).get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { (*self as i32).write_payload(addr); } } @@ -111,7 +110,7 @@ impl NfNetlinkAttribute for ChainType { self.as_str().len() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.as_str().to_string().write_payload(addr); } } @@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType { /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -#[derive(PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(PartialEq, Eq, Default, Debug)] pub struct Chain { family: ProtocolFamily, #[field(NFTA_CHAIN_TABLE)] @@ -151,7 +150,7 @@ pub struct Chain { chain_type: ChainType, #[field(NFTA_CHAIN_FLAGS)] flags: u32, - #[field(NFTA_CHAIN_USERDATA)] + #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)] userdata: Vec<u8>, } diff --git a/src/error.rs b/src/error.rs index f6b6247..80f06d7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -111,9 +111,6 @@ pub enum DecodeError { #[error("Invalid value for a protocol family")] UnknownProtocolFamily(i32), - - #[error("A custom error occured")] - Custom(Box<dyn std::error::Error + 'static>), } #[derive(thiserror::Error, Debug)] @@ -157,9 +154,6 @@ pub enum QueryError { #[error("Error received from the kernel")] NetlinkError(nlmsgerr), - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + Send + 'static>), - #[error("Couldn't allocate a netlink object, out of memory ?")] NetlinkAllocationFailed, diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 058b0cb..af29460 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -101,7 +101,7 @@ macro_rules! create_expr_variant { } } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { match self { $( $enum::$name(val) => val.write_payload(addr), @@ -194,7 +194,7 @@ impl NfNetlinkAttribute for ExpressionRaw { self.0.get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.0.write_payload(addr); } } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 7edf7cd..c42ad32 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::sys::{ - NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, - NFT_GOTO, NFT_JUMP, NFT_RETURN, + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN, }; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -21,14 +20,14 @@ pub enum VerdictType { Return = NFT_RETURN, } -#[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(nested = true)] +#[derive(Clone, PartialEq, Eq, Default, Debug)] pub struct Verdict { #[field(NFTA_VERDICT_CODE)] code: VerdictType, #[field(NFTA_VERDICT_CHAIN)] chain: String, - #[field(NFTA_VERDICT_CHAIN_ID)] + #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)] chain_id: u32, } diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 1c5b519..b8fa857 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -39,6 +39,8 @@ pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { pub struct NfNetlinkWriter<'a> { buf: &'a mut Vec<u8>, + // hold the position of the nlmsghdr and nfgenmsg structures for the object currently being + // written headers: Option<(usize, usize)>, } @@ -52,6 +54,7 @@ impl<'a> NfNetlinkWriter<'a> { let start = self.buf.len(); self.buf.resize(start + padded_size, 0); + // if we are *inside* an object begin written, extend the netlink object size if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) @@ -78,6 +81,7 @@ impl<'a> NfNetlinkWriter<'a> { let nlmsghdr_len = pad_netlink_object::<nlmsghdr>(); let nfgenmsg_len = pad_netlink_object::<nfgenmsg>(); + // serialize the nlmsghdr let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; @@ -90,6 +94,7 @@ impl<'a> NfNetlinkWriter<'a> { hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; hdr.nlmsg_seq = seq; + // serialize the nfgenmsg let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); let mut nfgenmsg: &mut nfgenmsg = unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; @@ -108,8 +113,10 @@ impl<'a> NfNetlinkWriter<'a> { } } +pub type NetlinkType = u16; + pub trait AttributeDecoder { - fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; + fn decode_attribute(&mut self, attr_type: NetlinkType, buf: &[u8]) -> Result<(), DecodeError>; } pub trait NfNetlinkDeserializable: Sized { @@ -139,9 +146,7 @@ pub trait NfNetlinkObject: None, ); let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } + self.write_payload(buf); writer.finalize_writing_object(); } @@ -165,8 +170,6 @@ pub trait NfNetlinkObject: } } -pub type NetlinkType = u16; - pub trait NfNetlinkAttribute: Debug + Sized { // is it a nested argument that must be marked with a NLA_F_NESTED flag? fn is_nested(&self) -> bool { @@ -177,6 +180,6 @@ pub trait NfNetlinkAttribute: Debug + Sized { size_of::<Self>() } - // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); - unsafe fn write_payload(&self, addr: *mut u8); + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr.as_mut_ptr(), self.get_size()); + fn write_payload(&self, addr: &mut [u8]); } diff --git a/src/parser.rs b/src/parser.rs index 6ea34c1..82dd27e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -105,14 +105,10 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -pub unsafe fn write_attribute<'a>( - ty: NetlinkType, - obj: &impl NfNetlinkAttribute, - mut buf: *mut u8, -) { - let header_len = pad_netlink_object::<libc::nlattr>(); +pub fn write_attribute<'a>(ty: NetlinkType, obj: &impl NfNetlinkAttribute, mut buf: &mut [u8]) { + let header_len = pad_netlink_object::<nlattr>(); // copy the header - *(buf as *mut nlattr) = nlattr { + let header = nlattr { // nla_len contains the header size + the unpadded attribute length nla_len: (header_len + obj.get_size() as usize) as u16, nla_type: if obj.is_nested() { @@ -121,7 +117,12 @@ pub unsafe fn write_attribute<'a>( ty }, }; - buf = buf.offset(pad_netlink_object::<nlattr>() as isize); + + unsafe { + *(buf.as_mut_ptr() as *mut nlattr) = header; + } + + buf = &mut buf[header_len..]; // copy the attribute data itself obj.write_payload(buf); } @@ -169,48 +170,30 @@ pub trait InnerFormat { ) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>; } -pub trait Parsable -where - Self: Sized, -{ - fn parse_object( - buf: &[u8], - add_obj: u32, - del_obj: u32, - ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>; -} +pub(crate) fn parse_object<T: AttributeDecoder + Default + Sized>( + buf: &[u8], + add_obj: u32, + del_obj: u32, +) -> Result<(T, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() started"); + let (hdr, msg) = parse_nlmsg(buf)?; -impl<T> Parsable for T -where - T: AttributeDecoder + Default + Sized, -{ - fn parse_object( - buf: &[u8], - add_obj: u32, - del_obj: u32, - ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> { - debug!("parse_object() started"); - let (hdr, msg) = parse_nlmsg(buf)?; - - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; - - if op != add_obj && op != del_obj { - return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); - } + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; - let obj_size = hdr.nlmsg_len as usize - - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); + if op != add_obj && op != del_obj { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } - let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); - let remaining_data = &buf[remaining_data_offset..]; + let obj_size = hdr.nlmsg_len as usize + - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); - let (nfgenmsg, res) = match msg { - NlMsg::NfGenMsg(nfgenmsg, content) => { - (nfgenmsg, read_attributes(&content[..obj_size])?) - } - _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), - }; + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + let remaining_data = &buf[remaining_data_offset..]; - Ok((res, nfgenmsg, remaining_data)) - } + let (nfgenmsg, res) = match msg { + NlMsg::NfGenMsg(nfgenmsg, content) => (nfgenmsg, read_attributes(&content[..obj_size])?), + _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), + }; + + Ok((res, nfgenmsg, remaining_data)) } diff --git a/src/parser_impls.rs b/src/parser_impls.rs index b2681bb..c49c876 100644 --- a/src/parser_impls.rs +++ b/src/parser_impls.rs @@ -1,4 +1,7 @@ -use std::{fmt::Debug, mem::transmute}; +use std::{ + fmt::Debug, + mem::{size_of, transmute}, +}; use rustables_macros::nfnetlink_struct; @@ -6,17 +9,17 @@ use crate::{ error::DecodeError, expr::Verdict, nlmsg::{ - pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, - NfNetlinkDeserializable, NfNetlinkObject, + pad_netlink_object, pad_netlink_object_with_variable_size, AttributeDecoder, + NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, }, - parser::{write_attribute, Parsable}, + parser::{parse_object, write_attribute}, sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, ProtocolFamily, }; impl NfNetlinkAttribute for u8 { - unsafe fn write_payload(&self, addr: *mut u8) { - *addr = *self; + fn write_payload(&self, addr: &mut [u8]) { + addr[0] = *self; } } @@ -27,8 +30,8 @@ impl NfNetlinkDeserializable for u8 { } impl NfNetlinkAttribute for u16 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -39,8 +42,8 @@ impl NfNetlinkDeserializable for u16 { } impl NfNetlinkAttribute for i32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -54,8 +57,8 @@ impl NfNetlinkDeserializable for i32 { } impl NfNetlinkAttribute for u32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -69,8 +72,8 @@ impl NfNetlinkDeserializable for u32 { } impl NfNetlinkAttribute for u64 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -90,8 +93,8 @@ impl NfNetlinkAttribute for String { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_bytes()); } } @@ -110,8 +113,8 @@ impl NfNetlinkAttribute for Vec<u8> { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_slice()); } } @@ -170,10 +173,11 @@ where }) } - unsafe fn write_payload(&self, mut addr: *mut u8) { + fn write_payload(&self, mut addr: &mut [u8]) { for item in &self.objs { write_attribute(NFTA_LIST_ELEM, item, addr); - addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + let offset = pad_netlink_object::<nlattr>() + item.get_size(); + addr = &mut addr[offset..]; } } } @@ -228,10 +232,10 @@ where impl<T> NfNetlinkDeserializable for T where - T: NfNetlinkObject + Parsable, + T: NfNetlinkObject + AttributeDecoder + Default + Sized, { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + fn deserialize(buf: &[u8]) -> Result<(T, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = parse_object::<T>( buf, <T as NfNetlinkObject>::MSG_TYPE_ADD, <T as NfNetlinkObject>::MSG_TYPE_DEL, diff --git a/src/query.rs b/src/query.rs index 7cf5050..3548d2a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -59,7 +59,7 @@ pub(crate) fn recv_and_process<'a, T>( } // we cannot know when a sequence of messages will end if the messages do not end - // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + // with an NlMsg::Done marker if a maximum sequence number wasn't specified if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { return Err(QueryError::UndecidableMessageTermination); } @@ -79,13 +79,7 @@ pub(crate) fn recv_and_process<'a, T>( // We achieve this by relocating the buffer content at the beginning of the buffer if end_pos >= nft_nlmsg_maxsize() as usize { if buf_start < end_pos { - unsafe { - std::ptr::copy( - msg_buffer[buf_start..end_pos].as_ptr(), - msg_buffer.as_mut_ptr(), - end_pos - buf_start, - ); - } + msg_buffer.copy_within(buf_start..end_pos, 0); } end_pos = end_pos - buf_start; buf_start = 0; @@ -128,9 +122,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>( ); if let Some(filter) = filter { let buf = writer.add_data_zeroed(filter.get_size()); - unsafe { - filter.write_payload(buf.as_mut_ptr()); - } + filter.write_payload(buf); } writer.finalize_writing_object(); Ok(buffer) diff --git a/src/table.rs b/src/table.rs index 81a26ef..1d19abe 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct; use crate::error::QueryError; use crate::nlmsg::NfNetlinkObject; use crate::sys::{ - NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; use crate::{Batch, ProtocolFamily}; @@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily}; /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html -#[derive(Default, PartialEq, Eq, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(Default, PartialEq, Eq, Debug)] pub struct Table { family: ProtocolFamily, #[field(NFTA_TABLE_NAME)] name: String, #[field(NFTA_TABLE_FLAGS)] flags: u32, - #[field(NFTA_TABLE_USERDATA)] + #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)] userdata: Vec<u8>, } |