diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | macros/src/lib.rs | 6 | ||||
-rw-r--r-- | src/batch.rs | 1 | ||||
-rw-r--r-- | src/chain.rs | 4 | ||||
-rw-r--r-- | src/expr/mod.rs | 4 | ||||
-rw-r--r-- | src/nlmsg.rs | 8 | ||||
-rw-r--r-- | src/parser.rs | 17 | ||||
-rw-r--r-- | src/parser_impls.rs | 38 | ||||
-rw-r--r-- | src/query.rs | 14 |
9 files changed, 45 insertions, 49 deletions
@@ -20,7 +20,7 @@ log = "0.4" libc = "0.2.43" nix = "0.23" ipnetwork = { version = "0.20", default-features = false } -rustables-macros = "0.1.0" +rustables-macros = { path = "macros" } [dev-dependencies] env_logger = "0.9" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 39f0d01..bfb5099 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -276,7 +276,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { { let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); - addr = addr.offset(size as isize); + addr = &mut addr[size..]; } } ) @@ -296,7 +296,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { size } - unsafe fn write_payload(&self, mut addr: *mut u8) { + fn write_payload(&self, mut addr: &mut [u8]) { use crate::nlmsg::NfNetlinkAttribute; #(#write_entries) * @@ -483,7 +483,7 @@ pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { (*self as #repr_type).get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { (*self as #repr_type).write_payload(addr); } } diff --git a/src/batch.rs b/src/batch.rs index b5c88b8..980194b 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -33,6 +33,7 @@ impl Batch { pub fn new() -> Self { // TODO: use a pinned Box ? let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize)); + // Safe because we hold onto the buffer for as long as `writer` exists let mut writer = NfNetlinkWriter::new(unsafe { std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) }); diff --git a/src/chain.rs b/src/chain.rs index 37e4cb3..7d365a1 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -63,7 +63,7 @@ impl NfNetlinkAttribute for ChainPolicy { (*self as i32).get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { (*self as i32).write_payload(addr); } } @@ -111,7 +111,7 @@ impl NfNetlinkAttribute for ChainType { self.as_str().len() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.as_str().to_string().write_payload(addr); } } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 058b0cb..af29460 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -101,7 +101,7 @@ macro_rules! create_expr_variant { } } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { match self { $( $enum::$name(val) => val.write_payload(addr), @@ -194,7 +194,7 @@ impl NfNetlinkAttribute for ExpressionRaw { self.0.get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.0.write_payload(addr); } } diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 1c5b519..0e32588 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -139,9 +139,7 @@ pub trait NfNetlinkObject: None, ); let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } + self.write_payload(buf); writer.finalize_writing_object(); } @@ -177,6 +175,6 @@ pub trait NfNetlinkAttribute: Debug + Sized { size_of::<Self>() } - // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); - unsafe fn write_payload(&self, addr: *mut u8); + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr.as_mut_ptr(), self.get_size()); + fn write_payload(&self, addr: &mut [u8]); } diff --git a/src/parser.rs b/src/parser.rs index 6ea34c1..c8667e3 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -105,14 +105,10 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -pub unsafe fn write_attribute<'a>( - ty: NetlinkType, - obj: &impl NfNetlinkAttribute, - mut buf: *mut u8, -) { - let header_len = pad_netlink_object::<libc::nlattr>(); +pub fn write_attribute<'a>(ty: NetlinkType, obj: &impl NfNetlinkAttribute, mut buf: &mut [u8]) { + let header_len = pad_netlink_object::<nlattr>(); // copy the header - *(buf as *mut nlattr) = nlattr { + let header = nlattr { // nla_len contains the header size + the unpadded attribute length nla_len: (header_len + obj.get_size() as usize) as u16, nla_type: if obj.is_nested() { @@ -121,7 +117,12 @@ pub unsafe fn write_attribute<'a>( ty }, }; - buf = buf.offset(pad_netlink_object::<nlattr>() as isize); + + unsafe { + *(buf.as_mut_ptr() as *mut nlattr) = header; + } + + buf = &mut buf[header_len..]; // copy the attribute data itself obj.write_payload(buf); } diff --git a/src/parser_impls.rs b/src/parser_impls.rs index b2681bb..887a9a2 100644 --- a/src/parser_impls.rs +++ b/src/parser_impls.rs @@ -1,4 +1,7 @@ -use std::{fmt::Debug, mem::transmute}; +use std::{ + fmt::Debug, + mem::{size_of, transmute}, +}; use rustables_macros::nfnetlink_struct; @@ -15,8 +18,8 @@ use crate::{ }; impl NfNetlinkAttribute for u8 { - unsafe fn write_payload(&self, addr: *mut u8) { - *addr = *self; + fn write_payload(&self, addr: &mut [u8]) { + addr[0] = *self; } } @@ -27,8 +30,8 @@ impl NfNetlinkDeserializable for u8 { } impl NfNetlinkAttribute for u16 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -39,8 +42,8 @@ impl NfNetlinkDeserializable for u16 { } impl NfNetlinkAttribute for i32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -54,8 +57,8 @@ impl NfNetlinkDeserializable for i32 { } impl NfNetlinkAttribute for u32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -69,8 +72,8 @@ impl NfNetlinkDeserializable for u32 { } impl NfNetlinkAttribute for u64 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -90,8 +93,8 @@ impl NfNetlinkAttribute for String { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_bytes()); } } @@ -110,8 +113,8 @@ impl NfNetlinkAttribute for Vec<u8> { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_slice()); } } @@ -170,10 +173,11 @@ where }) } - unsafe fn write_payload(&self, mut addr: *mut u8) { + fn write_payload(&self, mut addr: &mut [u8]) { for item in &self.objs { write_attribute(NFTA_LIST_ELEM, item, addr); - addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + let offset = pad_netlink_object::<nlattr>() + item.get_size(); + addr = &mut addr[offset..]; } } } diff --git a/src/query.rs b/src/query.rs index 7cf5050..3548d2a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -59,7 +59,7 @@ pub(crate) fn recv_and_process<'a, T>( } // we cannot know when a sequence of messages will end if the messages do not end - // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + // with an NlMsg::Done marker if a maximum sequence number wasn't specified if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { return Err(QueryError::UndecidableMessageTermination); } @@ -79,13 +79,7 @@ pub(crate) fn recv_and_process<'a, T>( // We achieve this by relocating the buffer content at the beginning of the buffer if end_pos >= nft_nlmsg_maxsize() as usize { if buf_start < end_pos { - unsafe { - std::ptr::copy( - msg_buffer[buf_start..end_pos].as_ptr(), - msg_buffer.as_mut_ptr(), - end_pos - buf_start, - ); - } + msg_buffer.copy_within(buf_start..end_pos, 0); } end_pos = end_pos - buf_start; buf_start = 0; @@ -128,9 +122,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>( ); if let Some(filter) = filter { let buf = writer.add_data_zeroed(filter.get_size()); - unsafe { - filter.write_payload(buf.as_mut_ptr()); - } + filter.write_payload(buf); } writer.finalize_writing_object(); Ok(buffer) |