aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon THOBY <git@nightmared.fr>2022-10-02 17:51:51 +0200
committerSimon THOBY <git@nightmared.fr>2022-10-02 17:52:08 +0200
commitdc2c0bc1ba921f113d5f90a05245cfccab9dbdaa (patch)
tree43b243157e6aaa03e9a98e74c9ffdf5acf2db997
parent3371865506cad4a795f07bce4495eb00d199f4a6 (diff)
special case the handling of batch messages
-rw-r--r--examples/add-rules.rs54
-rw-r--r--src/nlmsg.rs9
-rw-r--r--src/parser.rs47
-rw-r--r--src/query.rs6
-rw-r--r--tests/batch.rs12
5 files changed, 63 insertions, 65 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs
index 0dee080..229db97 100644
--- a/examples/add-rules.rs
+++ b/examples/add-rules.rs
@@ -37,25 +37,26 @@
//! ```
use ipnetwork::{IpNetwork, Ipv4Network};
+use rustables::{query::send_batch, Batch, ProtoFamily, Table};
//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table};
use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc};
-//
-//const TABLE_NAME: &str = "example-table";
-//const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
-//const IN_CHAIN_NAME: &str = "chain-for-incoming-packets";
+
+const TABLE_NAME: &str = "example-table";
+const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets";
+const IN_CHAIN_NAME: &str = "chain-for-incoming-packets";
fn main() -> Result<(), Error> {
- // // Create a batch. This is used to store all the netlink messages we will later send.
- // // Creating a new batch also automatically writes the initial batch begin message needed
- // // to tell netlink this is a single transaction that might arrive over multiple netlink packets.
- // let mut batch = Batch::new();
- //
- // // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet)
- // let table = Rc::new(Table::new(TABLE_NAME, ProtoFamily::Inet));
- // // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add
- // // this table under its `ProtoFamily::Inet` ruleset.
- // batch.add(&Rc::clone(&table), rustables::MsgType::Add);
- //
+ // Create a batch. This is used to store all the netlink messages we will later send.
+ // Creating a new batch also automatically writes the initial batch begin message needed
+ // to tell netlink this is a single transaction that might arrive over multiple netlink packets.
+ let mut batch = Batch::new();
+
+ // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet)
+ let table = Table::new(TABLE_NAME, ProtoFamily::Inet);
+ // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add
+ // this table under its `ProtoFamily::Inet` ruleset.
+ batch.add(&table, rustables::MsgType::Add);
+
// // Create input and output chains under the table we created above.
// // Hook the chains to the input and output event hooks, with highest priority (priority zero).
// // See the `Chain::set_hook` documentation for details.
@@ -163,22 +164,13 @@ fn main() -> Result<(), Error> {
// allow_router_solicitation.add_expr(&nft_expr!(verdict accept));
//
// batch.add(&allow_router_solicitation, rustables::MsgType::Add);
- //
- // // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER ===
- //
- // // Finalize the batch. This means the batch end message is written into the batch, telling
- // // netfilter the we reached the end of the transaction message. It's also converted to a type
- // // that implements `IntoIterator<Item = &'a [u8]>`, thus allowing us to get the raw netlink data
- // // out so it can be sent over a netlink socket to netfilter.
- // match batch.finalize() {
- // Some(mut finalized_batch) => {
- // // Send the entire batch and process any returned messages.
- // send_batch(&mut finalized_batch)?;
- // Ok(())
- // }
- // None => todo!(),
- // }
- Ok(())
+
+ // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER ===
+
+ // Finalize the batch and send it. This means the batch end message is written into the batch, telling
+ // netfilter the we reached the end of the transaction message. It's also converted to a
+ // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter.
+ Ok(send_batch(batch)?)
}
// Look up the interface index for a given interface name.
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index 97b3b02..868560a 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -1,7 +1,8 @@
use std::{collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, ops::Deref};
use libc::{
- nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR,
+ nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
+ NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR,
};
use thiserror::Error;
@@ -58,7 +59,11 @@ impl<'a> NfNetlinkWriter<'a> {
//let mut hdr = &mut unsafe { *(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32;
- hdr.nlmsg_type = ((NFNL_SUBSYS_NFTABLES as u16) << 8) | msg_type;
+ hdr.nlmsg_type = msg_type;
+ // batch messages are not specific to the nftables subsystem
+ if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 {
+ hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8;
+ }
hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
hdr.nlmsg_seq = seq;
diff --git a/src/parser.rs b/src/parser.rs
index a8df855..23b5213 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -7,7 +7,8 @@ use std::{
};
use libc::{
- nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_MIN_TYPE,
+ nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
+ NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP,
NLM_F_DUMP_INTR,
};
use thiserror::Error;
@@ -119,6 +120,8 @@ pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> {
Ok(nlmsghdr)
}
+
+#[derive(Debug)]
pub enum NlMsg<'a> {
Done,
Noop,
@@ -131,16 +134,16 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
// - nlmsghdr (contains the message size and type)
// - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
// - the raw value that we want to validate (if the previous part is nfgenmsg)
- let nlmsghdr = get_nlmsghdr(buf)?;
+ let hdr = get_nlmsghdr(buf)?;
let size_of_hdr = pad_netlink_object::<nlmsghdr>();
- if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
- match nlmsghdr.nlmsg_type as libc::c_int {
- NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)),
- NLMSG_ERROR => {
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
+ if hdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
+ match hdr.nlmsg_type as libc::c_int {
+ x if x == NLMSG_NOOP => return Ok((hdr, NlMsg::Noop)),
+ x if x == NLMSG_ERROR => {
+ if hdr.nlmsg_len as usize > buf.len()
+ || (hdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
{
return Err(DecodeError::NlMsgTooSmall);
}
@@ -150,38 +153,40 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
};
// some APIs return negative values, while other return positive values
err.error = err.error.abs();
- return Ok((nlmsghdr, NlMsg::Error(err)));
+ return Ok((hdr, NlMsg::Error(err)));
}
- NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)),
+ x if x == NLMSG_DONE => return Ok((hdr, NlMsg::Done)),
x => return Err(DecodeError::UnsupportedType(x as u16)),
}
}
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(DecodeError::InvalidSubsystem(subsys));
+ // batch messages are not specific to the nftables subsystem
+ if hdr.nlmsg_type != NFNL_MSG_BATCH_BEGIN as u16 && hdr.nlmsg_type != NFNL_MSG_BATCH_END as u16
+ {
+ // verify that we are decoding nftables messages
+ let subsys = get_subsystem_from_nlmsghdr_type(hdr.nlmsg_type);
+ if subsys != NFNL_SUBSYS_NFTABLES as u8 {
+ return Err(DecodeError::InvalidSubsystem(subsys));
+ }
}
let size_of_nfgenmsg = pad_netlink_object::<Nfgenmsg>();
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
+ if hdr.nlmsg_len as usize > buf.len()
+ || (hdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
{
return Err(DecodeError::NlMsgTooSmall);
}
let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg;
let nfgenmsg = unsafe { *nfgenmsg_ptr };
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(DecodeError::InvalidSubsystem(subsys));
- }
+
if nfgenmsg.version != NFNETLINK_V0 as u8 {
return Err(DecodeError::InvalidVersion(nfgenmsg.version));
}
- let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize];
+ let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..hdr.nlmsg_len as usize];
- Ok((nlmsghdr, NlMsg::NfGenMsg(nfgenmsg, raw_value)))
+ Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value)))
}
pub type NetlinkType = u16;
diff --git a/src/query.rs b/src/query.rs
index 5065436..3fea40d 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -21,7 +21,7 @@ pub fn get_list_of_objects<Error>(
setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
) -> Result<Vec<u8>, Error> {
let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
- let mut writer = &mut NfNetlinkWriter::new(&mut buffer);
+ let mut writer = NfNetlinkWriter::new(&mut buffer);
writer.write_header(
msg_type,
ProtoFamily::Unspec,
@@ -106,7 +106,7 @@ mod inner {
}
let mut buf = &msg_buffer.as_slice()[0..nb_recv];
loop {
- let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf) }?;
+ let (nlmsghdr, msg) = parse_nlmsg(&buf)?;
match msg {
NlMsg::Done => {
return Ok(());
@@ -119,7 +119,7 @@ mod inner {
NlMsg::Noop => {}
NlMsg::NfGenMsg(genmsg, data) => {
if let Some(cb) = cb {
- cb(&nlmsghdr, &genmsg, &data, working_data);
+ cb(&nlmsghdr, &genmsg, &data, working_data)?;
}
}
}
diff --git a/tests/batch.rs b/tests/batch.rs
index dbf444f..081ee97 100644
--- a/tests/batch.rs
+++ b/tests/batch.rs
@@ -14,16 +14,14 @@ fn batch_empty() {
let buf = batch.finalize();
let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32;
- assert_eq!(op, NFNL_MSG_BATCH_BEGIN);
+ assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_BEGIN as u16);
let (_nfgenmsg, attrs, remaining_data) =
parse_object(hdr, msg, &buf).expect("Could not parse the batch message");
assert_eq!(attrs.get_raw_data(), []);
let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message");
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32;
- assert_eq!(op, NFNL_MSG_BATCH_END);
+ assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_END as u16);
let (_nfgenmsg, attrs, remaining_data) =
parse_object(hdr, msg, &remaining_data).expect("Could not parse the batch message");
@@ -55,8 +53,7 @@ fn batch_with_objects() {
let buf = batch.finalize();
let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message");
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32;
- assert_eq!(op, NFNL_MSG_BATCH_BEGIN);
+ assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_BEGIN as u16);
let (_nfgenmsg, attrs, mut remaining_data) =
parse_object(hdr, msg, &buf).expect("Could not parse the batch message");
@@ -71,8 +68,7 @@ fn batch_with_objects() {
}
let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message");
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32;
- assert_eq!(op, NFNL_MSG_BATCH_END);
+ assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_END as u16);
let (_nfgenmsg, attrs, remaining_data) =
parse_object(hdr, msg, &remaining_data).expect("Could not parse the batch message");