aboutsummaryrefslogtreecommitdiff
path: root/src/query.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/query.rs')
-rw-r--r--src/query.rs196
1 files changed, 25 insertions, 171 deletions
diff --git a/src/query.rs b/src/query.rs
index 1c81cdd..80fdc75 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -1,163 +1,14 @@
use std::mem::size_of;
-use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
+use crate::{
+ nlmsg::NfNetlinkWriter,
+ parser::{nft_nlmsg_maxsize, Nfgenmsg},
+ sys, ProtoFamily,
+};
use libc::{
nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NLMSG_DONE, NLMSG_ERROR,
NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR,
};
-use sys::libc;
-
-#[repr(C)]
-#[derive(Debug, Clone, Copy)]
-pub struct Nfgenmsg {
- pub family: u8, /* AF_xxx */
- pub version: u8, /* nfnetlink version */
- pub res_id: u16, /* resource id */
-}
-
-#[derive(thiserror::Error, Debug)]
-pub enum ParseError {
- #[error("The buffer is too small to hold a valid message")]
- BufTooSmall,
-
- #[error("The message is too small")]
- NlMsgTooSmall,
-
- #[error("Invalid subsystem, expected NFTABLES")]
- InvalidSubsystem(u8),
-
- #[error("Invalid version, expected NFNETLINK_V0")]
- InvalidVersion(u8),
-
- #[error("Invalid port ID")]
- InvalidPortId(u32),
-
- #[error("Invalid sequence number")]
- InvalidSeq(u32),
-
- #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")]
- ConcurrentGenerationUpdate,
-
- #[error("Unsupported message type")]
- UnsupportedType(u16),
-
- #[error("A custom error occured")]
- Custom(Box<dyn std::error::Error + 'static>),
-}
-
-pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
- ((x & 0xff00) >> 8) as u8
-}
-
-pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
- (x & 0x00ff) as u8
-}
-
-pub unsafe fn get_nlmsghdr(
- buf: &[u8],
- expected_seq: u32,
- expected_port_id: u32,
-) -> Result<&nlmsghdr, ParseError> {
- let size_of_hdr = size_of::<nlmsghdr>();
-
- if buf.len() < size_of_hdr {
- return Err(ParseError::BufTooSmall);
- }
-
- let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr;
- let nlmsghdr = *nlmsghdr_ptr;
-
- if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr {
- println!("a: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
-
- if nlmsghdr.nlmsg_pid != 0 && expected_port_id != 0 && nlmsghdr.nlmsg_pid != expected_port_id {
- return Err(ParseError::InvalidPortId(nlmsghdr.nlmsg_pid));
- }
-
- if nlmsghdr.nlmsg_seq != 0 && expected_seq != 0 && nlmsghdr.nlmsg_seq != expected_seq {
- return Err(ParseError::InvalidSeq(nlmsghdr.nlmsg_seq));
- }
-
- if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 {
- return Err(ParseError::ConcurrentGenerationUpdate);
- }
-
- Ok(&*nlmsghdr_ptr as &nlmsghdr)
-}
-
-pub enum NlMsg<'a> {
- Done,
- Noop,
- Error(nlmsgerr),
- NfGenMsg(&'a Nfgenmsg, &'a [u8]),
-}
-
-pub unsafe fn parse_nlmsg<'a>(
- buf: &'a [u8],
- expected_seq: u32,
- expected_port_id: u32,
-) -> Result<(&'a nlmsghdr, NlMsg<'a>), ParseError> {
- // in theory the message is composed of the following parts:
- // - nlmsghdr (contains the message size and type)
- // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family)
- // - the raw value that we want to validate (if the previous part is nfgenmsg)
- let nlmsghdr = get_nlmsghdr(buf, expected_seq, expected_port_id)?;
-
- let size_of_hdr = size_of::<nlmsghdr>();
-
- if nlmsghdr.nlmsg_type < NLMSG_MIN_TYPE as u16 {
- match nlmsghdr.nlmsg_type as libc::c_int {
- NLMSG_NOOP => return Ok((nlmsghdr, NlMsg::Noop)),
- NLMSG_ERROR => {
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>()
- {
- println!("b: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
- let mut err = *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr()
- as *const nlmsgerr);
- // some APIs return negative values, while other return positive values
- err.error = err.error.abs();
- return Ok((nlmsghdr, NlMsg::Error(err)));
- }
- NLMSG_DONE => return Ok((nlmsghdr, NlMsg::Done)),
- x => return Err(ParseError::UnsupportedType(x as u16)),
- }
- }
-
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(ParseError::InvalidSubsystem(subsys));
- }
-
- let size_of_nfgenmsg = size_of::<Nfgenmsg>();
- if nlmsghdr.nlmsg_len as usize > buf.len()
- || (nlmsghdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg
- {
- println!("c: {}, {}", buf.len(), nlmsghdr.nlmsg_len);
- return Err(ParseError::NlMsgTooSmall);
- }
-
- let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg;
- let nfgenmsg = *nfgenmsg_ptr;
- let subsys = get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type);
- if subsys != NFNL_SUBSYS_NFTABLES as u8 {
- return Err(ParseError::InvalidSubsystem(subsys));
- }
- if nfgenmsg.version != NFNETLINK_V0 as u8 {
- return Err(ParseError::InvalidVersion(nfgenmsg.version));
- }
-
- let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize];
-
- Ok((
- nlmsghdr,
- NlMsg::NfGenMsg(&*nfgenmsg_ptr as &Nfgenmsg, raw_value),
- ))
-}
/// Returns a buffer containing a netlink message which requests a list of all the netfilter
/// matching objects (e.g. tables, chains, rules, ...).
@@ -165,22 +16,23 @@ pub unsafe fn parse_nlmsg<'a>(
/// to execute on the header, to set parameters for example.
/// To pass arbitrary data inside that callback, please use a closure.
pub fn get_list_of_objects<Error>(
+ msg_type: u16,
seq: u32,
- target: u16,
setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
) -> Result<Vec<u8>, Error> {
let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
- let hdr = unsafe {
- &mut *sys::nftnl_nlmsg_build_hdr(
- buffer.as_mut_ptr() as *mut libc::c_char,
- target,
- ProtoFamily::Unspec as u16,
- (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
- seq,
- )
- };
+ let mut writer = &mut NfNetlinkWriter::new(&mut buffer);
+ writer.write_header(
+ msg_type,
+ ProtoFamily::Unspec,
+ (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
+ seq,
+ None,
+ );
if let Some(cb) = setup_cb {
- cb(hdr)?;
+ cb(writer
+ .get_current_header()
+ .expect("Fatal error: mising header"))?;
}
Ok(buffer)
}
@@ -191,12 +43,12 @@ mod inner {
use nix::{
errno::Errno,
- sys::socket::{
- self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType,
- },
+ sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType},
};
- use crate::FinalizedBatch;
+ //use crate::FinalizedBatch;
+
+ use crate::nlmsg::{parse_nlmsg, DecodeError, NlMsg};
use super::*;
@@ -212,7 +64,7 @@ mod inner {
NetlinkRecvError(#[source] nix::Error),
#[error("Error while processing an incoming netlink message")]
- ProcessNetlinkError(#[from] ParseError),
+ ProcessNetlinkError(#[from] DecodeError),
#[error("Error received from the kernel")]
NetlinkError(nlmsgerr),
@@ -316,7 +168,7 @@ mod inner {
let seq = 0;
let portid = 0;
- let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
+ let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?;
socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?;
Ok(socket_close_wrapper(sock, move |sock| {
@@ -324,6 +176,7 @@ mod inner {
})?)
}
+ /*
pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
let sock = socket::socket(
AddressFamily::Netlink,
@@ -357,6 +210,7 @@ mod inner {
recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid)
})?)
}
+ */
}
#[cfg(feature = "query")]