aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--examples/add-rules.rs66
-rw-r--r--examples/filter-ethernet.rs55
-rw-r--r--src/chain.rs43
-rw-r--r--src/lib.rs3
-rw-r--r--src/query.rs332
-rw-r--r--src/rule.rs32
-rw-r--r--src/table.rs43
-rw-r--r--tests/chain.rs2
-rw-r--r--tests/expr.rs1
-rw-r--r--tests/lib.rs49
-rw-r--r--tests/rule.rs2
-rw-r--r--tests/set.rs2
-rw-r--r--tests/table.rs2
14 files changed, 408 insertions, 226 deletions
diff --git a/Cargo.toml b/Cargo.toml
index e2a2039..70bb4d2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,7 +20,7 @@ bitflags = "1.0"
thiserror = "1.0"
log = "0.4"
libc = "0.2.43"
-mnl = "0.2"
+nix = "0.23"
ipnetwork = "0.16"
serde = { version = "1.0", features = ["derive"] }
diff --git a/examples/add-rules.rs b/examples/add-rules.rs
index 3aae7ee..6354c8d 100644
--- a/examples/add-rules.rs
+++ b/examples/add-rules.rs
@@ -37,13 +37,8 @@
//! ```
use ipnetwork::{IpNetwork, Ipv4Network};
-use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table};
-use std::{
- ffi::{self, CString},
- io,
- net::Ipv4Addr,
- rc::Rc
-};
+use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table};
+use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc};
const TABLE_NAME: &str = "example-table";
const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
@@ -56,7 +51,10 @@ fn main() -> Result<(), Error> {
let mut batch = Batch::new();
// Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet)
- let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet));
+ let table = Rc::new(Table::new(
+ &CString::new(TABLE_NAME).unwrap(),
+ ProtoFamily::Inet,
+ ));
// Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add
// this table under its `ProtoFamily::Inet` ruleset.
batch.add(&Rc::clone(&table), rustables::MsgType::Add);
@@ -178,10 +176,10 @@ fn main() -> Result<(), Error> {
match batch.finalize() {
Some(mut finalized_batch) => {
// Send the entire batch and process any returned messages.
- send_and_process(&mut finalized_batch)?;
+ send_batch(&mut finalized_batch)?;
Ok(())
- },
- None => todo!()
+ }
+ None => todo!(),
}
}
@@ -196,53 +194,11 @@ fn iface_index(name: &str) -> Result<libc::c_uint, Error> {
}
}
-fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> {
- // Create a netlink socket to netfilter.
- let socket = mnl::Socket::new(mnl::Bus::Netfilter)?;
- // Send all the bytes in the batch.
- socket.send_all(&mut *batch)?;
-
- // Try to parse the messages coming back from netfilter. This part is still very unclear.
- let portid = socket.portid();
- let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize];
- let very_unclear_what_this_is_for = 2;
- while let Some(message) = socket_recv(&socket, &mut buffer[..])? {
- match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? {
- mnl::CbResult::Stop => {
- break;
- }
- mnl::CbResult::Ok => (),
- }
- }
- Ok(())
-}
-
-fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
- let ret = socket.recv(buf)?;
- if ret > 0 {
- Ok(Some(&buf[..ret]))
- } else {
- Ok(None)
- }
-}
-
#[derive(Debug)]
struct Error(String);
-impl From<io::Error> for Error {
- fn from(error: io::Error) -> Self {
- Error(error.to_string())
- }
-}
-
-impl From<ffi::NulError> for Error {
- fn from(error: ffi::NulError) -> Self {
- Error(error.to_string())
- }
-}
-
-impl From<ipnetwork::IpNetworkError> for Error {
- fn from(error: ipnetwork::IpNetworkError) -> Self {
+impl<T: std::error::Error> From<T> for Error {
+ fn from(error: T) -> Self {
Error(error.to_string())
}
}
diff --git a/examples/filter-ethernet.rs b/examples/filter-ethernet.rs
index b16c49e..41454c9 100644
--- a/examples/filter-ethernet.rs
+++ b/examples/filter-ethernet.rs
@@ -22,19 +22,22 @@
//! # nft delete table inet example-filter-ethernet
//! ```
-use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table};
-use std::{ffi::CString, io, rc::Rc};
+use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table};
+use std::{ffi::CString, rc::Rc};
const TABLE_NAME: &str = "example-filter-ethernet";
const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
const BLOCK_THIS_MAC: &[u8] = &[0, 0, 0, 0, 0, 0];
-fn main() -> Result<(), Error> {
+fn main() {
// For verbose explanations of what all these lines up until the rule creation does, see the
// `add-rules` example.
let mut batch = Batch::new();
- let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet));
+ let table = Rc::new(Table::new(
+ &CString::new(TABLE_NAME).unwrap(),
+ ProtoFamily::Inet,
+ ));
batch.add(&Rc::clone(&table), rustables::MsgType::Add);
let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table));
@@ -86,48 +89,8 @@ fn main() -> Result<(), Error> {
match batch.finalize() {
Some(mut finalized_batch) => {
- send_and_process(&mut finalized_batch)?;
- Ok(())
- },
- None => todo!()
- }
-}
-
-fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> {
- // Create a netlink socket to netfilter.
- let socket = mnl::Socket::new(mnl::Bus::Netfilter)?;
- // Send all the bytes in the batch.
- socket.send_all(&mut *batch)?;
-
- // Try to parse the messages coming back from netfilter. This part is still very unclear.
- let portid = socket.portid();
- let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize];
- let very_unclear_what_this_is_for = 2;
- while let Some(message) = socket_recv(&socket, &mut buffer[..])? {
- match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? {
- mnl::CbResult::Stop => {
- break;
- }
- mnl::CbResult::Ok => (),
+ send_batch(&mut finalized_batch).expect("Couldn't process the batch");
}
- }
- Ok(())
-}
-
-fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
- let ret = socket.recv(buf)?;
- if ret > 0 {
- Ok(Some(&buf[..ret]))
- } else {
- Ok(None)
- }
-}
-
-#[derive(Debug)]
-struct Error(String);
-
-impl From<io::Error> for Error {
- fn from(error: io::Error) -> Self {
- Error(error.to_string())
+ None => todo!(),
}
}
diff --git a/src/chain.rs b/src/chain.rs
index a942a37..18e3c64 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,5 +1,7 @@
-use crate::{MsgType, Table};
+#[cfg(feature = "query")]
+use crate::query::{Nfgenmsg, ParseError};
use crate::sys::{self as sys, libc};
+use crate::{MsgType, Table};
#[cfg(feature = "query")]
use std::convert::TryFrom;
use std::{
@@ -243,18 +245,27 @@ impl Drop for Chain {
#[cfg(feature = "query")]
pub fn get_chains_cb<'a>(
header: &libc::nlmsghdr,
+ _genmsg: &Nfgenmsg,
+ _data: &[u8],
(table, chains): &mut (&Rc<Table>, &mut Vec<Chain>),
-) -> libc::c_int {
+) -> Result<(), crate::query::Error> {
unsafe {
let chain = sys::nftnl_chain_alloc();
if chain == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "Chain allocation failed",
+ )))
+ .into());
}
let err = sys::nftnl_chain_nlmsg_parse(header, chain);
if err < 0 {
- error!("Failed to parse nelink chain message - {}", err);
sys::nftnl_chain_free(chain);
- return err;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "The netlink chain couldn't be parsed !?",
+ )))
+ .into());
}
let table_name = CStr::from_ptr(sys::nftnl_chain_get_str(
@@ -267,26 +278,38 @@ pub fn get_chains_cb<'a>(
Err(crate::InvalidProtocolFamily) => {
error!("The netlink table didn't have a valid protocol family !?");
sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_ERROR;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "The netlink table didn't have a valid protocol family !?",
+ )))
+ .into());
}
};
if table_name != table.get_name() {
sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
+ return Ok(());
}
if family != crate::ProtoFamily::Unspec && family != table.get_family() {
sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
+ return Ok(());
}
chains.push(Chain::from_raw(chain, table.clone()));
}
- mnl::mnl_sys::MNL_CB_OK
+
+ Ok(())
}
#[cfg(feature = "query")]
pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> {
- crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None)
+ let mut result = Vec::new();
+ crate::query::list_objects_with_data(
+ libc::NFT_MSG_GETCHAIN as u16,
+ &get_chains_cb,
+ &mut (&table, &mut result),
+ None,
+ )?;
+ Ok(result)
}
diff --git a/src/lib.rs b/src/lib.rs
index fbb96f3..5d40c5a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -119,7 +119,7 @@ pub use rule::Rule;
pub use rule::{get_rules_cb, list_rules_for_chain};
mod rule_methods;
-pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError};
+pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods};
pub mod set;
pub use set::Set;
@@ -155,6 +155,7 @@ pub enum ProtoFamily {
Ipv6 = libc::NFPROTO_IPV6 as u16,
DecNet = libc::NFPROTO_DECNET as u16,
}
+
#[derive(Error, Debug)]
#[error("Couldn't find a matching protocol")]
pub struct InvalidProtocolFamily;
diff --git a/src/query.rs b/src/query.rs
index d7574e8..1c81cdd 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,6 +1,164 @@
+use std::mem::size_of;
+
use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
+use libc::{
+ nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR,
+ NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
+};
use sys::libc;
+#[repr(C)]
+#[derive(Debug, Clone, Copy)]
+pub struct Nfgenmsg {
+ pub family: u8, /* AF_xxx */
+ pub version: u8, /* nfnetlink version */
+ pub res_id: u16, /* resource id */
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum ParseError {
+ #[error("The buffer is too small to hold a valid message")]
+ BufTooSmall,
+
+ #[error("The message is too small")]
+ NlMsgTooSmall,
+
+ #[error("Invalid subsystem, expected NFTABLES")]
+ InvalidSubsystem(u8),
+
+ #[error("Invalid version, expected NFNETLINK_V0")]
+ InvalidVersion(u8),
+
+ #[error("Invalid port ID")]
+ InvalidPortId(u32),
+
+ #[error("Invalid sequence number")]
+ InvalidSeq(u32),
+
+ #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
+ ConcurrentGenerationUpdate,
+
+ #[error("Unsupported message type")]
+ UnsupportedType(u16),
+
+ #[error("A custom error occured")]
+ Custom(Box<dyn std::error::Error + 'static>),
+}
+
+pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
+ ((x & 0xff00) >> 8) as u8
+}
+
+pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
+ (x & 0x00ff) as u8
+}
+
+pub unsafe fn get_nlmsghdr(
+ buf: &[u8],
+ expected_seq: u32,
+ expected_port_id: u32,
+) -> Result<&nlmsghdr, ParseError> {
+ let size_of_hdr = size_of::<nlmsghdr>();
+
+ if buf.len() < size_of_hdr {
+ return Err(ParseError::BufTooSmall);
+ }
+
+ let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr;
+ let nlmsghdr = *nlmsghdr_ptr;
+
+ if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr {
+ println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
+ return Err(ParseError::NlMsgTooSmall);
+ }
+
+ if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id {
+ return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid));
+ }
+
+ if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq {
+ return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq));
+ }
+
+ if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 {
+ return Err(ParseError::ConcurrentGenerationUpdate);
+ }
+
+ Ok(&*nlmsghdr_ptr as &nlmsghdr)
+}
+
+pub enum NlMsg<'a> {
+ Done,
+ Noop,
+ Error(nlmsgerr),
+ NfGenMsg(&'a Nfgenmsg, &'a [u8]),
+}
+
+pub unsafe fn parse_nlmsg<'a>(
+ buf: &'a [u8],
+ expected_seq: u32,
+ expected_port_id: u32,
+) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> {
+ // in theory the message is composed of the following parts:
+ // - nlmsghdr (contains the message size and type)
+ // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
+ // - the raw value that we want to validate (if the previous part is nfgenmsg)
+ let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?;
+
+ let size_of_hdr = size_of::<nlmsghdr>();
+
+ if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
+ match nlmsghdr.nlmsg_type as libc::c_int {
+ NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)),
+ NLMSG_ERROR => {
+ if nlmsghdr.nlmsg_len as usize > buf.len()
+ || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
+ {
+ println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
+ return Err(ParseError::NlMsgTooSmall);
+ }
+ let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr()
+ as *const nlmsgerr);
+ // some APIs return negative values, while other return positive values
+ err.error = err.error.abs();
+ return Ok((nlmsghdr, NlMsg::Error(err)));
+ }
+ NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)),
+ x => return Err(ParseError::UnsupportedType(x as u16)),
+ }
+ }
+
+ let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(ParseError::InvalidSubsystem(subsys));
+ }
+
+ let size_of_nfgenmsg = size_of::<Nfgenmsg>();
+ if nlmsghdr.nlmsg_len as usize > buf.len()
+ || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
+ {
+ println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
+ return Err(ParseError::NlMsgTooSmall);
+ }
+
+ let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg;
+ let nfgenmsg = *nfgenmsg_ptr;
+ let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(ParseError::InvalidSubsystem(subsys));
+ }
+ if nfgenmsg.version != NFNETLINK_V0 as u8 {
+ return Err(ParseError::InvalidVersion(nfgenmsg.version));
+ }
+
+ let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize];
+
+ Ok((
+ nlmsghdr,
+ NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value),
+ ))
+}
+
/// Returns a buffer containing a netlink message which requests a list of all the netfilter
/// matching objects (e.g. tables, chains, rules, ...).
/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback
@@ -29,6 +187,15 @@ pub fn get_list_of_objects<Error>(
#[cfg(feature = "query")]
mod inner {
+ use std::os::unix::prelude::RawFd;
+
+ use nix::{
+ errno::Errno,
+ sys::socket::{
+ self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
+ },
+ };
+
use crate::FinalizedBatch;
use super::*;
@@ -36,92 +203,159 @@ mod inner {
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
+ NetlinkOpenError(#[source] nix::Error),
#[error("Unable to send netlink command to netfilter")]
- NetlinkSendError(#[source] std::io::Error),
+ NetlinkSendError(#[source] nix::Error),
#[error("Error while reading from netlink socket")]
- NetlinkRecvError(#[source] std::io::Error),
+ NetlinkRecvError(#[source] nix::Error),
#[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[source] std::io::Error),
+ ProcessNetlinkError(#[from] ParseError),
+
+ #[error("Error received from the kernel")]
+ NetlinkError(nlmsgerr),
#[error("Custom error when customizing the query")]
InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
#[error("Couldn't allocate a netlink object, out of memory ?")]
NetlinkAllocationFailed,
+
+ #[error("This socket is not a netlink socket")]
+ NotNetlinkSocket,
+
+ #[error("Couldn't retrieve information on a socket")]
+ RetrievingSocketInfoFailed,
+
+ #[error("Only a part of the message was sent")]
+ TruncatedSend,
+
+ #[error("Couldn't close the socket")]
+ CloseFailed(#[source] Errno),
+ }
+
+ fn recv_and_process<'a, T>(
+ sock: RawFd,
+ cb: &dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>,
+ working_data: &'a mut T,
+ seq: u32,
+ portid: u32,
+ ) -> Result<(), Error> {
+ let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
+
+ loop {
+ let nb_recv = socket::recv(sock, &mut msg_buffer, MsgFlags::empty())
+ .map_err(Error::NetlinkRecvError)?;
+ if nb_recv <= 0 {
+ return Ok(());
+ }
+ let mut buf = &msg_buffer.as_slice()[0..nb_recv];
+ loop {
+ let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf, seq, portid) }?;
+ match msg {
+ NlMsg::Done => {
+ return Ok(());
+ }
+ NlMsg::Error(e) => {
+ if e.error != 0 {
+ return Err(Error::NetlinkError(e));
+ }
+ }
+ NlMsg::Noop => {}
+ NlMsg::NfGenMsg(genmsg, data) => cb(&nlmsghdr, &genmsg, &data, working_data)?,
+ }
+
+ // netlink messages are 4bytes aligned
+ let aligned_length = ((nlmsghdr.nlmsg_len + 3) & !3u32) as usize;
+
+ // retrieve the next message
+ buf = &buf[aligned_length..];
+
+ // exit the loop when we consumed all the buffer
+ if buf.len() == 0 {
+ break;
+ }
+ }
+ }
+ }
+
+ fn socket_close_wrapper(
+ sock: RawFd,
+ cb: impl FnOnce(RawFd) -> Result<(), Error>,
+ ) -> Result<(), Error> {
+ let ret = cb(sock);
+
+ // we don't need to shutdown the socket (in fact, Linux doesn't support that operation;
+ // and return EOPNOTSUPP if we try)
+ nix::unistd::close(sock).map_err(Error::CloseFailed)?;
+
+ ret
}
/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
/// function called by mnl::cb_run2.
/// The callback expects a tuple of additional data (supplied as an argument to this function)
/// and of the output vector, to which it should append the parsed object it received.
- pub fn list_objects_with_data<'a, A, T>(
+ pub fn list_objects_with_data<'a, T>(
data_type: u16,
- cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int,
- additional_data: &'a A,
+ cb: &dyn Fn(&libc::nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>,
+ working_data: &'a mut T,
req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
- ) -> Result<Vec<T>, Error>
- where
- T: 'a,
- {
+ ) -> Result<(), Error> {
debug!("listing objects of kind {}", data_type);
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(Error::NetlinkOpenError)?;
let seq = 0;
let portid = 0;
let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
- socket.send(&chains_buf).map_err(Error::NetlinkSendError)?;
+ socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?;
- let mut res = Vec::new();
-
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop = mnl::cb_run2(
- &msg_buffer,
- seq,
- portid,
- cb,
- &mut (additional_data, &mut res),
- )
- .map_err(Error::ProcessNetlinkError)?
- {
- break;
- }
- }
-
- Ok(res)
+ Ok(socket_close_wrapper(sock, move |sock| {
+ recv_and_process(sock, cb, working_data, seq, portid)
+ })?)
}
pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
- let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;
+ let sock = socket::socket(
+ AddressFamily::Netlink,
+ SockType::Raw,
+ SockFlag::empty(),
+ SockProtocol::NetlinkNetFilter,
+ )
+ .map_err(Error::NetlinkOpenError)?;
let seq = 0;
- let portid = socket.portid();
+ let portid = 0;
- socket.send_all(batch).map_err(Error::NetlinkSendError)?;
- debug!("sent");
+ // while this bind() is not strictly necessary, strace have trouble decoding the messages
+ // if we don't
+ let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0));
+ socket::bind(sock, &addr).expect("bind");
+ //match socket::getsockname(sock).map_err(|_| Error::RetrievingSocketInfoFailed)? {
+ // SockAddr::Netlink(addr) => addr.0.nl_pid,
+ // _ => return Err(Error::NotNetlinkSocket),
+ //};
- let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
- while socket
- .recv(&mut msg_buffer)
- .map_err(Error::NetlinkRecvError)?
- > 0
- {
- if let mnl::CbResult::Stop =
- mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)?
+ for data in batch {
+ if socket::send(sock, data, MsgFlags::empty()).map_err(Error::NetlinkSendError)?
+ < data.len()
{
- break;
+ return Err(Error::TruncatedSend);
}
}
- Ok(())
+
+ Ok(socket_close_wrapper(sock, move |sock| {
+ recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid)
+ })?)
}
}
diff --git a/src/rule.rs b/src/rule.rs
index 2ee5308..66beef8 100644
--- a/src/rule.rs
+++ b/src/rule.rs
@@ -1,6 +1,8 @@
use crate::expr::ExpressionWrapper;
-use crate::{chain::Chain, expr::Expression, MsgType};
+#[cfg(feature = "query")]
+use crate::query::{Nfgenmsg, ParseError};
use crate::sys::{self, libc};
+use crate::{chain::Chain, expr::Expression, MsgType};
use std::ffi::{c_void, CStr, CString};
use std::fmt::Debug;
use std::os::raw::c_char;
@@ -284,31 +286,42 @@ impl Drop for RuleExprsIter {
#[cfg(feature = "query")]
pub fn get_rules_cb(
header: &libc::nlmsghdr,
+ _genmsg: &Nfgenmsg,
+ _data: &[u8],
(chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>),
-) -> libc::c_int {
+) -> Result<(), crate::query::Error> {
unsafe {
let rule = sys::nftnl_rule_alloc();
if rule == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "Rule allocation failed",
+ )))
+ .into());
}
let err = sys::nftnl_rule_nlmsg_parse(header, rule);
if err < 0 {
- error!("Failed to parse nelink rule message - {}", err);
sys::nftnl_rule_free(rule);
- return err;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "The netlink table couldn't be parsed !?",
+ )))
+ .into());
}
rules.push(Rule::from_raw(rule, chain.clone()));
}
- mnl::mnl_sys::MNL_CB_OK
+
+ Ok(())
}
#[cfg(feature = "query")]
pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> {
+ let mut result = Vec::new();
crate::query::list_objects_with_data(
libc::NFT_MSG_GETRULE as u16,
- get_rules_cb,
- &chain,
+ &get_rules_cb,
+ &mut (chain, &mut result),
// only retrieve rules from the currently targetted chain
Some(&|hdr| unsafe {
let rule = sys::nftnl_rule_alloc();
@@ -337,5 +350,6 @@ pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query
sys::nftnl_rule_free(rule);
Ok(())
}),
- )
+ )?;
+ Ok(result)
}
diff --git a/src/table.rs b/src/table.rs
index 593fffb..332cc99 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,5 +1,7 @@
-use crate::{MsgType, ProtoFamily};
+#[cfg(feature = "query")]
+use crate::query::{Nfgenmsg, ParseError};
use crate::sys::{self, libc};
+use crate::{MsgType, ProtoFamily};
#[cfg(feature = "query")]
use std::convert::TryFrom;
use std::{
@@ -39,7 +41,7 @@ impl Table {
unsafe {
let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16);
if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a chain failed")
+ panic!("Impossible situation: retrieving the name of a table failed")
} else {
CStr::from_ptr(ptr)
}
@@ -137,29 +139,41 @@ impl Drop for Table {
/// A callback to parse the response for messages created with `get_tables_nlmsg`.
pub fn get_tables_cb(
header: &libc::nlmsghdr,
- (_, tables): &mut (&(), &mut Vec<Table>),
-) -> libc::c_int {
+ _genmsg: &Nfgenmsg,
+ _data: &[u8],
+ tables: &mut Vec<Table>,
+) -> Result<(), crate::query::Error> {
unsafe {
let table = sys::nftnl_table_alloc();
if table == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "Table allocation failed",
+ )))
+ .into());
}
let err = sys::nftnl_table_nlmsg_parse(header, table);
if err < 0 {
- error!("Failed to parse nelink table message - {}", err);
sys::nftnl_table_free(table);
- return err;
+ return Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "The netlink table couldn't be parsed !?",
+ )))
+ .into());
}
let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16);
match crate::ProtoFamily::try_from(family as i32) {
Ok(family) => {
tables.push(Table::from_raw(table, family));
- mnl::mnl_sys::MNL_CB_OK
+ Ok(())
}
Err(crate::InvalidProtocolFamily) => {
- error!("The netlink table didn't have a valid protocol family !?");
sys::nftnl_table_free(table);
- mnl::mnl_sys::MNL_CB_ERROR
+ Err(ParseError::Custom(Box::new(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "The netlink table didn't have a valid protocol family !?",
+ )))
+ .into())
}
}
}
@@ -167,5 +181,12 @@ pub fn get_tables_cb(
#[cfg(feature = "query")]
pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> {
- crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None)
+ let mut result = Vec::new();
+ crate::query::list_objects_with_data(
+ libc::NFT_MSG_GETTABLE as u16,
+ &get_tables_cb,
+ &mut result,
+ None,
+ )?;
+ Ok(result)
}
diff --git a/tests/chain.rs b/tests/chain.rs
index 4b6da91..809936a 100644
--- a/tests/chain.rs
+++ b/tests/chain.rs
@@ -1,7 +1,7 @@
use std::ffi::CStr;
mod sys;
-use rustables::MsgType;
+use rustables::{query::get_operation_from_nlmsghdr_type, MsgType};
use sys::*;
mod lib;
diff --git a/tests/expr.rs b/tests/expr.rs
index 7950df3..cfc9c7d 100644
--- a/tests/expr.rs
+++ b/tests/expr.rs
@@ -3,6 +3,7 @@ use rustables::expr::{
LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField,
TransportHeaderField, Verdict,
};
+use rustables::query::{get_operation_from_nlmsghdr_type, Nfgenmsg};
use rustables::set::Set;
use rustables::sys::libc::{nlmsghdr, NF_DROP};
use rustables::{ProtoFamily, Rule};
diff --git a/tests/lib.rs b/tests/lib.rs
index 29c61b3..9b44a88 100644
--- a/tests/lib.rs
+++ b/tests/lib.rs
@@ -1,19 +1,11 @@
#![allow(dead_code)]
-use libc::{nlmsghdr, AF_UNIX, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES};
+use libc::{nlmsghdr, AF_UNIX};
+use rustables::query::{parse_nlmsg, Nfgenmsg};
use rustables::set::SetKey;
use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Set, Table};
use std::ffi::{c_void, CStr};
-use std::mem::size_of;
use std::rc::Rc;
-pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
- ((x & 0xff00) >> 8) as u8
-}
-
-pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
- (x & 0x00ff) as u8
-}
-
pub const TABLE_NAME: &[u8; 10] = b"mocktable\0";
pub const CHAIN_NAME: &[u8; 10] = b"mockchain\0";
pub const SET_NAME: &[u8; 8] = b"mockset\0";
@@ -89,14 +81,6 @@ impl NetlinkExpr {
}
}
-#[repr(C)]
-#[derive(Clone, Copy)]
-pub struct Nfgenmsg {
- family: u8, /* AF_xxx */
- version: u8, /* nfnetlink version */
- res_id: u16, /* resource id */
-}
-
pub fn get_test_table() -> Table {
Table::new(
&CStr::from_bytes_with_nul(TABLE_NAME).unwrap(),
@@ -131,35 +115,20 @@ pub fn get_test_nlmsg_with_msg_type(
unsafe {
obj.write(buf.as_mut_ptr() as *mut c_void, 0, msg_type);
- // right now the message is composed of the following parts:
- // - nlmsghdr (contains the message size and type)
- // - nfgenmsg (nftables header that describes the message family)
- // - the raw value that we want to validate
-
- let size_of_hdr = size_of::<nlmsghdr>();
- let size_of_nfgenmsg = size_of::<Nfgenmsg>();
- let nlmsghdr = *(buf[0..size_of_hdr].as_ptr() as *const nlmsghdr);
- let nfgenmsg =
- *(buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg);
- let raw_value = buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize]
- .iter()
- .map(|x| *x)
- .collect();
+ let (nlmsghdr, msg) = parse_nlmsg(&buf, 0, 0).expect("Couldn't parse the message");
+
+ let (nfgenmsg, raw_value) = match msg {
+ rustables::query::NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value),
+ _ => panic!("Invalid return value type, expected a valid message"),
+ };
// sanity checks on the global message (this should be very similar/factorisable for the
// most part in other tests)
// TODO: check the messages flags
- assert_eq!(
- get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
- NFNL_SUBSYS_NFTABLES as u8
- );
- assert_eq!(nlmsghdr.nlmsg_seq, 0);
- assert_eq!(nlmsghdr.nlmsg_pid, 0);
assert_eq!(nfgenmsg.family, AF_UNIX as u8);
- assert_eq!(nfgenmsg.version, NFNETLINK_V0 as u8);
assert_eq!(nfgenmsg.res_id.to_be(), 0);
- (nlmsghdr, nfgenmsg, raw_value)
+ (*nlmsghdr, *nfgenmsg, raw_value.to_owned())
}
}
diff --git a/tests/rule.rs b/tests/rule.rs
index b601a61..517db47 100644
--- a/tests/rule.rs
+++ b/tests/rule.rs
@@ -1,7 +1,7 @@
use std::ffi::CStr;
mod sys;
-use rustables::MsgType;
+use rustables::{query::get_operation_from_nlmsghdr_type, MsgType};
use sys::*;
mod lib;
diff --git a/tests/set.rs b/tests/set.rs
index d5b2ad7..4b79988 100644
--- a/tests/set.rs
+++ b/tests/set.rs
@@ -1,7 +1,7 @@
mod sys;
use std::net::{Ipv4Addr, Ipv6Addr};
-use rustables::{set::SetKey, MsgType};
+use rustables::{query::get_operation_from_nlmsghdr_type, set::SetKey, MsgType};
use sys::*;
mod lib;
diff --git a/tests/table.rs b/tests/table.rs
index 3d8957c..971d58b 100644
--- a/tests/table.rs
+++ b/tests/table.rs
@@ -1,7 +1,7 @@
use std::ffi::CStr;
mod sys;
-use rustables::MsgType;
+use rustables::{query::get_operation_from_nlmsghdr_type, MsgType};
use sys::*;
mod lib;