diff options
-rw-r--r-- | examples/add-rules.rs | 8 | ||||
-rw-r--r-- | src/batch.rs | 41 | ||||
-rw-r--r-- | src/query.rs | 266 |
3 files changed, 161 insertions, 154 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 229db97..11e7b6f 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,7 +37,7 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{query::send_batch, Batch, ProtoFamily, Table}; +use rustables::{Batch, ProtoFamily, Table}; //use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; @@ -57,6 +57,10 @@ fn main() -> Result<(), Error> { // this table under its `ProtoFamily::Inet` ruleset. batch.add(&table, rustables::MsgType::Add); + let table = Table::new("lool", ProtoFamily::Inet); + + batch.add(&table, rustables::MsgType::Add); + // // Create input and output chains under the table we created above. // // Hook the chains to the input and output event hooks, with highest priority (priority zero). // // See the `Chain::set_hook` documentation for details. @@ -170,7 +174,7 @@ fn main() -> Result<(), Error> { // Finalize the batch and send it. This means the batch end message is written into the batch, telling // netfilter the we reached the end of the transaction message. It's also converted to a // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter. - Ok(send_batch(batch)?) + Ok(batch.send()?) } // Look up the interface index for a given interface name. diff --git a/src/batch.rs b/src/batch.rs index a9529a3..c7fb8f3 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -5,6 +5,11 @@ use thiserror::Error; use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; use crate::{MsgType, ProtoFamily}; +use crate::query::Error; +use nix::sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, +}; + /// Error while communicating with netlink. #[derive(Error, Debug)] #[error("Error while communicating with netlink")] @@ -30,18 +35,19 @@ impl Batch { let mut writer = NfNetlinkWriter::new(unsafe { std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) }); + let seq = 0; writer.write_header( libc::NFNL_MSG_BATCH_BEGIN as u16, ProtoFamily::Unspec, 0, - 0, + seq, Some(libc::NFNL_SUBSYS_NFTABLES as u16), ); writer.finalize_writing_object(); Batch { buf, writer, - seq: 1, + seq: seq + 1, } } @@ -80,6 +86,37 @@ impl Batch { self.writer.finalize_writing_object(); *self.buf } + + #[cfg(feature = "query")] + pub fn send(mut self) -> Result<(), Error> { + use crate::query::{recv_and_process_until_seq, socket_close_wrapper}; + + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; + + let max_seq = self.seq - 1; + + let addr = SockAddr::Netlink(NetlinkAddr::new(0, 0)); + // while this bind() is not strictly necessary, strace have trouble decoding the messages + // if we don't + socket::bind(sock, &addr).expect("bind"); + + let to_send = self.finalize(); + let mut sent = 0; + while sent != to_send.len() { + sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) + .map_err(Error::NetlinkSendError)?; + } + + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process_until_seq(sock, max_seq, None, &mut ()) + })?) + } } /// Selected batch page is 256 Kbytes long to load ruleset of half a million rules without hitting diff --git a/src/query.rs b/src/query.rs index 3fea40d..d2409f2 100644 --- a/src/query.rs +++ b/src/query.rs @@ -37,182 +37,148 @@ pub fn get_list_of_objects<Error>( Ok(buffer) } -#[cfg(feature = "query")] -mod inner { - use std::os::unix::prelude::RawFd; +use std::os::unix::prelude::RawFd; - use nix::{ - errno::Errno, - sys::socket::{ - self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, - }, - }; - - use crate::{ - batch::Batch, - parser::{parse_nlmsg, DecodeError, NlMsg}, - }; - - use super::*; +use nix::{ + errno::Errno, + sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, + }, +}; - #[derive(thiserror::Error, Debug)] - pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] nix::Error), +use crate::{ + batch::Batch, + parser::{parse_nlmsg, DecodeError, NlMsg}, +}; - #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] nix::Error), +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Unable to open netlink socket to netfilter")] + NetlinkOpenError(#[source] nix::Error), - #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] nix::Error), + #[error("Unable to send netlink command to netfilter")] + NetlinkSendError(#[source] nix::Error), - #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[from] DecodeError), + #[error("Error while reading from netlink socket")] + NetlinkRecvError(#[source] nix::Error), - #[error("Error received from the kernel")] - NetlinkError(nlmsgerr), + #[error("Error while processing an incoming netlink message")] + ProcessNetlinkError(#[from] DecodeError), - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + Send + 'static>), + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), - #[error("Couldn't allocate a netlink object, out of memory ?")] - NetlinkAllocationFailed, + #[error("Custom error when customizing the query")] + InitError(#[from] Box<dyn std::error::Error + Send + 'static>), - #[error("This socket is not a netlink socket")] - NotNetlinkSocket, + #[error("Couldn't allocate a netlink object, out of memory ?")] + NetlinkAllocationFailed, - #[error("Couldn't retrieve information on a socket")] - RetrievingSocketInfoFailed, + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, - #[error("Only a part of the message was sent")] - TruncatedSend, + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, - #[error("Couldn't close the socket")] - CloseFailed(#[source] Errno), - } + #[error("Only a part of the message was sent")] + TruncatedSend, - fn recv_and_process<'a, T>( - sock: RawFd, - cb: Option<&dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>>, - working_data: &'a mut T, - ) -> Result<(), Error> { - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), +} +pub(crate) fn recv_and_process_until_seq<'a, T>( + sock: RawFd, + max_seq: u32, + cb: Option<&dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>>, + working_data: &'a mut T, +) -> Result<(), Error> { + let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + + loop { + let nb_recv = socket::recv(sock, &mut msg_buffer, MsgFlags::empty()) + .map_err(Error::NetlinkRecvError)?; + if nb_recv <= 0 { + return Ok(()); + } + let mut buf = &msg_buffer.as_slice()[0..nb_recv]; loop { - let nb_recv = socket::recv(sock, &mut msg_buffer, MsgFlags::empty()) - .map_err(Error::NetlinkRecvError)?; - if nb_recv <= 0 { - return Ok(()); - } - let mut buf = &msg_buffer.as_slice()[0..nb_recv]; - loop { - let (nlmsghdr, msg) = parse_nlmsg(&buf)?; - match msg { - NlMsg::Done => { - return Ok(()); - } - NlMsg::Error(e) => { - if e.error != 0 { - return Err(Error::NetlinkError(e)); - } + let (nlmsghdr, msg) = parse_nlmsg(&buf)?; + match msg { + NlMsg::Done => { + return Ok(()); + } + NlMsg::Error(e) => { + if e.error != 0 { + return Err(Error::NetlinkError(e)); } - NlMsg::Noop => {} - NlMsg::NfGenMsg(genmsg, data) => { - if let Some(cb) = cb { - cb(&nlmsghdr, &genmsg, &data, working_data)?; - } + } + NlMsg::Noop => {} + NlMsg::NfGenMsg(genmsg, data) => { + if let Some(cb) = cb { + cb(&nlmsghdr, &genmsg, &data, working_data)?; } } + } - // netlink messages are 4bytes aligned - let aligned_length = ((nlmsghdr.nlmsg_len + 3) & !3u32) as usize; + // netlink messages are 4bytes aligned + let aligned_length = ((nlmsghdr.nlmsg_len + 3) & !3u32) as usize; - // retrieve the next message - buf = &buf[aligned_length..]; + // retrieve the next message + buf = &buf[aligned_length..]; - // exit the loop when we consumed all the buffer - if buf.len() == 0 { - break; - } + if nlmsghdr.nlmsg_seq >= max_seq { + return Ok(()); } - } - } - fn socket_close_wrapper( - sock: RawFd, - cb: impl FnOnce(RawFd) -> Result<(), Error>, - ) -> Result<(), Error> { - let ret = cb(sock); - - // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; - // and return EOPNOTSUPP if we try) - nix::unistd::close(sock).map_err(Error::CloseFailed)?; - - ret + // exit the loop and try to receive further messages when we consumed all the buffer + if buf.len() == 0 { + break; + } + } } +} - /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper - /// function called by mnl::cb_run2. - /// The callback expects a tuple of additional data (supplied as an argument to this function) - /// and of the output vector, to which it should append the parsed object it received. - pub fn list_objects_with_data<'a, T>( - data_type: u16, - cb: &dyn Fn(&libc::nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, - working_data: &'a mut T, - req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, - ) -> Result<(), Error> { - debug!("listing objects of kind {}", data_type); - let sock = socket::socket( - AddressFamily::Netlink, - SockType::Raw, - SockFlag::empty(), - SockProtocol::NetlinkNetFilter, - ) - .map_err(Error::NetlinkOpenError)?; - - let seq = 0; - - let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?; - socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; - - Ok(socket_close_wrapper(sock, move |sock| { - recv_and_process(sock, Some(cb), working_data) - })?) - } +pub(crate) fn socket_close_wrapper( + sock: RawFd, + cb: impl FnOnce(RawFd) -> Result<(), Error>, +) -> Result<(), Error> { + let ret = cb(sock); - pub fn send_batch(batch: Batch) -> Result<(), Error> { - let sock = socket::socket( - AddressFamily::Netlink, - SockType::Raw, - SockFlag::empty(), - SockProtocol::NetlinkNetFilter, - ) - .map_err(Error::NetlinkOpenError)?; - - let seq = 0; - let portid = 0; - - let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0)); - // while this bind() is not strictly necessary, strace have trouble decoding the messages - // if we don't - socket::bind(sock, &addr).expect("bind"); - //match socket::getsockname(sock).map_err(|_| Error::RetrievingSocketInfoFailed)? { - // SockAddr::Netlink(addr) => addr.0.nl_pid, - // _ => return Err(Error::NotNetlinkSocket), - //}; - - let to_send = batch.finalize(); - let mut sent = 0; - while sent != to_send.len() { - sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) - .map_err(Error::NetlinkSendError)?; - } + // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; + // and return EOPNOTSUPP if we try) + nix::unistd::close(sock).map_err(Error::CloseFailed)?; - Ok(socket_close_wrapper(sock, move |sock| { - recv_and_process(sock, None, &mut ()) - })?) - } + ret } -#[cfg(feature = "query")] -pub use inner::*; +/* +/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper +/// function called by mnl::cb_run2. +/// The callback expects a tuple of additional data (supplied as an argument to this function) +/// and of the output vector, to which it should append the parsed object it received. +pub fn list_objects_with_data<'a, T>( + data_type: u16, + cb: &dyn Fn(&libc::nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + working_data: &'a mut T, + req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, +) -> Result<(), Error> { + debug!("listing objects of kind {}", data_type); + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; + + let seq = 0; + + let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; + + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, Some(cb), working_data) + })?) +} +*/ |