aboutsummaryrefslogtreecommitdiff
path: root/src/query.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/query.rs')
-rw-r--r--src/query.rs277
1 files changed, 163 insertions, 114 deletions
diff --git a/src/query.rs b/src/query.rs
index bc1d02e..7cf5050 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,129 +1,178 @@
-use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
-use sys::libc;
-
-/// Returns a buffer containing a netlink message which requests a list of all the netfilter
-/// matching objects (e.g. tables, chains, rules, ...).
-/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback
-/// to execute on the header, to set parameters for example.
-/// To pass arbitrary data inside that callback, please use a closure.
-pub fn get_list_of_objects<Error>(
- seq: u32,
- target: u16,
- setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
-) -> Result<Vec<u8>, Error> {
- let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
- let hdr = unsafe {
- &mut *sys::nftnl_nlmsg_build_hdr(
- buffer.as_mut_ptr() as *mut libc::c_char,
- target,
- ProtoFamily::Unspec as u16,
- (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
- seq,
- )
- };
- if let Some(cb) = setup_cb {
- cb(hdr)?;
- }
- Ok(buffer)
-}
-
-#[cfg(feature = "query")]
-mod inner {
- use crate::FinalizedBatch;
-
- use super::*;
-
- #[derive(thiserror::Error, Debug)]
- pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
-
- #[error("Unable to send netlink command to netfilter")]
- NetlinkSendError(#[source] std::io::Error),
-
- #[error("Error while reading from netlink socket")]
- NetlinkRecvError(#[source] std::io::Error),
+use std::os::unix::prelude::RawFd;
+
+use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType};
+
+use crate::{
+ error::QueryError,
+ nlmsg::{
+ nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
+ NfNetlinkObject, NfNetlinkWriter,
+ },
+ parser::{parse_nlmsg, NlMsg},
+ sys::{NLM_F_DUMP, NLM_F_MULTI},
+ ProtocolFamily,
+};
+
+pub(crate) fn recv_and_process<'a, T>(
+ sock: RawFd,
+ max_seq: Option<u32>,
+ cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>,
+ working_data: &'a mut T,
+) -> Result<(), QueryError> {
+ let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize];
+ let mut buf_start = 0;
+ let mut end_pos = 0;
+
+ loop {
+ let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty())
+ .map_err(QueryError::NetlinkRecvError)?;
+ if nb_recv <= 0 {
+ return Ok(());
+ }
+ end_pos += nb_recv;
+ loop {
+ let buf = &msg_buffer.as_slice()[buf_start..end_pos];
+ // exit the loop and try to receive further messages when we consumed all the buffer
+ if buf.len() == 0 {
+ break;
+ }
- #[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[source] std::io::Error),
+ debug!("Calling parse_nlmsg");
+ let (nlmsghdr, msg) = parse_nlmsg(&buf)?;
+ debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg);
+
+ match msg {
+ NlMsg::Done => {
+ return Ok(());
+ }
+ NlMsg::Error(e) => {
+ if e.error != 0 {
+ return Err(QueryError::NetlinkError(e));
+ }
+ }
+ NlMsg::Noop => {}
+ NlMsg::NfGenMsg(_genmsg, _data) => {
+ if let Some(cb) = cb {
+ cb(&buf[0..nlmsghdr.nlmsg_len as usize], working_data)?;
+ }
+ }
+ }
- #[error("Custom error when customizing the query")]
- InitError(#[from] Box<dyn std::error::Error + 'static>),
+ // we cannot know when a sequence of messages will end if the messages do not end
+ // with an NlMsg::Done marker while if a maximum sequence number wasn't specified
+ if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
+ return Err(QueryError::UndecidableMessageTermination);
+ }
- #[error("Couldn't allocate a netlink object, out of memory ?")]
- NetlinkAllocationFailed,
- }
+ // retrieve the next message
+ if let Some(max_seq) = max_seq {
+ if nlmsghdr.nlmsg_seq >= max_seq {
+ return Ok(());
+ }
+ }
- /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
- /// function called by mnl::cb_run2.
- /// The callback expects a tuple of additional data (supplied as an argument to this function)
- /// and of the output vector, to which it should append the parsed object it received.
- pub fn list_objects_with_data<'a, A, T>(
- data_type: u16,
- cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int,
- additional_data: &'a A,
- req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
- ) -> Result<Vec<T>, Error>
- where
- T: 'a,
- {
- debug!("listing objects of kind {}", data_type);
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
-
- let seq = 0;
- let portid = 0;
-
- let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
- socket.send(&chains_buf).map_err(Error::NetlinkSendError)?;
-
- let mut res = Vec::new();
-
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop = mnl::cb_run2(
- &msg_buffer,
- seq,
- portid,
- cb,
- &mut (additional_data, &mut res),
- )
- .map_err(Error::ProcessNetlinkError)?
- {
- break;
+ // netlink messages are 4bytes aligned
+ let aligned_length = pad_netlink_object_with_variable_size(nlmsghdr.nlmsg_len as usize);
+ buf_start += aligned_length;
+ }
+ // Ensure that we always have nft_nlmsg_maxsize() free space available in the buffer.
+ // We achieve this by relocating the buffer content at the beginning of the buffer
+ if end_pos >= nft_nlmsg_maxsize() as usize {
+ if buf_start < end_pos {
+ unsafe {
+ std::ptr::copy(
+ msg_buffer[buf_start..end_pos].as_ptr(),
+ msg_buffer.as_mut_ptr(),
+ end_pos - buf_start,
+ );
+ }
}
+ end_pos = end_pos - buf_start;
+ buf_start = 0;
}
-
- Ok(res)
}
+}
- pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
+pub(crate) fn socket_close_wrapper<E>(
+ sock: RawFd,
+ cb: impl FnOnce(RawFd) -> Result<(), E>,
+) -> Result<(), QueryError>
+where
+ QueryError: From<E>,
+{
+ let ret = cb(sock);
- let seq = 0;
- let portid = socket.portid();
+ // we don't need to shutdown the socket (in fact, Linux doesn't support that operation;
+ // and return EOPNOTSUPP if we try)
+ nix::unistd::close(sock).map_err(QueryError::CloseFailed)?;
- socket.send_all(batch).map_err(Error::NetlinkSendError)?;
- debug!("sent");
+ Ok(ret?)
+}
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop =
- mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)?
- {
- break;
- }
+/// Returns a buffer containing a netlink message which requests a list of all the netfilter
+/// matching objects (e.g. tables, chains, rules, ...).
+/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter.
+pub fn get_list_of_objects<T: NfNetlinkAttribute>(
+ msg_type: u16,
+ seq: u32,
+ filter: Option<&T>,
+) -> Result<Vec<u8>, QueryError> {
+ let mut buffer = Vec::new();
+ let mut writer = NfNetlinkWriter::new(&mut buffer);
+ writer.write_header(
+ msg_type,
+ ProtocolFamily::Unspec,
+ NLM_F_DUMP as u16,
+ seq,
+ None,
+ );
+ if let Some(filter) = filter {
+ let buf = writer.add_data_zeroed(filter.get_size());
+ unsafe {
+ filter.write_payload(buf.as_mut_ptr());
}
- Ok(())
}
+ writer.finalize_writing_object();
+ Ok(buffer)
}
-#[cfg(feature = "query")]
-pub use inner::*;
+/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
+/// function called by mnl::cb_run2.
+/// The callback expects a tuple of additional data (supplied as an argument to this function)
+/// and of the output vector, to which it should append the parsed object it received.
+pub fn list_objects_with_data<'a, Object, Accumulator>(
+ data_type: u16,
+ cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>,
+ filter: Option<&Object>,
+ working_data: &'a mut Accumulator,
+) -> Result<(), QueryError>
+where
+ Object: NfNetlinkObject + NfNetlinkAttribute,
+{
+ debug!("Listing objects of kind {}", data_type);
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(QueryError::NetlinkOpenError)?;
+
+ let seq = 0;
+
+ let chains_buf = get_list_of_objects(data_type, seq, filter)?;
+ socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?;
+
+ socket_close_wrapper(sock, move |sock| {
+ // the kernel should return NLM_F_MULTI objects
+ recv_and_process(
+ sock,
+ None,
+ Some(&|buf: &[u8], working_data: &mut Accumulator| {
+ debug!("Calling Object::deserialize()");
+ cb(Object::deserialize(buf)?.0, working_data)
+ }),
+ working_data,
+ )
+ })
+}