diff options
Diffstat (limited to 'src/query.rs')
-rw-r--r-- | src/query.rs | 196 |
1 files changed, 25 insertions, 171 deletions
diff --git a/src/query.rs b/src/query.rs index 1c81cdd..80fdc75 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,163 +1,14 @@ use std::mem::size_of; -use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; +use crate::{ + nlmsg::NfNetlinkWriter, + parser::{nft_nlmsg_maxsize, Nfgenmsg}, + sys, ProtoFamily, +}; use libc::{ nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, }; -use sys::libc; - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct Nfgenmsg { - pub family: u8, /* AF_xxx */ - pub version: u8, /* nfnetlink version */ - pub res_id: u16, /* resource id */ -} - -#[derive(thiserror::Error, Debug)] -pub enum ParseError { - #[error("The buffer is too small to hold a valid message")] - BufTooSmall, - - #[error("The message is too small")] - NlMsgTooSmall, - - #[error("Invalid subsystem, expected NFTABLES")] - InvalidSubsystem(u8), - - #[error("Invalid version, expected NFNETLINK_V0")] - InvalidVersion(u8), - - #[error("Invalid port ID")] - InvalidPortId(u32), - - #[error("Invalid sequence number")] - InvalidSeq(u32), - - #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] - ConcurrentGenerationUpdate, - - #[error("Unsupported message type")] - UnsupportedType(u16), - - #[error("A custom error occured")] - Custom(Box<dyn std::error::Error + 'static>), -} - -pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { - ((x & 0xff00) >> 8) as u8 -} - -pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { - (x & 0x00ff) as u8 -} - -pub unsafe fn get_nlmsghdr( - buf: &[u8], - expected_seq: u32, - expected_port_id: u32, -) -> Result<&nlmsghdr, ParseError> { - let size_of_hdr = size_of::<nlmsghdr>(); - - if buf.len() < size_of_hdr { - return Err(ParseError::BufTooSmall); - } - - let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr; - let nlmsghdr = *nlmsghdr_ptr; - - if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr { - println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len); - return Err(ParseError::NlMsgTooSmall); - } - - if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id { - return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid)); - } - - if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq { - return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq)); - } - - if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 { - return Err(ParseError::ConcurrentGenerationUpdate); - } - - Ok(&*nlmsghdr_ptr as &nlmsghdr) -} - -pub enum NlMsg<'a> { - Done, - Noop, - Error(nlmsgerr), - NfGenMsg(&'a Nfgenmsg, &'a [u8]), -} - -pub unsafe fn parse_nlmsg<'a>( - buf: &'a [u8], - expected_seq: u32, - expected_port_id: u32, -) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> { - // in theory the message is composed of the following parts: - // - nlmsghdr (contains the message size and type) - // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) - // - the raw value that we want to validate (if the previous part is nfgenmsg) - let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?; - - let size_of_hdr = size_of::<nlmsghdr>(); - - if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { - match nlmsghdr.nlmsg_type as libc::c_int { - NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)), - NLMSG_ERROR => { - if nlmsghdr.nlmsg_len as usize > buf.len() - || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() - { - println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len); - return Err(ParseError::NlMsgTooSmall); - } - let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr() - as *const nlmsgerr); - // some APIs return negative values, while other return positive values - err.error = err.error.abs(); - return Ok((nlmsghdr, NlMsg::Error(err))); - } - NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)), - x => return Err(ParseError::UnsupportedType(x as u16)), - } - } - - let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); - if subsys != NFNL_SUBSYS_NFTABLES as u8 { - return Err(ParseError::InvalidSubsystem(subsys)); - } - - let size_of_nfgenmsg = size_of::<Nfgenmsg>(); - if nlmsghdr.nlmsg_len as usize > buf.len() - || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg - { - println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len); - return Err(ParseError::NlMsgTooSmall); - } - - let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg; - let nfgenmsg = *nfgenmsg_ptr; - let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); - if subsys != NFNL_SUBSYS_NFTABLES as u8 { - return Err(ParseError::InvalidSubsystem(subsys)); - } - if nfgenmsg.version != NFNETLINK_V0 as u8 { - return Err(ParseError::InvalidVersion(nfgenmsg.version)); - } - - let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize]; - - Ok(( - nlmsghdr, - NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value), - )) -} /// Returns a buffer containing a netlink message which requests a list of all the netfilter /// matching objects (e.g. tables, chains, rules, ...). @@ -165,22 +16,23 @@ pub unsafe fn parse_nlmsg<'a>( /// to execute on the header, to set parameters for example. /// To pass arbitrary data inside that callback, please use a closure. pub fn get_list_of_objects<Error>( + msg_type: u16, seq: u32, - target: u16, setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, ) -> Result<Vec<u8>, Error> { let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; - let hdr = unsafe { - &mut *sys::nftnl_nlmsg_build_hdr( - buffer.as_mut_ptr() as *mut libc::c_char, - target, - ProtoFamily::Unspec as u16, - (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16, - seq, - ) - }; + let mut writer = &mut NfNetlinkWriter::new(&mut buffer); + writer.write_header( + msg_type, + ProtoFamily::Unspec, + (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16, + seq, + None, + ); if let Some(cb) = setup_cb { - cb(hdr)?; + cb(writer + .get_current_header() + .expect("Fatal error: mising header"))?; } Ok(buffer) } @@ -191,12 +43,12 @@ mod inner { use nix::{ errno::Errno, - sys::socket::{ - self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, - }, + sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}, }; - use crate::FinalizedBatch; + //use crate::FinalizedBatch; + + use crate::nlmsg::{parse_nlmsg, DecodeError, NlMsg}; use super::*; @@ -212,7 +64,7 @@ mod inner { NetlinkRecvError(#[source] nix::Error), #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[from] ParseError), + ProcessNetlinkError(#[from] DecodeError), #[error("Error received from the kernel")] NetlinkError(nlmsgerr), @@ -316,7 +168,7 @@ mod inner { let seq = 0; let portid = 0; - let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; + let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?; socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; Ok(socket_close_wrapper(sock, move |sock| { @@ -324,6 +176,7 @@ mod inner { })?) } + /* pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { let sock = socket::socket( AddressFamily::Netlink, @@ -357,6 +210,7 @@ mod inner { recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid) })?) } + */ } #[cfg(feature = "query")] |