diff options
Diffstat (limited to 'src/query.rs')
-rw-r--r-- | src/query.rs | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/src/query.rs b/src/query.rs index 80fdc75..5065436 100644 --- a/src/query.rs +++ b/src/query.rs @@ -43,12 +43,15 @@ mod inner { use nix::{ errno::Errno, - sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}, + sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, + }, }; - //use crate::FinalizedBatch; - - use crate::nlmsg::{parse_nlmsg, DecodeError, NlMsg}; + use crate::{ + batch::Batch, + parser::{parse_nlmsg, DecodeError, NlMsg}, + }; use super::*; @@ -90,10 +93,8 @@ mod inner { fn recv_and_process<'a, T>( sock: RawFd, - cb: &dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>, + cb: Option<&dyn Fn(&nlmsghdr, &Nfgenmsg, &[u8], &mut T) -> Result<(), Error>>, working_data: &'a mut T, - seq: u32, - portid: u32, ) -> Result<(), Error> { let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; @@ -105,7 +106,7 @@ mod inner { } let mut buf = &msg_buffer.as_slice()[0..nb_recv]; loop { - let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf, seq, portid) }?; + let (nlmsghdr, msg) = unsafe { parse_nlmsg(&buf) }?; match msg { NlMsg::Done => { return Ok(()); @@ -116,7 +117,11 @@ mod inner { } } NlMsg::Noop => {} - NlMsg::NfGenMsg(genmsg, data) => cb(&nlmsghdr, &genmsg, &data, working_data)?, + NlMsg::NfGenMsg(genmsg, data) => { + if let Some(cb) = cb { + cb(&nlmsghdr, &genmsg, &data, working_data); + } + } } // netlink messages are 4bytes aligned @@ -166,18 +171,16 @@ mod inner { .map_err(Error::NetlinkOpenError)?; let seq = 0; - let portid = 0; let chains_buf = get_list_of_objects(data_type, seq, req_hdr_customize)?; socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; Ok(socket_close_wrapper(sock, move |sock| { - recv_and_process(sock, cb, working_data, seq, portid) + recv_and_process(sock, Some(cb), working_data) })?) } - /* - pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { + pub fn send_batch(batch: Batch) -> Result<(), Error> { let sock = socket::socket( AddressFamily::Netlink, SockType::Raw, @@ -189,28 +192,26 @@ mod inner { let seq = 0; let portid = 0; + let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0)); // while this bind() is not strictly necessary, strace have trouble decoding the messages // if we don't - let addr = SockAddr::Netlink(NetlinkAddr::new(portid, 0)); socket::bind(sock, &addr).expect("bind"); //match socket::getsockname(sock).map_err(|_| Error::RetrievingSocketInfoFailed)? { // SockAddr::Netlink(addr) => addr.0.nl_pid, // _ => return Err(Error::NotNetlinkSocket), //}; - for data in batch { - if socket::send(sock, data, MsgFlags::empty()).map_err(Error::NetlinkSendError)? - < data.len() - { - return Err(Error::TruncatedSend); - } + let to_send = batch.finalize(); + let mut sent = 0; + while sent != to_send.len() { + sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) + .map_err(Error::NetlinkSendError)?; } Ok(socket_close_wrapper(sock, move |sock| { - recv_and_process(sock, &|_, _, _, _| Ok(()), &mut (), seq, portid) + recv_and_process(sock, None, &mut ()) })?) } - */ } #[cfg(feature = "query")] |