aboutsummaryrefslogtreecommitdiff
path: root/src/batch.rs
blob: 198e8d0ddd1abc1239d2c87f0ba1c2adcbb6adf5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
use crate::{MsgType, NlMsg};
use crate::sys::{self as sys, libc};
use std::ffi::c_void;
use std::os::raw::c_char;
use std::ptr;
use thiserror::Error;

/// Error while communicating with netlink.
#[derive(Error, Debug)]
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());

#[cfg(feature = "query")]
/// Check if the kernel supports batched netlink messages to netfilter.
pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> {
    match unsafe { sys::nftnl_batch_is_supported() } {
        1 => Ok(true),
        0 => Ok(false),
        _ => Err(NetlinkError(())),
    }
}

/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to
/// `nftnl_batch` in libnftnl.
pub struct Batch {
    pub(crate) batch: *mut sys::nftnl_batch,
    pub(crate) seq: u32,
    pub(crate) is_empty: bool,
}

impl Batch {
    /// Creates a new nftnl batch with the [default page size].
    ///
    /// [default page size]: fn.default_batch_page_size.html
    pub fn new() -> Self {
        Self::with_page_size(default_batch_page_size())
    }

    pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self {
        Batch {
            batch,
            seq,
            // we assume this batch is not empty by default
            is_empty: false,
        }
    }

    /// Creates a new nftnl batch with the given batch size.
    pub fn with_page_size(batch_page_size: u32) -> Self {
        let batch = try_alloc!(unsafe {
            sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize())
        });
        let mut this = Batch {
            batch,
            seq: 0,
            is_empty: true,
        };
        this.write_begin_msg();
        this
    }

    /// Adds the given message to this batch.
    pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) {
        trace!("Writing NlMsg with seq {} to batch", self.seq);
        unsafe { msg.write(self.current(), self.seq, msg_type) };
        self.is_empty = false;
        self.next()
    }

    /// Adds all the messages in the given iterator to this batch. If any message fails to be
    /// added the error for that failure is returned and all messages up until that message stay
    /// added to the batch.
    pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType)
    where
        T: NlMsg,
        I: Iterator<Item = T>,
    {
        for msg in msg_iter {
            self.add(&msg, msg_type);
        }
    }

    /// Adds the final end message to the batch and returns a [`FinalizedBatch`] that can be used
    /// to send the messages to netfilter.
    ///
    /// Return None if there is no object in the batch (this could block forever).
    ///
    /// [`FinalizedBatch`]: struct.FinalizedBatch.html
    pub fn finalize(mut self) -> Option<FinalizedBatch> {
        self.write_end_msg();
        if self.is_empty {
            return None;
        }
        Some(FinalizedBatch { batch: self })
    }

    fn current(&self) -> *mut c_void {
        unsafe { sys::nftnl_batch_buffer(self.batch) }
    }

    fn next(&mut self) {
        if unsafe { sys::nftnl_batch_update(self.batch) } < 0 {
            // See try_alloc definition.
            std::process::abort();
        }
        self.seq += 1;
    }

    fn write_begin_msg(&mut self) {
        unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) };
        self.next();
    }

    fn write_end_msg(&mut self) {
        unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) };
        self.next();
    }

    #[cfg(feature = "unsafe-raw-handles")]
    /// Returns the raw handle.
    pub fn as_ptr(&self) -> *const sys::nftnl_batch {
        self.batch as *const sys::nftnl_batch
    }

    #[cfg(feature = "unsafe-raw-handles")]
    /// Returns a mutable version of the raw handle.
    pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch {
        self.batch
    }
}

impl Drop for Batch {
    fn drop(&mut self) {
        unsafe { sys::nftnl_batch_free(self.batch) };
    }
}

/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper
/// batch end message. Created from [`Batch::finalize`].
///
/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch.
///
/// [`Batch`]: struct.Batch.html
/// [`Batch::finalize`]: struct.Batch.html#method.finalize
pub struct FinalizedBatch {
    batch: Batch,
}

impl FinalizedBatch {
    /// Returns the iterator over byte buffers to send to netlink.
    pub fn iter(&mut self) -> Iter<'_> {
        let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize };
        let mut iovecs = vec![
            libc::iovec {
                iov_base: ptr::null_mut(),
                iov_len: 0,
            };
            num_pages
        ];
        let iovecs_ptr = iovecs.as_mut_ptr();
        unsafe {
            sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32);
        }
        Iter {
            iovecs: iovecs.into_iter(),
            _marker: ::std::marker::PhantomData,
        }
    }
}

impl<'a> IntoIterator for &'a mut FinalizedBatch {
    type Item = &'a [u8];
    type IntoIter = Iter<'a>;

    fn into_iter(self) -> Iter<'a> {
        self.iter()
    }
}

pub struct Iter<'a> {
    iovecs: ::std::vec::IntoIter<libc::iovec>,
    _marker: ::std::marker::PhantomData<&'a ()>,
}

impl<'a> Iterator for Iter<'a> {
    type Item = &'a [u8];

    fn next(&mut self) -> Option<&'a [u8]> {
        self.iovecs.next().map(|iovec| unsafe {
            ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len)
        })
    }
}

/// Selected batch page is 256 Kbytes long to load ruleset of half a million rules without hitting
/// -EMSGSIZE due to large iovec.
pub fn default_batch_page_size() -> u32 {
    unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u32 * 32 }
}