aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/batch.rs1
-rw-r--r--src/chain.rs11
-rw-r--r--src/error.rs6
-rw-r--r--src/expr/mod.rs4
-rw-r--r--src/expr/verdict.rs7
-rw-r--r--src/nlmsg.rs19
-rw-r--r--src/parser.rs77
-rw-r--r--src/parser_impls.rs50
-rw-r--r--src/query.rs14
-rw-r--r--src/table.rs6
10 files changed, 85 insertions, 110 deletions
diff --git a/src/batch.rs b/src/batch.rs
index b5c88b8..980194b 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -33,6 +33,7 @@ impl Batch {
pub fn new() -> Self {
// TODO: use a pinned Box ?
let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
+ // Safe because we hold onto the buffer for as long as `writer` exists
let mut writer = NfNetlinkWriter::new(unsafe {
std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
});
diff --git a/src/chain.rs b/src/chain.rs
index 37e4cb3..53ac595 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError};
use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
use crate::sys::{
NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
- NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
- NFT_MSG_NEWCHAIN,
+ NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN,
};
use crate::{Batch, ProtocolFamily, Table};
use std::fmt::Debug;
@@ -63,7 +62,7 @@ impl NfNetlinkAttribute for ChainPolicy {
(*self as i32).get_size()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
(*self as i32).write_payload(addr);
}
}
@@ -111,7 +110,7 @@ impl NfNetlinkAttribute for ChainType {
self.as_str().len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
self.as_str().to_string().write_payload(addr);
}
}
@@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType {
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-#[derive(PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(PartialEq, Eq, Default, Debug)]
pub struct Chain {
family: ProtocolFamily,
#[field(NFTA_CHAIN_TABLE)]
@@ -151,7 +150,7 @@ pub struct Chain {
chain_type: ChainType,
#[field(NFTA_CHAIN_FLAGS)]
flags: u32,
- #[field(NFTA_CHAIN_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
userdata: Vec<u8>,
}
diff --git a/src/error.rs b/src/error.rs
index f6b6247..80f06d7 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -111,9 +111,6 @@ pub enum DecodeError {
#[error("Invalid value for a protocol family")]
UnknownProtocolFamily(i32),
-
- #[error("A custom error occured")]
- Custom(Box<dyn std::error::Error + 'static>),
}
#[derive(thiserror::Error, Debug)]
@@ -157,9 +154,6 @@ pub enum QueryError {
#[error("Error received from the kernel")]
NetlinkError(nlmsgerr),
- #[error("Custom error when customizing the query")]
- InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
-
#[error("Couldn't allocate a netlink object, out of memory ?")]
NetlinkAllocationFailed,
diff --git a/src/expr/mod.rs b/src/expr/mod.rs
index 058b0cb..af29460 100644
--- a/src/expr/mod.rs
+++ b/src/expr/mod.rs
@@ -101,7 +101,7 @@ macro_rules! create_expr_variant {
}
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
match self {
$(
$enum::$name(val) => val.write_payload(addr),
@@ -194,7 +194,7 @@ impl NfNetlinkAttribute for ExpressionRaw {
self.0.get_size()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
self.0.write_payload(addr);
}
}
diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs
index 7edf7cd..c42ad32 100644
--- a/src/expr/verdict.rs
+++ b/src/expr/verdict.rs
@@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE};
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
use crate::sys::{
- NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE,
- NFT_GOTO, NFT_JUMP, NFT_RETURN,
+ NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
@@ -21,14 +20,14 @@ pub enum VerdictType {
Return = NFT_RETURN,
}
-#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(nested = true)]
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
pub struct Verdict {
#[field(NFTA_VERDICT_CODE)]
code: VerdictType,
#[field(NFTA_VERDICT_CHAIN)]
chain: String,
- #[field(NFTA_VERDICT_CHAIN_ID)]
+ #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)]
chain_id: u32,
}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index 1c5b519..b8fa857 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -39,6 +39,8 @@ pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
pub struct NfNetlinkWriter<'a> {
buf: &'a mut Vec<u8>,
+ // hold the position of the nlmsghdr and nfgenmsg structures for the object currently being
+ // written
headers: Option<(usize, usize)>,
}
@@ -52,6 +54,7 @@ impl<'a> NfNetlinkWriter<'a> {
let start = self.buf.len();
self.buf.resize(start + padded_size, 0);
+ // if we are *inside* an object begin written, extend the netlink object size
if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers {
let mut hdr: &mut nlmsghdr = unsafe {
std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr)
@@ -78,6 +81,7 @@ impl<'a> NfNetlinkWriter<'a> {
let nlmsghdr_len = pad_netlink_object::<nlmsghdr>();
let nfgenmsg_len = pad_netlink_object::<nfgenmsg>();
+ // serialize the nlmsghdr
let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len);
let mut hdr: &mut nlmsghdr =
unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
@@ -90,6 +94,7 @@ impl<'a> NfNetlinkWriter<'a> {
hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
hdr.nlmsg_seq = seq;
+ // serialize the nfgenmsg
let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len);
let mut nfgenmsg: &mut nfgenmsg =
unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) };
@@ -108,8 +113,10 @@ impl<'a> NfNetlinkWriter<'a> {
}
}
+pub type NetlinkType = u16;
+
pub trait AttributeDecoder {
- fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>;
+ fn decode_attribute(&mut self, attr_type: NetlinkType, buf: &[u8]) -> Result<(), DecodeError>;
}
pub trait NfNetlinkDeserializable: Sized {
@@ -139,9 +146,7 @@ pub trait NfNetlinkObject:
None,
);
let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
+ self.write_payload(buf);
writer.finalize_writing_object();
}
@@ -165,8 +170,6 @@ pub trait NfNetlinkObject:
}
}
-pub type NetlinkType = u16;
-
pub trait NfNetlinkAttribute: Debug + Sized {
// is it a nested argument that must be marked with a NLA_F_NESTED flag?
fn is_nested(&self) -> bool {
@@ -177,6 +180,6 @@ pub trait NfNetlinkAttribute: Debug + Sized {
size_of::<Self>()
}
- // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
- unsafe fn write_payload(&self, addr: *mut u8);
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr.as_mut_ptr(), self.get_size());
+ fn write_payload(&self, addr: &mut [u8]);
}
diff --git a/src/parser.rs b/src/parser.rs
index 6ea34c1..82dd27e 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -105,14 +105,10 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
/// Write the attribute, preceded by a `libc::nlattr`
// rewrite of `mnl_attr_put`
-pub unsafe fn write_attribute<'a>(
- ty: NetlinkType,
- obj: &impl NfNetlinkAttribute,
- mut buf: *mut u8,
-) {
- let header_len = pad_netlink_object::<libc::nlattr>();
+pub fn write_attribute<'a>(ty: NetlinkType, obj: &impl NfNetlinkAttribute, mut buf: &mut [u8]) {
+ let header_len = pad_netlink_object::<nlattr>();
// copy the header
- *(buf as *mut nlattr) = nlattr {
+ let header = nlattr {
// nla_len contains the header size + the unpadded attribute length
nla_len: (header_len + obj.get_size() as usize) as u16,
nla_type: if obj.is_nested() {
@@ -121,7 +117,12 @@ pub unsafe fn write_attribute<'a>(
ty
},
};
- buf = buf.offset(pad_netlink_object::<nlattr>() as isize);
+
+ unsafe {
+ *(buf.as_mut_ptr() as *mut nlattr) = header;
+ }
+
+ buf = &mut buf[header_len..];
// copy the attribute data itself
obj.write_payload(buf);
}
@@ -169,48 +170,30 @@ pub trait InnerFormat {
) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>;
}
-pub trait Parsable
-where
- Self: Sized,
-{
- fn parse_object(
- buf: &[u8],
- add_obj: u32,
- del_obj: u32,
- ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>;
-}
+pub(crate) fn parse_object<T: AttributeDecoder + Default + Sized>(
+ buf: &[u8],
+ add_obj: u32,
+ del_obj: u32,
+) -> Result<(T, nfgenmsg, &[u8]), DecodeError> {
+ debug!("parse_object() started");
+ let (hdr, msg) = parse_nlmsg(buf)?;
-impl<T> Parsable for T
-where
- T: AttributeDecoder + Default + Sized,
-{
- fn parse_object(
- buf: &[u8],
- add_obj: u32,
- del_obj: u32,
- ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> {
- debug!("parse_object() started");
- let (hdr, msg) = parse_nlmsg(buf)?;
-
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
-
- if op != add_obj && op != del_obj {
- return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
- }
+ let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
- let obj_size = hdr.nlmsg_len as usize
- - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>());
+ if op != add_obj && op != del_obj {
+ return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
+ }
- let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
- let remaining_data = &buf[remaining_data_offset..];
+ let obj_size = hdr.nlmsg_len as usize
+ - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>());
- let (nfgenmsg, res) = match msg {
- NlMsg::NfGenMsg(nfgenmsg, content) => {
- (nfgenmsg, read_attributes(&content[..obj_size])?)
- }
- _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
- };
+ let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
+ let remaining_data = &buf[remaining_data_offset..];
- Ok((res, nfgenmsg, remaining_data))
- }
+ let (nfgenmsg, res) = match msg {
+ NlMsg::NfGenMsg(nfgenmsg, content) => (nfgenmsg, read_attributes(&content[..obj_size])?),
+ _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
+ };
+
+ Ok((res, nfgenmsg, remaining_data))
}
diff --git a/src/parser_impls.rs b/src/parser_impls.rs
index b2681bb..c49c876 100644
--- a/src/parser_impls.rs
+++ b/src/parser_impls.rs
@@ -1,4 +1,7 @@
-use std::{fmt::Debug, mem::transmute};
+use std::{
+ fmt::Debug,
+ mem::{size_of, transmute},
+};
use rustables_macros::nfnetlink_struct;
@@ -6,17 +9,17 @@ use crate::{
error::DecodeError,
expr::Verdict,
nlmsg::{
- pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
- NfNetlinkDeserializable, NfNetlinkObject,
+ pad_netlink_object, pad_netlink_object_with_variable_size, AttributeDecoder,
+ NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject,
},
- parser::{write_attribute, Parsable},
+ parser::{parse_object, write_attribute},
sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK},
ProtocolFamily,
};
impl NfNetlinkAttribute for u8 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *addr = *self;
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0] = *self;
}
}
@@ -27,8 +30,8 @@ impl NfNetlinkDeserializable for u8 {
}
impl NfNetlinkAttribute for u16 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -39,8 +42,8 @@ impl NfNetlinkDeserializable for u16 {
}
impl NfNetlinkAttribute for i32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -54,8 +57,8 @@ impl NfNetlinkDeserializable for i32 {
}
impl NfNetlinkAttribute for u32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -69,8 +72,8 @@ impl NfNetlinkDeserializable for u32 {
}
impl NfNetlinkAttribute for u64 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -90,8 +93,8 @@ impl NfNetlinkAttribute for String {
self.len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..self.len()].copy_from_slice(&self.as_bytes());
}
}
@@ -110,8 +113,8 @@ impl NfNetlinkAttribute for Vec<u8> {
self.len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..self.len()].copy_from_slice(&self.as_slice());
}
}
@@ -170,10 +173,11 @@ where
})
}
- unsafe fn write_payload(&self, mut addr: *mut u8) {
+ fn write_payload(&self, mut addr: &mut [u8]) {
for item in &self.objs {
write_attribute(NFTA_LIST_ELEM, item, addr);
- addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize);
+ let offset = pad_netlink_object::<nlattr>() + item.get_size();
+ addr = &mut addr[offset..];
}
}
}
@@ -228,10 +232,10 @@ where
impl<T> NfNetlinkDeserializable for T
where
- T: NfNetlinkObject + Parsable,
+ T: NfNetlinkObject + AttributeDecoder + Default + Sized,
{
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (mut obj, nfgenmsg, remaining_data) = Self::parse_object(
+ fn deserialize(buf: &[u8]) -> Result<(T, &[u8]), DecodeError> {
+ let (mut obj, nfgenmsg, remaining_data) = parse_object::<T>(
buf,
<T as NfNetlinkObject>::MSG_TYPE_ADD,
<T as NfNetlinkObject>::MSG_TYPE_DEL,
diff --git a/src/query.rs b/src/query.rs
index 7cf5050..3548d2a 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -59,7 +59,7 @@ pub(crate) fn recv_and_process<'a, T>(
}
// we cannot know when a sequence of messages will end if the messages do not end
- // with an NlMsg::Done marker while if a maximum sequence number wasn't specified
+ // with an NlMsg::Done marker if a maximum sequence number wasn't specified
if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
return Err(QueryError::UndecidableMessageTermination);
}
@@ -79,13 +79,7 @@ pub(crate) fn recv_and_process<'a, T>(
// We achieve this by relocating the buffer content at the beginning of the buffer
if end_pos >= nft_nlmsg_maxsize() as usize {
if buf_start < end_pos {
- unsafe {
- std::ptr::copy(
- msg_buffer[buf_start..end_pos].as_ptr(),
- msg_buffer.as_mut_ptr(),
- end_pos - buf_start,
- );
- }
+ msg_buffer.copy_within(buf_start..end_pos, 0);
}
end_pos = end_pos - buf_start;
buf_start = 0;
@@ -128,9 +122,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>(
);
if let Some(filter) = filter {
let buf = writer.add_data_zeroed(filter.get_size());
- unsafe {
- filter.write_payload(buf.as_mut_ptr());
- }
+ filter.write_payload(buf);
}
writer.finalize_writing_object();
Ok(buffer)
diff --git a/src/table.rs b/src/table.rs
index 81a26ef..1d19abe 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct;
use crate::error::QueryError;
use crate::nlmsg::NfNetlinkObject;
use crate::sys::{
- NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
+ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
NFT_MSG_NEWTABLE,
};
use crate::{Batch, ProtocolFamily};
@@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily};
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
-#[derive(Default, PartialEq, Eq, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(Default, PartialEq, Eq, Debug)]
pub struct Table {
family: ProtocolFamily,
#[field(NFTA_TABLE_NAME)]
name: String,
#[field(NFTA_TABLE_FLAGS)]
flags: u32,
- #[field(NFTA_TABLE_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)]
userdata: Vec<u8>,
}