aboutsummaryrefslogtreecommitdiff
path: root/src/batch.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/batch.rs')
-rw-r--r--src/batch.rs237
1 files changed, 82 insertions, 155 deletions
diff --git a/src/batch.rs b/src/batch.rs
index 198e8d0..b5c88b8 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -1,31 +1,29 @@
-use crate::{MsgType, NlMsg};
-use crate::sys::{self as sys, libc};
-use std::ffi::c_void;
-use std::os::raw::c_char;
-use std::ptr;
+use libc;
+
use thiserror::Error;
+use crate::error::QueryError;
+use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
+use crate::sys::NFNL_SUBSYS_NFTABLES;
+use crate::{MsgType, ProtocolFamily};
+
+use nix::sys::socket::{
+ self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
+};
+
/// Error while communicating with netlink.
#[derive(Error, Debug)]
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());
-#[cfg(feature = "query")]
-/// Check if the kernel supports batched netlink messages to netfilter.
-pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
- match unsafe { sys::nftnl_batch_is_supported() } {
- 1 => Ok(true),
- 0 => Ok(false),
- _ => Err(NetlinkError(())),
- }
-}
-
-/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to
-/// `nftnl_batch` in libnftnl.
+/// A batch of netfilter messages to be performed in one atomic operation.
pub struct Batch {
- pub(crate) batch: *mut sys::nftnl_batch,
- pub(crate) seq: u32,
- pub(crate) is_empty: bool,
+ buf: Box<Vec<u8>>,
+ // the 'static lifetime here is a cheat, as the writer can only be used as long
+ // as `self.buf` exists. This is why this member must never be exposed directly to
+ // the rest of the crate (let alone publicly).
+ writer: NfNetlinkWriter<'static>,
+ seq: u32,
}
impl Batch {
@@ -33,48 +31,40 @@ impl Batch {
///
/// [default page size]: fn.default_batch_page_size.html
pub fn new() -> Self {
- Self::with_page_size(default_batch_page_size())
- }
-
- pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self {
- Batch {
- batch,
+ // TODO: use a pinned Box ?
+ let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
+ let mut writer = NfNetlinkWriter::new(unsafe {
+ std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
+ });
+ let seq = 0;
+ writer.write_header(
+ libc::NFNL_MSG_BATCH_BEGIN as u16,
+ ProtocolFamily::Unspec,
+ 0,
seq,
- // we assume this batch is not empty by default
- is_empty: false,
+ Some(libc::NFNL_SUBSYS_NFTABLES as u16),
+ );
+ writer.finalize_writing_object();
+ Batch {
+ buf,
+ writer,
+ seq: seq + 1,
}
}
- /// Creates a new nftnl batch with the given batch size.
- pub fn with_page_size(batch_page_size: u32) -> Self {
- let batch = try_alloc!(unsafe {
- sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize())
- });
- let mut this = Batch {
- batch,
- seq: 0,
- is_empty: true,
- };
- this.write_begin_msg();
- this
- }
-
/// Adds the given message to this batch.
- pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) {
+ pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
trace!("Writing NlMsg with seq {} to batch", self.seq);
- unsafe { msg.write(self.current(), self.seq, msg_type) };
- self.is_empty = false;
- self.next()
+ msg.add_or_remove(&mut self.writer, msg_type, self.seq);
+ self.seq += 1;
}
- /// Adds all the messages in the given iterator to this batch. If any message fails to be
- /// added the error for that failure is returned and all messages up until that message stay
- /// added to the batch.
- pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType)
- where
- T: NlMsg,
- I: Iterator<Item = T>,
- {
+ /// Adds all the messages in the given iterator to this batch.
+ pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
+ &mut self,
+ msg_iter: I,
+ msg_type: MsgType,
+ ) {
for msg in msg_iter {
self.add(&msg, msg_type);
}
@@ -86,109 +76,46 @@ impl Batch {
/// Return None if there is no object in the batch (this could block forever).
///
/// [`FinalizedBatch`]: struct.FinalizedBatch.html
- pub fn finalize(mut self) -> Option<FinalizedBatch> {
- self.write_end_msg();
- if self.is_empty {
- return None;
- }
- Some(FinalizedBatch { batch: self })
- }
-
- fn current(&self) -> *mut c_void {
- unsafe { sys::nftnl_batch_buffer(self.batch) }
- }
-
- fn next(&mut self) {
- if unsafe { sys::nftnl_batch_update(self.batch) } < 0 {
- // See try_alloc definition.
- std::process::abort();
- }
- self.seq += 1;
- }
-
- fn write_begin_msg(&mut self) {
- unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) };
- self.next();
- }
-
- fn write_end_msg(&mut self) {
- unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) };
- self.next();
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_batch {
- self.batch as *const sys::nftnl_batch
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch {
- self.batch
- }
-}
-
-impl Drop for Batch {
- fn drop(&mut self) {
- unsafe { sys::nftnl_batch_free(self.batch) };
- }
-}
-
-/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper
-/// batch end message. Created from [`Batch::finalize`].
-///
-/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch.
-///
-/// [`Batch`]: struct.Batch.html
-/// [`Batch::finalize`]: struct.Batch.html#method.finalize
-pub struct FinalizedBatch {
- batch: Batch,
-}
-
-impl FinalizedBatch {
- /// Returns the iterator over byte buffers to send to netlink.
- pub fn iter(&mut self) -> Iter<'_> {
- let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize };
- let mut iovecs = vec![
- libc::iovec {
- iov_base: ptr::null_mut(),
- iov_len: 0,
- };
- num_pages
- ];
- let iovecs_ptr = iovecs.as_mut_ptr();
- unsafe {
- sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32);
- }
- Iter {
- iovecs: iovecs.into_iter(),
- _marker: ::std::marker::PhantomData,
+ pub fn finalize(mut self) -> Vec<u8> {
+ self.writer.write_header(
+ libc::NFNL_MSG_BATCH_END as u16,
+ ProtocolFamily::Unspec,
+ 0,
+ self.seq,
+ Some(NFNL_SUBSYS_NFTABLES as u16),
+ );
+ self.writer.finalize_writing_object();
+ *self.buf
+ }
+
+ pub fn send(self) -> Result<(), QueryError> {
+ use crate::query::{recv_and_process, socket_close_wrapper};
+
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(QueryError::NetlinkOpenError)?;
+
+ let max_seq = self.seq - 1;
+
+ let addr = SockAddr::Netlink(NetlinkAddr::new(0, 0));
+ // while this bind() is not strictly necessary, strace have trouble decoding the messages
+ // if we don't
+ socket::bind(sock, &addr).expect("bind");
+
+ let to_send = self.finalize();
+ let mut sent = 0;
+ while sent != to_send.len() {
+ sent += socket::send(sock, &to_send[sent..], MsgFlags::empty())
+ .map_err(QueryError::NetlinkSendError)?;
}
- }
-}
-
-impl<'a> IntoIterator for &'a mut FinalizedBatch {
- type Item = &'a [u8];
- type IntoIter = Iter<'a>;
-
- fn into_iter(self) -> Iter<'a> {
- self.iter()
- }
-}
-
-pub struct Iter<'a> {
- iovecs: ::std::vec::IntoIter<libc::iovec>,
- _marker: ::std::marker::PhantomData<&'a ()>,
-}
-
-impl<'a> Iterator for Iter<'a> {
- type Item = &'a [u8];
- fn next(&mut self) -> Option<&'a [u8]> {
- self.iovecs.next().map(|iovec| unsafe {
- ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len)
- })
+ Ok(socket_close_wrapper(sock, move |sock| {
+ recv_and_process(sock, Some(max_seq), None, &mut ())
+ })?)
}
}