diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | examples/add-rules.rs | 66 | ||||
-rw-r--r-- | examples/filter-ethernet.rs | 55 | ||||
-rw-r--r-- | src/chain.rs | 43 | ||||
-rw-r--r-- | src/lib.rs | 3 | ||||
-rw-r--r-- | src/query.rs | 332 | ||||
-rw-r--r-- | src/rule.rs | 32 | ||||
-rw-r--r-- | src/table.rs | 43 | ||||
-rw-r--r-- | tests/chain.rs | 2 | ||||
-rw-r--r-- | tests/expr.rs | 1 | ||||
-rw-r--r-- | tests/lib.rs | 49 | ||||
-rw-r--r-- | tests/rule.rs | 2 | ||||
-rw-r--r-- | tests/set.rs | 2 | ||||
-rw-r--r-- | tests/table.rs | 2 |
14 files changed, 408 insertions, 226 deletions
@@ -20,7 +20,7 @@ bitflags = "1.0" thiserror = "1.0" log = "0.4" libc = "0.2.43" -mnl = "0.2" +nix = "0.23" ipnetwork = "0.16" serde = { version = "1.0", features = ["derive"] } diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 3aae7ee..6354c8d 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,13 +37,8 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; -use std::{ - ffi::{self, CString}, - io, - net::Ipv4Addr, - rc::Rc -}; +use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; +use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; const TABLE_NAME: &str = "example-table"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; @@ -56,7 +51,10 @@ fn main() -> Result<(), Error> { let mut batch = Batch::new(); // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet) - let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet)); + let table = Rc::new(Table::new( + &CString::new(TABLE_NAME).unwrap(), + ProtoFamily::Inet, + )); // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add // this table under its `ProtoFamily::Inet` ruleset. batch.add(&Rc::clone(&table), rustables::MsgType::Add); @@ -178,10 +176,10 @@ fn main() -> Result<(), Error> { match batch.finalize() { Some(mut finalized_batch) => { // Send the entire batch and process any returned messages. - send_and_process(&mut finalized_batch)?; + send_batch(&mut finalized_batch)?; Ok(()) - }, - None => todo!() + } + None => todo!(), } } @@ -196,53 +194,11 @@ fn iface_index(name: &str) -> Result<libc::c_uint, Error> { } } -fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> { - // Create a netlink socket to netfilter. - let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; - // Send all the bytes in the batch. - socket.send_all(&mut *batch)?; - - // Try to parse the messages coming back from netfilter. This part is still very unclear. - let portid = socket.portid(); - let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize]; - let very_unclear_what_this_is_for = 2; - while let Some(message) = socket_recv(&socket, &mut buffer[..])? { - match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? { - mnl::CbResult::Stop => { - break; - } - mnl::CbResult::Ok => (), - } - } - Ok(()) -} - -fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> { - let ret = socket.recv(buf)?; - if ret > 0 { - Ok(Some(&buf[..ret])) - } else { - Ok(None) - } -} - #[derive(Debug)] struct Error(String); -impl From<io::Error> for Error { - fn from(error: io::Error) -> Self { - Error(error.to_string()) - } -} - -impl From<ffi::NulError> for Error { - fn from(error: ffi::NulError) -> Self { - Error(error.to_string()) - } -} - -impl From<ipnetwork::IpNetworkError> for Error { - fn from(error: ipnetwork::IpNetworkError) -> Self { +impl<T: std::error::Error> From<T> for Error { + fn from(error: T) -> Self { Error(error.to_string()) } } diff --git a/examples/filter-ethernet.rs b/examples/filter-ethernet.rs index b16c49e..41454c9 100644 --- a/examples/filter-ethernet.rs +++ b/examples/filter-ethernet.rs @@ -22,19 +22,22 @@ //! # nft delete table inet example-filter-ethernet //! ``` -use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; -use std::{ffi::CString, io, rc::Rc}; +use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; +use std::{ffi::CString, rc::Rc}; const TABLE_NAME: &str = "example-filter-ethernet"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; const BLOCK_THIS_MAC: &[u8] = &[0, 0, 0, 0, 0, 0]; -fn main() -> Result<(), Error> { +fn main() { // For verbose explanations of what all these lines up until the rule creation does, see the // `add-rules` example. let mut batch = Batch::new(); - let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet)); + let table = Rc::new(Table::new( + &CString::new(TABLE_NAME).unwrap(), + ProtoFamily::Inet, + )); batch.add(&Rc::clone(&table), rustables::MsgType::Add); let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table)); @@ -86,48 +89,8 @@ fn main() -> Result<(), Error> { match batch.finalize() { Some(mut finalized_batch) => { - send_and_process(&mut finalized_batch)?; - Ok(()) - }, - None => todo!() - } -} - -fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> { - // Create a netlink socket to netfilter. - let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; - // Send all the bytes in the batch. - socket.send_all(&mut *batch)?; - - // Try to parse the messages coming back from netfilter. This part is still very unclear. - let portid = socket.portid(); - let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize]; - let very_unclear_what_this_is_for = 2; - while let Some(message) = socket_recv(&socket, &mut buffer[..])? { - match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? { - mnl::CbResult::Stop => { - break; - } - mnl::CbResult::Ok => (), + send_batch(&mut finalized_batch).expect("Couldn't process the batch"); } - } - Ok(()) -} - -fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> { - let ret = socket.recv(buf)?; - if ret > 0 { - Ok(Some(&buf[..ret])) - } else { - Ok(None) - } -} - -#[derive(Debug)] -struct Error(String); - -impl From<io::Error> for Error { - fn from(error: io::Error) -> Self { - Error(error.to_string()) + None => todo!(), } } diff --git a/src/chain.rs b/src/chain.rs index a942a37..18e3c64 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,5 +1,7 @@ -use crate::{MsgType, Table}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self as sys, libc}; +use crate::{MsgType, Table}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ @@ -243,18 +245,27 @@ impl Drop for Chain { #[cfg(feature = "query")] pub fn get_chains_cb<'a>( header: &libc::nlmsghdr, + _genmsg: &Nfgenmsg, + _data: &[u8], (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), -) -> libc::c_int { +) -> Result<(), crate::query::Error> { unsafe { let chain = sys::nftnl_chain_alloc(); if chain == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Chain allocation failed", + ))) + .into()); } let err = sys::nftnl_chain_nlmsg_parse(header, chain); if err < 0 { - error!("Failed to parse nelink chain message - {}", err); sys::nftnl_chain_free(chain); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink chain couldn't be parsed !?", + ))) + .into()); } let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( @@ -267,26 +278,38 @@ pub fn get_chains_cb<'a>( Err(crate::InvalidProtocolFamily) => { error!("The netlink table didn't have a valid protocol family !?"); sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table didn't have a valid protocol family !?", + ))) + .into()); } }; if table_name != table.get_name() { sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; + return Ok(()); } if family != crate::ProtoFamily::Unspec && family != table.get_family() { sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; + return Ok(()); } chains.push(Chain::from_raw(chain, table.clone())); } - mnl::mnl_sys::MNL_CB_OK + + Ok(()) } #[cfg(feature = "query")] pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETCHAIN as u16, + &get_chains_cb, + &mut (&table, &mut result), + None, + )?; + Ok(result) } @@ -119,7 +119,7 @@ pub use rule::Rule; pub use rule::{get_rules_cb, list_rules_for_chain}; mod rule_methods; -pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError}; +pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; pub mod set; pub use set::Set; @@ -155,6 +155,7 @@ pub enum ProtoFamily { Ipv6 = libc::NFPROTO_IPV6 as u16, DecNet = libc::NFPROTO_DECNET as u16, } + #[derive(Error, Debug)] #[error("Couldn't find a matching protocol")] pub struct InvalidProtocolFamily; diff --git a/src/query.rs b/src/query.rs index d7574e8..1c81cdd 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,164 @@ +use std::mem::size_of; + use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; +use libc::{ + nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR, + NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, +}; use sys::libc; +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct Nfgenmsg { + pub family: u8, /* AF_xxx */ + pub version: u8, /* nfnetlink version */ + pub res_id: u16, /* resource id */ +} + +#[derive(thiserror::Error, Debug)] +pub enum ParseError { + #[error("The buffer is too small to hold a valid message")] + BufTooSmall, + + #[error("The message is too small")] + NlMsgTooSmall, + + #[error("Invalid subsystem, expected NFTABLES")] + InvalidSubsystem(u8), + + #[error("Invalid version, expected NFNETLINK_V0")] + InvalidVersion(u8), + + #[error("Invalid port ID")] + InvalidPortId(u32), + + #[error("Invalid sequence number")] + InvalidSeq(u32), + + #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] + ConcurrentGenerationUpdate, + + #[error("Unsupported message type")] + UnsupportedType(u16), + + #[error("A custom error occured")] + Custom(Box<dyn std::error::Error + 'static>), +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +pub unsafe fn get_nlmsghdr( + buf: &[u8], + expected_seq: u32, + expected_port_id: u32, +) -> Result<&nlmsghdr, ParseError> { + let size_of_hdr = size_of::<nlmsghdr>(); + + if buf.len() < size_of_hdr { + return Err(ParseError::BufTooSmall); + } + + let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr; + let nlmsghdr = *nlmsghdr_ptr; + + if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr { + println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + + if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id { + return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid)); + } + + if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq { + return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq)); + } + + if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 { + return Err(ParseError::ConcurrentGenerationUpdate); + } + + Ok(&*nlmsghdr_ptr as &nlmsghdr) +} + +pub enum NlMsg<'a> { + Done, + Noop, + Error(nlmsgerr), + NfGenMsg(&'a Nfgenmsg, &'a [u8]), +} + +pub unsafe fn parse_nlmsg<'a>( + buf: &'a [u8], + expected_seq: u32, + expected_port_id: u32, +) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> { + // in theory the message is composed of the following parts: + // - nlmsghdr (contains the message size and type) + // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) + // - the raw value that we want to validate (if the previous part is nfgenmsg) + let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?; + + let size_of_hdr = size_of::<nlmsghdr>(); + + if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { + match nlmsghdr.nlmsg_type as libc::c_int { + NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)), + NLMSG_ERROR => { + if nlmsghdr.nlmsg_len as usize > buf.len() + || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() + { + println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr() + as *const nlmsgerr); + // some APIs return negative values, while other return positive values + err.error = err.error.abs(); + return Ok((nlmsghdr, NlMsg::Error(err))); + } + NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)), + x => return Err(ParseError::UnsupportedType(x as u16)), + } + } + + let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(ParseError::InvalidSubsystem(subsys)); + } + + let size_of_nfgenmsg = size_of::<Nfgenmsg>(); + if nlmsghdr.nlmsg_len as usize > buf.len() + || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg + { + println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len); + return Err(ParseError::NlMsgTooSmall); + } + + let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg; + let nfgenmsg = *nfgenmsg_ptr; + let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(ParseError::InvalidSubsystem(subsys)); + } + if nfgenmsg.version != NFNETLINK_V0 as u8 { + return Err(ParseError::InvalidVersion(nfgenmsg.version)); + } + + let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize]; + + Ok(( + nlmsghdr, + NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value), + )) +} + /// Returns a buffer containing a netlink message which requests a list of all the netfilter /// matching objects (e.g. tables, chains, rules, ...). /// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback @@ -29,6 +187,15 @@ pub fn get_list_of_objects<Error>( #[cfg(feature = "query")] mod inner { + use std::os::unix::prelude::RawFd; + + use nix::{ + errno::Errno, + sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, + }, + }; + use crate::FinalizedBatch; use super::*; @@ -36,92 +203,159 @@ mod inner { #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), + NetlinkOpenError(#[source] nix::Error), #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] std::io::Error), + NetlinkSendError(#[source] nix::Error), #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] std::io::Error), + NetlinkRecvError(#[source] nix::Error), #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[source] std::io::Error), + ProcessNetlinkError(#[from] ParseError), + + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), #[error("Custom error when customizing the query")] InitError(#[from] Box<dyn std::error::Error + Send + 'static>), #[error("Couldn't allocate a netlink object, out of memory ?")] NetlinkAllocationFailed, + + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, + + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, + + #[error("Only a part of the message was sent")] + TruncatedSend, + + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), + } + + fn recv_and_process<'a, T>( + sock: RawFd, + cb: &dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + working_data: &'a mut T, + seq: u32, + portid: u32, + ) -> Result<(), Error> { + let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; + + loop { + let nb_recv = socket::recv(sock, &mut msg_buffer, MsgFlags::empty()) + .map_err(Error::NetlinkRecvError)?; + if nb_recv <= 0 { + return Ok(()); + } + let mut buf = &msg_buffer.as_slice()[0..nb_recv]; + loop { + let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf, seq, portid) }?; + match msg { + NlMsg::Done => { + return Ok(()); + } + NlMsg::Error(e) => { + if e.error != 0 { + return Err(Error::NetlinkError(e)); + } + } + NlMsg::Noop => {} + NlMsg::NfGenMsg(genmsg, data) => cb(&nlmsghdr, &genmsg, &data, working_data)?, + } + + // netlink messages are 4bytes aligned + let aligned_length = ((nlmsghdr.nlmsg_len + 3) & !3u32) as usize; + + // retrieve the next message + buf = &buf[aligned_length..]; + + // exit the loop when we consumed all the buffer + if buf.len() == 0 { + break; + } + } + } + } + + fn socket_close_wrapper( + sock: RawFd, + cb: impl FnOnce(RawFd) -> Result<(), Error>, + ) -> Result<(), Error> { + let ret = cb(sock); + + // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; + // and return EOPNOTSUPP if we try) + nix::unistd::close(sock).map_err(Error::CloseFailed)?; + + ret } /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper /// function called by mnl::cb_run2. /// The callback expects a tuple of additional data (supplied as an argument to this function) /// and of the output vector, to which it should append the parsed object it received. - pub fn list_objects_with_data<'a, A, T>( + pub fn list_objects_with_data<'a, T>( data_type: u16, - cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int, - additional_data: &'a A, + cb: &dyn Fn(&libc::nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + working_data: &'a mut T, req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, - ) -> Result<Vec<T>, Error> - where - T: 'a, - { + ) -> Result<(), Error> { debug!("listing objects of kind {}", data_type); - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; let seq = 0; let portid = 0; let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; - socket.send(&chains_buf).map_err(Error::NetlinkSendError)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; - let mut res = Vec::new(); - - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = mnl::cb_run2( - &msg_buffer, - seq, - portid, - cb, - &mut (additional_data, &mut res), - ) - .map_err(Error::ProcessNetlinkError)? - { - break; - } - } - - Ok(res) + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, cb, working_data, seq, portid) + })?) } pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(Error::NetlinkOpenError)?; let seq = 0; - let portid = socket.portid(); + let portid = 0; - socket.send_all(batch).map_err(Error::NetlinkSendError)?; - debug!("sent"); + // while this bind() is not strictly necessary, strace have trouble decoding the messages + // if we don't + let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0)); + socket::bind(sock, &addr).expect("bind"); + //match socket::getsockname(sock).map_err(|_| Error::RetrievingSocketInfoFailed)? { + // SockAddr::Netlink(addr) => addr.0.nl_pid, + // _ => return Err(Error::NotNetlinkSocket), + //}; - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = - mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)? + for data in batch { + if socket::send(sock, data, MsgFlags::empty()).map_err(Error::NetlinkSendError)? + < data.len() { - break; + return Err(Error::TruncatedSend); } } - Ok(()) + + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid) + })?) } } diff --git a/src/rule.rs b/src/rule.rs index 2ee5308..66beef8 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,6 +1,8 @@ use crate::expr::ExpressionWrapper; -use crate::{chain::Chain, expr::Expression, MsgType}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self, libc}; +use crate::{chain::Chain, expr::Expression, MsgType}; use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; use std::os::raw::c_char; @@ -284,31 +286,42 @@ impl Drop for RuleExprsIter { #[cfg(feature = "query")] pub fn get_rules_cb( header: &libc::nlmsghdr, + _genmsg: &Nfgenmsg, + _data: &[u8], (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> libc::c_int { +) -> Result<(), crate::query::Error> { unsafe { let rule = sys::nftnl_rule_alloc(); if rule == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Rule allocation failed", + ))) + .into()); } let err = sys::nftnl_rule_nlmsg_parse(header, rule); if err < 0 { - error!("Failed to parse nelink rule message - {}", err); sys::nftnl_rule_free(rule); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table couldn't be parsed !?", + ))) + .into()); } rules.push(Rule::from_raw(rule, chain.clone())); } - mnl::mnl_sys::MNL_CB_OK + + Ok(()) } #[cfg(feature = "query")] pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { + let mut result = Vec::new(); crate::query::list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - get_rules_cb, - &chain, + &get_rules_cb, + &mut (chain, &mut result), // only retrieve rules from the currently targetted chain Some(&|hdr| unsafe { let rule = sys::nftnl_rule_alloc(); @@ -337,5 +350,6 @@ pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query sys::nftnl_rule_free(rule); Ok(()) }), - ) + )?; + Ok(result) } diff --git a/src/table.rs b/src/table.rs index 593fffb..332cc99 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,5 +1,7 @@ -use crate::{MsgType, ProtoFamily}; +#[cfg(feature = "query")] +use crate::query::{Nfgenmsg, ParseError}; use crate::sys::{self, libc}; +use crate::{MsgType, ProtoFamily}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ @@ -39,7 +41,7 @@ impl Table { unsafe { let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16); if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") + panic!("Impossible situation: retrieving the name of a table failed") } else { CStr::from_ptr(ptr) } @@ -137,29 +139,41 @@ impl Drop for Table { /// A callback to parse the response for messages created with `get_tables_nlmsg`. pub fn get_tables_cb( header: &libc::nlmsghdr, - (_, tables): &mut (&(), &mut Vec<Table>), -) -> libc::c_int { + _genmsg: &Nfgenmsg, + _data: &[u8], + tables: &mut Vec<Table>, +) -> Result<(), crate::query::Error> { unsafe { let table = sys::nftnl_table_alloc(); if table == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Table allocation failed", + ))) + .into()); } let err = sys::nftnl_table_nlmsg_parse(header, table); if err < 0 { - error!("Failed to parse nelink table message - {}", err); sys::nftnl_table_free(table); - return err; + return Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table couldn't be parsed !?", + ))) + .into()); } let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16); match crate::ProtoFamily::try_from(family as i32) { Ok(family) => { tables.push(Table::from_raw(table, family)); - mnl::mnl_sys::MNL_CB_OK + Ok(()) } Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); sys::nftnl_table_free(table); - mnl::mnl_sys::MNL_CB_ERROR + Err(ParseError::Custom(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "The netlink table didn't have a valid protocol family !?", + ))) + .into()) } } } @@ -167,5 +181,12 @@ pub fn get_tables_cb( #[cfg(feature = "query")] pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None) + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETTABLE as u16, + &get_tables_cb, + &mut result, + None, + )?; + Ok(result) } diff --git a/tests/chain.rs b/tests/chain.rs index 4b6da91..809936a 100644 --- a/tests/chain.rs +++ b/tests/chain.rs @@ -1,7 +1,7 @@ use std::ffi::CStr; mod sys; -use rustables::MsgType; +use rustables::{query::get_operation_from_nlmsghdr_type, MsgType}; use sys::*; mod lib; diff --git a/tests/expr.rs b/tests/expr.rs index 7950df3..cfc9c7d 100644 --- a/tests/expr.rs +++ b/tests/expr.rs @@ -3,6 +3,7 @@ use rustables::expr::{ LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, TransportHeaderField, Verdict, }; +use rustables::query::{get_operation_from_nlmsghdr_type, Nfgenmsg}; use rustables::set::Set; use rustables::sys::libc::{nlmsghdr, NF_DROP}; use rustables::{ProtoFamily, Rule}; diff --git a/tests/lib.rs b/tests/lib.rs index 29c61b3..9b44a88 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,19 +1,11 @@ #![allow(dead_code)] -use libc::{nlmsghdr, AF_UNIX, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; +use libc::{nlmsghdr, AF_UNIX}; +use rustables::query::{parse_nlmsg, Nfgenmsg}; use rustables::set::SetKey; use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Set, Table}; use std::ffi::{c_void, CStr}; -use std::mem::size_of; use std::rc::Rc; -pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { - ((x & 0xff00) >> 8) as u8 -} - -pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { - (x & 0x00ff) as u8 -} - pub const TABLE_NAME: &[u8; 10] = b"mocktable\0"; pub const CHAIN_NAME: &[u8; 10] = b"mockchain\0"; pub const SET_NAME: &[u8; 8] = b"mockset\0"; @@ -89,14 +81,6 @@ impl NetlinkExpr { } } -#[repr(C)] -#[derive(Clone, Copy)] -pub struct Nfgenmsg { - family: u8, /* AF_xxx */ - version: u8, /* nfnetlink version */ - res_id: u16, /* resource id */ -} - pub fn get_test_table() -> Table { Table::new( &CStr::from_bytes_with_nul(TABLE_NAME).unwrap(), @@ -131,35 +115,20 @@ pub fn get_test_nlmsg_with_msg_type( unsafe { obj.write(buf.as_mut_ptr() as *mut c_void, 0, msg_type); - // right now the message is composed of the following parts: - // - nlmsghdr (contains the message size and type) - // - nfgenmsg (nftables header that describes the message family) - // - the raw value that we want to validate - - let size_of_hdr = size_of::<nlmsghdr>(); - let size_of_nfgenmsg = size_of::<Nfgenmsg>(); - let nlmsghdr = *(buf[0..size_of_hdr].as_ptr() as *const nlmsghdr); - let nfgenmsg = - *(buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg); - let raw_value = buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize] - .iter() - .map(|x| *x) - .collect(); + let (nlmsghdr, msg) = parse_nlmsg(&buf, 0, 0).expect("Couldn't parse the message"); + + let (nfgenmsg, raw_value) = match msg { + rustables::query::NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), + _ => panic!("Invalid return value type, expected a valid message"), + }; // sanity checks on the global message (this should be very similar/factorisable for the // most part in other tests) // TODO: check the messages flags - assert_eq!( - get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFNL_SUBSYS_NFTABLES as u8 - ); - assert_eq!(nlmsghdr.nlmsg_seq, 0); - assert_eq!(nlmsghdr.nlmsg_pid, 0); assert_eq!(nfgenmsg.family, AF_UNIX as u8); - assert_eq!(nfgenmsg.version, NFNETLINK_V0 as u8); assert_eq!(nfgenmsg.res_id.to_be(), 0); - (nlmsghdr, nfgenmsg, raw_value) + (*nlmsghdr, *nfgenmsg, raw_value.to_owned()) } } diff --git a/tests/rule.rs b/tests/rule.rs index b601a61..517db47 100644 --- a/tests/rule.rs +++ b/tests/rule.rs @@ -1,7 +1,7 @@ use std::ffi::CStr; mod sys; -use rustables::MsgType; +use rustables::{query::get_operation_from_nlmsghdr_type, MsgType}; use sys::*; mod lib; diff --git a/tests/set.rs b/tests/set.rs index d5b2ad7..4b79988 100644 --- a/tests/set.rs +++ b/tests/set.rs @@ -1,7 +1,7 @@ mod sys; use std::net::{Ipv4Addr, Ipv6Addr}; -use rustables::{set::SetKey, MsgType}; +use rustables::{query::get_operation_from_nlmsghdr_type, set::SetKey, MsgType}; use sys::*; mod lib; diff --git a/tests/table.rs b/tests/table.rs index 3d8957c..971d58b 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -1,7 +1,7 @@ use std::ffi::CStr; mod sys; -use rustables::MsgType; +use rustables::{query::get_operation_from_nlmsghdr_type, MsgType}; use sys::*; mod lib; |