diff options
author | Simon THOBY <git@nightmared.fr> | 2022-10-02 17:51:51 +0200 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2022-10-02 17:52:08 +0200 |
commit | dc2c0bc1ba921f113d5f90a05245cfccab9dbdaa (patch) | |
tree | 43b243157e6aaa03e9a98e74c9ffdf5acf2db997 | |
parent | 3371865506cad4a795f07bce4495eb00d199f4a6 (diff) |
special case the handling of batch messages
-rw-r--r-- | examples/add-rules.rs | 54 | ||||
-rw-r--r-- | src/nlmsg.rs | 9 | ||||
-rw-r--r-- | src/parser.rs | 47 | ||||
-rw-r--r-- | src/query.rs | 6 | ||||
-rw-r--r-- | tests/batch.rs | 12 |
5 files changed, 63 insertions, 65 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 0dee080..229db97 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,25 +37,26 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; +use rustables::{query::send_batch, Batch, ProtoFamily, Table}; //use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; -// -//const TABLE_NAME: &str = "example-table"; -//const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; -//const IN_CHAIN_NAME: &str = "chain-for-incoming-packets"; + +const TABLE_NAME: &str = "example-table"; +const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; +const IN_CHAIN_NAME: &str = "chain-for-incoming-packets"; fn main() -> Result<(), Error> { - // // Create a batch. This is used to store all the netlink messages we will later send. - // // Creating a new batch also automatically writes the initial batch begin message needed - // // to tell netlink this is a single transaction that might arrive over multiple netlink packets. - // let mut batch = Batch::new(); - // - // // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet) - // let table = Rc::new(Table::new(TABLE_NAME, ProtoFamily::Inet)); - // // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add - // // this table under its `ProtoFamily::Inet` ruleset. - // batch.add(&Rc::clone(&table), rustables::MsgType::Add); - // + // Create a batch. This is used to store all the netlink messages we will later send. + // Creating a new batch also automatically writes the initial batch begin message needed + // to tell netlink this is a single transaction that might arrive over multiple netlink packets. + let mut batch = Batch::new(); + + // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet) + let table = Table::new(TABLE_NAME, ProtoFamily::Inet); + // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add + // this table under its `ProtoFamily::Inet` ruleset. + batch.add(&table, rustables::MsgType::Add); + // // Create input and output chains under the table we created above. // // Hook the chains to the input and output event hooks, with highest priority (priority zero). // // See the `Chain::set_hook` documentation for details. @@ -163,22 +164,13 @@ fn main() -> Result<(), Error> { // allow_router_solicitation.add_expr(&nft_expr!(verdict accept)); // // batch.add(&allow_router_solicitation, rustables::MsgType::Add); - // - // // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === - // - // // Finalize the batch. This means the batch end message is written into the batch, telling - // // netfilter the we reached the end of the transaction message. It's also converted to a type - // // that implements `IntoIterator<Item = &'a [u8]>`, thus allowing us to get the raw netlink data - // // out so it can be sent over a netlink socket to netfilter. - // match batch.finalize() { - // Some(mut finalized_batch) => { - // // Send the entire batch and process any returned messages. - // send_batch(&mut finalized_batch)?; - // Ok(()) - // } - // None => todo!(), - // } - Ok(()) + + // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === + + // Finalize the batch and send it. This means the batch end message is written into the batch, telling + // netfilter the we reached the end of the transaction message. It's also converted to a + // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter. + Ok(send_batch(batch)?) } // Look up the interface index for a given interface name. diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 97b3b02..868560a 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, ops::Deref}; use libc::{ - nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR, + nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLMSG_MIN_TYPE, NLM_F_DUMP_INTR, }; use thiserror::Error; @@ -58,7 +59,11 @@ impl<'a> NfNetlinkWriter<'a> { //let mut hdr = &mut unsafe { *(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32; - hdr.nlmsg_type = ((NFNL_SUBSYS_NFTABLES as u16) << 8) | msg_type; + hdr.nlmsg_type = msg_type; + // batch messages are not specific to the nftables subsystem + if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 { + hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8; + } hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; hdr.nlmsg_seq = seq; diff --git a/src/parser.rs b/src/parser.rs index a8df855..23b5213 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -7,7 +7,8 @@ use std::{ }; use libc::{ - nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_MIN_TYPE, + nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, }; use thiserror::Error; @@ -119,6 +120,8 @@ pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> { Ok(nlmsghdr) } + +#[derive(Debug)] pub enum NlMsg<'a> { Done, Noop, @@ -131,16 +134,16 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr // - nlmsghdr (contains the message size and type) // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) // - the raw value that we want to validate (if the previous part is nfgenmsg) - let nlmsghdr = get_nlmsghdr(buf)?; + let hdr = get_nlmsghdr(buf)?; let size_of_hdr = pad_netlink_object::<nlmsghdr>(); - if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { - match nlmsghdr.nlmsg_type as libc::c_int { - NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)), - NLMSG_ERROR => { - if nlmsghdr.nlmsg_len as usize > buf.len() - || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() + if hdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { + match hdr.nlmsg_type as libc::c_int { + x if x == NLMSG_NOOP => return Ok((hdr, NlMsg::Noop)), + x if x == NLMSG_ERROR => { + if hdr.nlmsg_len as usize > buf.len() + || (hdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() { return Err(DecodeError::NlMsgTooSmall); } @@ -150,38 +153,40 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr }; // some APIs return negative values, while other return positive values err.error = err.error.abs(); - return Ok((nlmsghdr, NlMsg::Error(err))); + return Ok((hdr, NlMsg::Error(err))); } - NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)), + x if x == NLMSG_DONE => return Ok((hdr, NlMsg::Done)), x => return Err(DecodeError::UnsupportedType(x as u16)), } } - let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); - if subsys != NFNL_SUBSYS_NFTABLES as u8 { - return Err(DecodeError::InvalidSubsystem(subsys)); + // batch messages are not specific to the nftables subsystem + if hdr.nlmsg_type != NFNL_MSG_BATCH_BEGIN as u16 && hdr.nlmsg_type != NFNL_MSG_BATCH_END as u16 + { + // verify that we are decoding nftables messages + let subsys = get_subsystem_from_nlmsghdr_type(hdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(DecodeError::InvalidSubsystem(subsys)); + } } let size_of_nfgenmsg = pad_netlink_object::<Nfgenmsg>(); - if nlmsghdr.nlmsg_len as usize > buf.len() - || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg + if hdr.nlmsg_len as usize > buf.len() + || (hdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg { return Err(DecodeError::NlMsgTooSmall); } let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg; let nfgenmsg = unsafe { *nfgenmsg_ptr }; - let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); - if subsys != NFNL_SUBSYS_NFTABLES as u8 { - return Err(DecodeError::InvalidSubsystem(subsys)); - } + if nfgenmsg.version != NFNETLINK_V0 as u8 { return Err(DecodeError::InvalidVersion(nfgenmsg.version)); } - let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize]; + let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..hdr.nlmsg_len as usize]; - Ok((nlmsghdr, NlMsg::NfGenMsg(nfgenmsg, raw_value))) + Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value))) } pub type NetlinkType = u16; diff --git a/src/query.rs b/src/query.rs index 5065436..3fea40d 100644 --- a/src/query.rs +++ b/src/query.rs @@ -21,7 +21,7 @@ pub fn get_list_of_objects<Error>( setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, ) -> Result<Vec<u8>, Error> { let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; - let mut writer = &mut NfNetlinkWriter::new(&mut buffer); + let mut writer = NfNetlinkWriter::new(&mut buffer); writer.write_header( msg_type, ProtoFamily::Unspec, @@ -106,7 +106,7 @@ mod inner { } let mut buf = &msg_buffer.as_slice()[0..nb_recv]; loop { - let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf) }?; + let (nlmsghdr, msg) = parse_nlmsg(&buf)?; match msg { NlMsg::Done => { return Ok(()); @@ -119,7 +119,7 @@ mod inner { NlMsg::Noop => {} NlMsg::NfGenMsg(genmsg, data) => { if let Some(cb) = cb { - cb(&nlmsghdr, &genmsg, &data, working_data); + cb(&nlmsghdr, &genmsg, &data, working_data)?; } } } diff --git a/tests/batch.rs b/tests/batch.rs index dbf444f..081ee97 100644 --- a/tests/batch.rs +++ b/tests/batch.rs @@ -14,16 +14,14 @@ fn batch_empty() { let buf = batch.finalize(); let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32; - assert_eq!(op, NFNL_MSG_BATCH_BEGIN); + assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_BEGIN as u16); let (_nfgenmsg, attrs, remaining_data) = parse_object(hdr, msg, &buf).expect("Could not parse the batch message"); assert_eq!(attrs.get_raw_data(), []); let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message"); - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32; - assert_eq!(op, NFNL_MSG_BATCH_END); + assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_END as u16); let (_nfgenmsg, attrs, remaining_data) = parse_object(hdr, msg, &remaining_data).expect("Could not parse the batch message"); @@ -55,8 +53,7 @@ fn batch_with_objects() { let buf = batch.finalize(); let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32; - assert_eq!(op, NFNL_MSG_BATCH_BEGIN); + assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_BEGIN as u16); let (_nfgenmsg, attrs, mut remaining_data) = parse_object(hdr, msg, &buf).expect("Could not parse the batch message"); @@ -71,8 +68,7 @@ fn batch_with_objects() { } let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message"); - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as i32; - assert_eq!(op, NFNL_MSG_BATCH_END); + assert_eq!(hdr.nlmsg_type, NFNL_MSG_BATCH_END as u16); let (_nfgenmsg, attrs, remaining_data) = parse_object(hdr, msg, &remaining_data).expect("Could not parse the batch message"); |