diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/chain.rs | 43 | ||||
-rw-r--r-- | src/lib.rs | 3 | ||||
-rw-r--r-- | src/query.rs | 332 | ||||
-rw-r--r-- | src/rule.rs | 32 | ||||
-rw-r--r-- | src/table.rs | 43 |
5 files changed, 373 insertions, 80 deletions
diff --git a/src/chain.rs b/src/chain.rs index a942a37..18e3c64 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,5 +1,7 @@ -use crate::{MsgType, Table}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self as sys, libc}; +use crate::{MsgType, Table}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ @@ -243,18 +245,27 @@ impl Drop for Chain { #[cfg(feature = "query")] pub fn get_chains_cb<'a>( header: &libc::nlmsghdr, + _genmsg: &Nfgenmsg, + _data: &[u8], (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), -) -> libc::c_int { +) -> Result<(), crate::query::Error> { unsafe { let chain = sys::nftnl_chain_alloc(); if chain == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Chain allocation failed", + ))) + .into()); } let err = sys::nftnl_chain_nlmsg_parse(header, chain); if err < 0 { - error!("Failed to parse nelink chain message - {}", err); sys::nftnl_chain_free(chain); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink chain couldn't be parsed !?", + ))) + .into()); } let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( @@ -267,26 +278,38 @@ pub fn get_chains_cb<'a>( Err(crate::InvalidProtocolFamily) => { error!("The netlink table didn't have a valid protocol family !?"); sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table didn't have a valid protocol family !?", + ))) + .into()); } }; if table_name != table.get_name() { sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; + return Ok(()); } if family != crate::ProtoFamily::Unspec && family != table.get_family() { sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; + return Ok(()); } chains.push(Chain::from_raw(chain, table.clone())); } - mnl::mnl_sys::MNL_CB_OK + + Ok(()) } #[cfg(feature = "query")] pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETCHAIN as u16, + &get_chains_cb, + &mut (&table, &mut result), + None, + )?; + Ok(result) } @@ -119,7 +119,7 @@ pub use rule::Rule; pub use rule::{get_rules_cb, list_rules_for_chain}; mod rule_methods; -pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError}; +pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; pub mod set; pub use set::Set; @@ -155,6 +155,7 @@ pub enum ProtoFamily { Ipv6 = libc::NFPROTO_IPV6 as u16, DecNet = libc::NFPROTO_DECNET as u16, } + #[derive(Error, Debug)] #[error("Couldn't find a matching protocol")] pub struct InvalidProtocolFamily; diff --git a/src/query.rs b/src/query.rs index d7574e8..1c81cdd 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,164 @@ +use std::mem::size_of; + use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; +use libc::{ + nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR, + NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, +}; use sys::libc; +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct Nfgenmsg { + pub family: u8, /* AF_xxx */ + pub version: u8, /* nfnetlink version */ + pub res_id: u16, /* resource id */ +} + +#[derive(thiserror::Error, Debug)] +pub enum ParseError { + #[error("The buffer is too small to hold a valid message")] + BufTooSmall, + + #[error("The message is too small")] + NlMsgTooSmall, + + #[error("Invalid subsystem, expected NFTABLES")] + InvalidSubsystem(u8), + + #[error("Invalid version, expected NFNETLINK_V0")] + InvalidVersion(u8), + + #[error("Invalid port ID")] + InvalidPortId(u32), + + #[error("Invalid sequence number")] + InvalidSeq(u32), + + #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] + ConcurrentGenerationUpdate, + + #[error("Unsupported message type")] + UnsupportedType(u16), + + #[error("A custom error occured")] + Custom(Box<dyn std::error::Error + 'static>), +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +pub unsafe fn get_nlmsghdr( + buf: &[u8], + expected_seq: u32, + expected_port_id: u32, +) -> Result<&nlmsghdr, ParseError> { + let size_of_hdr = size_of::<nlmsghdr>(); + + if buf.len() < size_of_hdr { + return Err(ParseError::BufTooSmall); + } + + let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr; + let nlmsghdr = *nlmsghdr_ptr; + + if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr { + println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + + if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id { + return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid)); + } + + if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq { + return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq)); + } + + if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 { + return Err(ParseError::ConcurrentGenerationUpdate); + } + + Ok(&*nlmsghdr_ptr as &nlmsghdr) +} + +pub enum NlMsg<'a> { + Done, + Noop, + Error(nlmsgerr), + NfGenMsg(&'a Nfgenmsg, &'a [u8]), +} + +pub unsafe fn parse_nlmsg<'a>( + buf: &'a [u8], + expected_seq: u32, + expected_port_id: u32, +) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> { + // in theory the message is composed of the following parts: + // - nlmsghdr (contains the message size and type) + // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) + // - the raw value that we want to validate (if the previous part is nfgenmsg) + let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?; + + let size_of_hdr = size_of::<nlmsghdr>(); + + if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { + match nlmsghdr.nlmsg_type as libc::c_int { + NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)), + NLMSG_ERROR => { + if nlmsghdr.nlmsg_len as usize > buf.len() + || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() + { + println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr() + as *const nlmsgerr); + // some APIs return negative values, while other return positive values + err.error = err.error.abs(); + return Ok((nlmsghdr, NlMsg::Error(err))); + } + NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)), + x => return Err(ParseError::UnsupportedType(x as u16)), + } + } + + let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(ParseError::InvalidSubsystem(subsys)); + } + + let size_of_nfgenmsg = size_of::<Nfgenmsg>(); + if nlmsghdr.nlmsg_len as usize > buf.len() + || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg + { + println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + + let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg; + let nfgenmsg = *nfgenmsg_ptr; + let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(ParseError::InvalidSubsystem(subsys)); + } + if nfgenmsg.version != NFNETLINK_V0 as u8 { + return Err(ParseError::InvalidVersion(nfgenmsg.version)); + } + + let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize]; + + Ok(( + nlmsghdr, + NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value), + )) +} + /// Returns a buffer containing a netlink message which requests a list of all the netfilter /// matching objects (e.g. tables, chains, rules, ...). /// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback @@ -29,6 +187,15 @@ pub fn get_list_of_objects<Error>( #[cfg(feature = "query")] mod inner { + use std::os::unix::prelude::RawFd; + + use nix::{ + errno::Errno, + sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, + }, + }; + use crate::FinalizedBatch; use super::*; @@ -36,92 +203,159 @@ mod inner { #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), + NetlinkOpenError(#[source] nix::Error), #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] std::io::Error), + NetlinkSendError(#[source] nix::Error), #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] std::io::Error), + NetlinkRecvError(#[source] nix::Error), #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[source] std::io::Error), + ProcessNetlinkError(#[from] ParseError), + + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), #[error("Custom error when customizing the query")] InitError(#[from] Box<dyn std::error::Error + Send + 'static>), #[error("Couldn't allocate a netlink object, out of memory ?")] NetlinkAllocationFailed, + + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, + + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, + + #[error("Only a part of the message was sent")] + TruncatedSend, + + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), + } + + fn recv_and_process<'a, T>( + sock: RawFd, + cb: &dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + working_data: &'a mut T, + seq: u32, + portid: u32, + ) -> Result<(), Error> { + let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + + loop { + let nb_recv = socket::recv(sock, &mut msg_buffer, MsgFlags::empty()) + .map_err(Error::NetlinkRecvError)?; + if nb_recv <= 0 { + return Ok(()); + } + let mut buf = &msg_buffer.as_slice()[0..nb_recv]; + loop { + let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf, seq, portid) }?; + match msg { + NlMsg::Done => { + return Ok(()); + } + NlMsg::Error(e) => { + if e.error != 0 { + return Err(Error::NetlinkError(e)); + } + } + NlMsg::Noop => {} + NlMsg::NfGenMsg(genmsg, data) => cb(&nlmsghdr, &genmsg, &data, working_data)?, + } + + // netlink messages are 4bytes aligned + let aligned_length = ((nlmsghdr.nlmsg_len + 3) & !3u32) as usize; + + // retrieve the next message + buf = &buf[aligned_length..]; + + // exit the loop when we consumed all the buffer + if buf.len() == 0 { + break; + } + } + } + } + + fn socket_close_wrapper( + sock: RawFd, + cb: impl FnOnce(RawFd) -> Result<(), Error>, + ) -> Result<(), Error> { + let ret = cb(sock); + + // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; + // and return EOPNOTSUPP if we try) + nix::unistd::close(sock).map_err(Error::CloseFailed)?; + + ret } /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper /// function called by mnl::cb_run2. /// The callback expects a tuple of additional data (supplied as an argument to this function) /// and of the output vector, to which it should append the parsed object it received. - pub fn list_objects_with_data<'a, A, T>( + pub fn list_objects_with_data<'a, T>( data_type: u16, - cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int, - additional_data: &'a A, + cb: &dyn Fn(&libc::nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + working_data: &'a mut T, req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, - ) -> Result<Vec<T>, Error> - where - T: 'a, - { + ) -> Result<(), Error> { debug!("listing objects of kind {}", data_type); - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; let seq = 0; let portid = 0; let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; - socket.send(&chains_buf).map_err(Error::NetlinkSendError)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; - let mut res = Vec::new(); - - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = mnl::cb_run2( - &msg_buffer, - seq, - portid, - cb, - &mut (additional_data, &mut res), - ) - .map_err(Error::ProcessNetlinkError)? - { - break; - } - } - - Ok(res) + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, cb, working_data, seq, portid) + })?) } pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; let seq = 0; - let portid = socket.portid(); + let portid = 0; - socket.send_all(batch).map_err(Error::NetlinkSendError)?; - debug!("sent"); + // while this bind() is not strictly necessary, strace have trouble decoding the messages + // if we don't + let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0)); + socket::bind(sock, &addr).expect("bind"); + //match socket::getsockname(sock).map_err(|_| Error::RetrievingSocketInfoFailed)? { + // SockAddr::Netlink(addr) => addr.0.nl_pid, + // _ => return Err(Error::NotNetlinkSocket), + //}; - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = - mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)? + for data in batch { + if socket::send(sock, data, MsgFlags::empty()).map_err(Error::NetlinkSendError)? + < data.len() { - break; + return Err(Error::TruncatedSend); } } - Ok(()) + + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid) + })?) } } diff --git a/src/rule.rs b/src/rule.rs index 2ee5308..66beef8 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,6 +1,8 @@ use crate::expr::ExpressionWrapper; -use crate::{chain::Chain, expr::Expression, MsgType}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self, libc}; +use crate::{chain::Chain, expr::Expression, MsgType}; use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; use std::os::raw::c_char; @@ -284,31 +286,42 @@ impl Drop for RuleExprsIter { #[cfg(feature = "query")] pub fn get_rules_cb( header: &libc::nlmsghdr, + _genmsg: &Nfgenmsg, + _data: &[u8], (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> libc::c_int { +) -> Result<(), crate::query::Error> { unsafe { let rule = sys::nftnl_rule_alloc(); if rule == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Rule allocation failed", + ))) + .into()); } let err = sys::nftnl_rule_nlmsg_parse(header, rule); if err < 0 { - error!("Failed to parse nelink rule message - {}", err); sys::nftnl_rule_free(rule); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table couldn't be parsed !?", + ))) + .into()); } rules.push(Rule::from_raw(rule, chain.clone())); } - mnl::mnl_sys::MNL_CB_OK + + Ok(()) } #[cfg(feature = "query")] pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { + let mut result = Vec::new(); crate::query::list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - get_rules_cb, - &chain, + &get_rules_cb, + &mut (chain, &mut result), // only retrieve rules from the currently targetted chain Some(&|hdr| unsafe { let rule = sys::nftnl_rule_alloc(); @@ -337,5 +350,6 @@ pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query sys::nftnl_rule_free(rule); Ok(()) }), - ) + )?; + Ok(result) } diff --git a/src/table.rs b/src/table.rs index 593fffb..332cc99 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,5 +1,7 @@ -use crate::{MsgType, ProtoFamily}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self, libc}; +use crate::{MsgType, ProtoFamily}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ @@ -39,7 +41,7 @@ impl Table { unsafe { let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16); if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") + panic!("Impossible situation: retrieving the name of a table failed") } else { CStr::from_ptr(ptr) } @@ -137,29 +139,41 @@ impl Drop for Table { /// A callback to parse the response for messages created with `get_tables_nlmsg`. pub fn get_tables_cb( header: &libc::nlmsghdr, - (_, tables): &mut (&(), &mut Vec<Table>), -) -> libc::c_int { + _genmsg: &Nfgenmsg, + _data: &[u8], + tables: &mut Vec<Table>, +) -> Result<(), crate::query::Error> { unsafe { let table = sys::nftnl_table_alloc(); if table == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Table allocation failed", + ))) + .into()); } let err = sys::nftnl_table_nlmsg_parse(header, table); if err < 0 { - error!("Failed to parse nelink table message - {}", err); sys::nftnl_table_free(table); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table couldn't be parsed !?", + ))) + .into()); } let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16); match crate::ProtoFamily::try_from(family as i32) { Ok(family) => { tables.push(Table::from_raw(table, family)); - mnl::mnl_sys::MNL_CB_OK + Ok(()) } Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); sys::nftnl_table_free(table); - mnl::mnl_sys::MNL_CB_ERROR + Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table didn't have a valid protocol family !?", + ))) + .into()) } } } @@ -167,5 +181,12 @@ pub fn get_tables_cb( #[cfg(feature = "query")] pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None) + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETTABLE as u16, + &get_tables_cb, + &mut result, + None, + )?; + Ok(result) } |