aboutsummaryrefslogtreecommitdiff
path: root/src/query.rs
blob: bc1d02eac585468360ecb867bb7dcd3b238da64a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use crate::{nft_nlmsg_maxsize, sys, ProtoFamily};
use sys::libc;

/// Returns a buffer containing a netlink message which requests a list of all the netfilter
/// matching objects (e.g. tables, chains, rules, ...).
/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback
/// to execute on the header, to set parameters for example.
/// To pass arbitrary data inside that callback, please use a closure.
pub fn get_list_of_objects<Error>(
    seq: u32,
    target: u16,
    setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
) -> Result<Vec<u8>, Error> {
    let mut buffer = vec![0; nft_nlmsg_maxsize() as usize];
    let hdr = unsafe {
        &mut *sys::nftnl_nlmsg_build_hdr(
            buffer.as_mut_ptr() as *mut libc::c_char,
            target,
            ProtoFamily::Unspec as u16,
            (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16,
            seq,
        )
    };
    if let Some(cb) = setup_cb {
        cb(hdr)?;
    }
    Ok(buffer)
}

#[cfg(feature = "query")]
mod inner {
    use crate::FinalizedBatch;

    use super::*;

    #[derive(thiserror::Error, Debug)]
    pub enum Error {
        #[error("Unable to open netlink socket to netfilter")]
        NetlinkOpenError(#[source] std::io::Error),

        #[error("Unable to send netlink command to netfilter")]
        NetlinkSendError(#[source] std::io::Error),

        #[error("Error while reading from netlink socket")]
        NetlinkRecvError(#[source] std::io::Error),

        #[error("Error while processing an incoming netlink message")]
        ProcessNetlinkError(#[source] std::io::Error),

        #[error("Custom error when customizing the query")]
        InitError(#[from] Box<dyn std::error::Error + 'static>),

        #[error("Couldn't allocate a netlink object, out of memory ?")]
        NetlinkAllocationFailed,
    }

    /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper
    /// function called by mnl::cb_run2.
    /// The callback expects a tuple of additional data (supplied as an argument to this function)
    /// and of the output vector, to which it should append the parsed object it received.
    pub fn list_objects_with_data<'a, A, T>(
        data_type: u16,
        cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int,
        additional_data: &'a A,
        req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>,
    ) -> Result<Vec<T>, Error>
    where
        T: 'a,
    {
        debug!("listing objects of kind {}", data_type);
        let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;

        let seq = 0;
        let portid = 0;

        let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?;
        socket.send(&chains_buf).map_err(Error::NetlinkSendError)?;

        let mut res = Vec::new();

        let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
        while socket
            .recv(&mut msg_buffer)
            .map_err(Error::NetlinkRecvError)?
            > 0
        {
            if let mnl::CbResult::Stop = mnl::cb_run2(
                &msg_buffer,
                seq,
                portid,
                cb,
                &mut (additional_data, &mut res),
            )
            .map_err(Error::ProcessNetlinkError)?
            {
                break;
            }
        }

        Ok(res)
    }

    pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> {
        let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?;

        let seq = 0;
        let portid = socket.portid();

        socket.send_all(batch).map_err(Error::NetlinkSendError)?;
        debug!("sent");

        let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize];
        while socket
            .recv(&mut msg_buffer)
            .map_err(Error::NetlinkRecvError)?
            > 0
        {
            if let mnl::CbResult::Stop =
                mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)?
            {
                break;
            }
        }
        Ok(())
    }
}

#[cfg(feature = "query")]
pub use inner::*;