From 1e33e3ab0790d977add329e9686b4b9e5570ba3c Mon Sep 17 00:00:00 2001 From: lafleur Date: Sun, 7 Nov 2021 14:19:37 +0100 Subject: call bindgen on build in rustables --- rustables/Cargo.toml | 11 +- rustables/build.rs | 132 ++++++- rustables/examples/add-rules.rs | 2 +- rustables/examples/filter-ethernet.rs | 2 +- rustables/src/batch.rs | 4 +- rustables/src/chain.rs | 2 +- rustables/src/expr/bitwise.rs | 2 +- rustables/src/expr/cmp.rs | 2 +- rustables/src/expr/counter.rs | 2 +- rustables/src/expr/ct.rs | 2 +- rustables/src/expr/immediate.rs | 2 +- rustables/src/expr/log.rs | 2 +- rustables/src/expr/lookup.rs | 2 +- rustables/src/expr/masquerade.rs | 2 +- rustables/src/expr/meta.rs | 2 +- rustables/src/expr/mod.rs | 2 +- rustables/src/expr/nat.rs | 2 +- rustables/src/expr/payload.rs | 2 +- rustables/src/expr/register.rs | 2 +- rustables/src/expr/reject.rs | 5 +- rustables/src/expr/verdict.rs | 5 +- rustables/src/expr/wrapper.rs | 6 +- rustables/src/lib.rs | 7 +- rustables/src/query.rs | 2 +- rustables/src/rule.rs | 2 +- rustables/src/set.rs | 2 +- rustables/src/table.rs | 2 +- rustables/src/tests/mod.rs | 702 ---------------------------------- rustables/tests/expr.rs | 700 +++++++++++++++++++++++++++++++++ rustables/tests_wrapper.h | 1 + rustables/wrapper.h | 13 +- 31 files changed, 865 insertions(+), 761 deletions(-) delete mode 100644 rustables/src/tests/mod.rs create mode 100644 rustables/tests/expr.rs create mode 100644 rustables/tests_wrapper.h (limited to 'rustables') diff --git a/rustables/Cargo.toml b/rustables/Cargo.toml index d948055..5ac819b 100644 --- a/rustables/Cargo.toml +++ b/rustables/Cargo.toml @@ -13,20 +13,12 @@ edition = "2018" [features] query = [] unsafe-raw-handles = [] -nftnl-1-0-7 = ["rustables-sys/nftnl-1-0-7"] -nftnl-1-0-8 = ["rustables-sys/nftnl-1-0-8"] -nftnl-1-0-9 = ["rustables-sys/nftnl-1-0-9"] -nftnl-1-1-0 = ["rustables-sys/nftnl-1-1-0"] -nftnl-1-1-1 = ["rustables-sys/nftnl-1-1-1"] -nftnl-1-1-2 = ["rustables-sys/nftnl-1-1-2"] -nftnl-1-2-0 = ["rustables-sys/nftnl-1-2-0"] -default = ["nftnl-1-2-0"] [dependencies] bitflags = "1.0" thiserror = "1.0" log = "0.4" -rustables-sys = { path = "../rustables-sys", version = "0.7" } +libc = "0.2.43" mnl = "0.2" [dev-dependencies] @@ -34,6 +26,7 @@ ipnetwork = "0.16" [build-dependencies] bindgen = "0.53.1" +pkg-config = "0.3" regex = "1.5.4" lazy_static = "1.4.0" diff --git a/rustables/build.rs b/rustables/build.rs index fee34ff..180e06b 100644 --- a/rustables/build.rs +++ b/rustables/build.rs @@ -1,13 +1,64 @@ use bindgen; - +use lazy_static::lazy_static; +use regex::{Captures, Regex}; +use pkg_config; +use std::env; use std::fs::File; use std::io::Write; use std::path::PathBuf; use std::borrow::Cow; -use lazy_static::lazy_static; -use regex::{Captures, Regex}; +const SYS_HEADER_FILE: &str = "wrapper.h"; +const SYS_BINDINGS_FILE: &str = "src/sys.rs"; +const TESTS_HEADER_FILE: &str = "tests_wrapper.h"; +const TESTS_BINDINGS_FILE: &str = "tests/sys.rs"; +const MIN_LIBNFTNL_VERSION: &str = "1.0.6"; + + +fn get_env(var: &'static str) -> Option { + println!("cargo:rerun-if-env-changed={}", var); + env::var_os(var).map(PathBuf::from) +} + +/// Set env vars to help rustc find linked libraries. +fn setup_libs() { + if let Some(lib_dir) = get_env("LIBNFTNL_LIB_DIR") { + if !lib_dir.is_dir() { + panic!( + "libnftnl library directory does not exist: {}", + lib_dir.display() + ); + } + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + println!("cargo:rustc-link-lib=nftnl"); + } else { + // Trying with pkg-config instead + println!("Minimum libnftnl version: {}", MIN_LIBNFTNL_VERSION); + pkg_config::Config::new() + .atleast_version(MIN_LIBNFTNL_VERSION) + .probe("libnftnl") + .unwrap(); + } + + if let Some(lib_dir) = get_env("LIBMNL_LIB_DIR") { + if !lib_dir.is_dir() { + panic!( + "libmnl library directory does not exist: {}", + lib_dir.display() + ); + } + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + println!("cargo:rustc-link-lib=mnl"); + } else { + // Trying with pkg-config instead + pkg_config::Config::new() + .atleast_version("1.0.0") + .probe("libmnl") + .unwrap(); + } +} + /// Recast nft_*_attributes from u32 to u16 in header file `before`. fn reformat_units(before: &str) -> Cow { lazy_static! { @@ -18,31 +69,92 @@ fn reformat_units(before: &str) -> Cow { }) } -fn main() { +fn generate_consts() { // Tell cargo to invalidate the built crate whenever the headers change. - println!("cargo:rerun-if-changed=wrapper.h"); + println!("cargo:rerun-if-changed={}", SYS_HEADER_FILE); let bindings = bindgen::Builder::default() - .header("wrapper.h") + .header(SYS_HEADER_FILE) .generate_comments(false) .prepend_enum_name(false) + .use_core() + .whitelist_function("^nftnl_.+$") + .whitelist_type("^nftnl_.+$") + .whitelist_var("^nftnl_.+$") + .whitelist_var("^NFTNL_.+$") + .blacklist_type("(FILE|iovec)") + .blacklist_type("^_IO_.+$") + .blacklist_type("^__.+$") + .blacklist_type("nlmsghdr") + .raw_line("#![allow(non_camel_case_types)]\n\n") + .raw_line("pub use libc;") + .raw_line("use libc::{c_char, c_int, c_ulong, c_void, iovec, nlmsghdr, FILE};") + .raw_line("use core::option::Option;") + .ctypes_prefix("libc") // Tell cargo to invalidate the built crate whenever any of the // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) // Finish the builder and generate the bindings. .generate() // Unwrap the Result and panic on failure. - .expect("Unable to generate bindings"); + .expect("Error: unable to generate bindings"); + + let mut s = bindings.to_string() + // Add newlines because in alpine bindgen doesn't add them after + // statements. + .replace(" ; ", ";\n") + .replace("#[derive(Debug, Copy, Clone)]", ""); + let re = Regex::new(r"libc::(c_[a-z]*)").unwrap(); + s = re.replace_all(&s, "$1").into(); + let re = Regex::new(r"::core::option::(Option)").unwrap(); + s = re.replace_all(&s, "$1").into(); + let re = Regex::new(r"_bindgen_ty_[0-9]+").unwrap(); + s = re.replace_all(&s, "u32").into(); + // Change struct bodies to c_void. + let re = Regex::new(r"(pub struct .*) \{\n *_unused: \[u8; 0\],\n\}\n").unwrap(); + s = re.replace_all(&s, "$1(c_void);\n").into(); + let re = Regex::new(r"pub type u32 = u32;\n").unwrap(); + s = re.replace_all(&s, "").into(); + + // Write the bindings to the rust header file. + let out_path = PathBuf::from(SYS_BINDINGS_FILE); + File::create(out_path) + .expect("Error: could not create rust header file.") + .write_all(&s.as_bytes()) + .expect("Error: could not write to the rust header file."); +} + +fn generate_test_consts() { + // Tell cargo to invalidate the built crate whenever the headers change. + println!("cargo:rerun-if-changed={}", TESTS_HEADER_FILE); + + let bindings = bindgen::Builder::default() + .header(TESTS_HEADER_FILE) + .generate_comments(false) + .prepend_enum_name(false) + .raw_line("#![allow(non_camel_case_types, dead_code)]\n\n") + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Error: unable to generate bindings needed for tests."); // Add newlines because in alpine bindgen doesn't add them after statements. let s = bindings.to_string().replace(" ; ", ";\n"); let s = reformat_units(&s); - let h = String::from("#![allow(non_camel_case_types, dead_code)]\n\n") + &s; // Write the bindings to the rust header file. - let out_path = PathBuf::from("src/tests/bindings.rs"); + let out_path = PathBuf::from(TESTS_BINDINGS_FILE); File::create(out_path) .expect("Error: could not create rust header file.") - .write_all(&h.as_bytes()) + .write_all(&s.as_bytes()) .expect("Error: could not write to the rust header file."); } + +fn main() { + setup_libs(); + generate_consts(); + generate_test_consts(); +} diff --git a/rustables/examples/add-rules.rs b/rustables/examples/add-rules.rs index 4fea491..3aae7ee 100644 --- a/rustables/examples/add-rules.rs +++ b/rustables/examples/add-rules.rs @@ -37,7 +37,7 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{nft_expr, rustables_sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; +use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; use std::{ ffi::{self, CString}, io, diff --git a/rustables/examples/filter-ethernet.rs b/rustables/examples/filter-ethernet.rs index 23be8a1..b16c49e 100644 --- a/rustables/examples/filter-ethernet.rs +++ b/rustables/examples/filter-ethernet.rs @@ -22,7 +22,7 @@ //! # nft delete table inet example-filter-ethernet //! ``` -use rustables::{nft_expr, rustables_sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; +use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; use std::{ffi::CString, io, rc::Rc}; const TABLE_NAME: &str = "example-filter-ethernet"; diff --git a/rustables/src/batch.rs b/rustables/src/batch.rs index 3cdd52b..c8ec5aa 100644 --- a/rustables/src/batch.rs +++ b/rustables/src/batch.rs @@ -1,5 +1,5 @@ use crate::{MsgType, NlMsg}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self as sys, libc}; use std::ffi::c_void; use std::os::raw::c_char; use std::ptr; @@ -157,7 +157,7 @@ impl FinalizedBatch { }; num_pages ]; - let iovecs_ptr = iovecs.as_mut_ptr() as *mut [u8; 0]; + let iovecs_ptr = iovecs.as_mut_ptr(); unsafe { sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32); } diff --git a/rustables/src/chain.rs b/rustables/src/chain.rs index ac9c57d..20043ac 100644 --- a/rustables/src/chain.rs +++ b/rustables/src/chain.rs @@ -1,5 +1,5 @@ use crate::{MsgType, Table}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self as sys, libc}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ diff --git a/rustables/src/expr/bitwise.rs b/rustables/src/expr/bitwise.rs index a5d9343..59ef41b 100644 --- a/rustables/src/expr/bitwise.rs +++ b/rustables/src/expr/bitwise.rs @@ -1,5 +1,5 @@ use super::{Expression, Rule, ToSlice}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::ffi::c_void; use std::os::raw::c_char; diff --git a/rustables/src/expr/cmp.rs b/rustables/src/expr/cmp.rs index 747974d..384f0b4 100644 --- a/rustables/src/expr/cmp.rs +++ b/rustables/src/expr/cmp.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule, ToSlice}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::{ borrow::Cow, ffi::{c_void, CString}, diff --git a/rustables/src/expr/counter.rs b/rustables/src/expr/counter.rs index 099e7fa..71064df 100644 --- a/rustables/src/expr/counter.rs +++ b/rustables/src/expr/counter.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys as sys; +use crate::sys; use std::os::raw::c_char; /// A counter expression adds a counter to the rule that is incremented to count number of packets diff --git a/rustables/src/expr/ct.rs b/rustables/src/expr/ct.rs index 001aef8..7d6614c 100644 --- a/rustables/src/expr/ct.rs +++ b/rustables/src/expr/ct.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::os::raw::c_char; bitflags::bitflags! { diff --git a/rustables/src/expr/immediate.rs b/rustables/src/expr/immediate.rs index ff4ad04..0787e06 100644 --- a/rustables/src/expr/immediate.rs +++ b/rustables/src/expr/immediate.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Register, Rule, ToSlice}; -use rustables_sys as sys; +use crate::sys; use std::ffi::c_void; use std::os::raw::c_char; diff --git a/rustables/src/expr/log.rs b/rustables/src/expr/log.rs index db96ba9..5c06897 100644 --- a/rustables/src/expr/log.rs +++ b/rustables/src/expr/log.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys as sys; +use crate::sys; use std::ffi::{CStr, CString}; use std::os::raw::c_char; use thiserror::Error; diff --git a/rustables/src/expr/lookup.rs b/rustables/src/expr/lookup.rs index 7796b29..8e288a0 100644 --- a/rustables/src/expr/lookup.rs +++ b/rustables/src/expr/lookup.rs @@ -1,6 +1,6 @@ use super::{DeserializationError, Expression, Rule}; use crate::set::Set; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::ffi::{CStr, CString}; use std::os::raw::c_char; diff --git a/rustables/src/expr/masquerade.rs b/rustables/src/expr/masquerade.rs index 40565d5..c1a06de 100644 --- a/rustables/src/expr/masquerade.rs +++ b/rustables/src/expr/masquerade.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys as sys; +use crate::sys; use std::os::raw::c_char; /// Sets the source IP to that of the output interface. diff --git a/rustables/src/expr/meta.rs b/rustables/src/expr/meta.rs index 199f3d3..bf77774 100644 --- a/rustables/src/expr/meta.rs +++ b/rustables/src/expr/meta.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::os::raw::c_char; /// A meta expression refers to meta data associated with a packet. diff --git a/rustables/src/expr/mod.rs b/rustables/src/expr/mod.rs index b20a752..fbf49d6 100644 --- a/rustables/src/expr/mod.rs +++ b/rustables/src/expr/mod.rs @@ -9,7 +9,7 @@ use std::net::Ipv4Addr; use std::net::Ipv6Addr; use super::rule::Rule; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use thiserror::Error; mod bitwise; diff --git a/rustables/src/expr/nat.rs b/rustables/src/expr/nat.rs index 51f439f..8beaa30 100644 --- a/rustables/src/expr/nat.rs +++ b/rustables/src/expr/nat.rs @@ -1,6 +1,6 @@ use super::{DeserializationError, Expression, Register, Rule}; use crate::ProtoFamily; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::{convert::TryFrom, os::raw::c_char}; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] diff --git a/rustables/src/expr/payload.rs b/rustables/src/expr/payload.rs index 334c939..7612fd9 100644 --- a/rustables/src/expr/payload.rs +++ b/rustables/src/expr/payload.rs @@ -1,5 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::os::raw::c_char; pub trait HeaderField { diff --git a/rustables/src/expr/register.rs b/rustables/src/expr/register.rs index 2cfcc3b..f0aed94 100644 --- a/rustables/src/expr/register.rs +++ b/rustables/src/expr/register.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use rustables_sys::libc; +use crate::sys::libc; use super::DeserializationError; diff --git a/rustables/src/expr/reject.rs b/rustables/src/expr/reject.rs index 550a287..2ea0cbf 100644 --- a/rustables/src/expr/reject.rs +++ b/rustables/src/expr/reject.rs @@ -1,9 +1,6 @@ use super::{DeserializationError, Expression, Rule}; use crate::ProtoFamily; -use rustables_sys::{ - self as sys, - libc::{self, c_char}, -}; +use crate::sys::{self, libc::{self, c_char}}; /// A reject expression that defines the type of rejection message sent /// when discarding a packet. diff --git a/rustables/src/expr/verdict.rs b/rustables/src/expr/verdict.rs index 6a6b802..3c4c374 100644 --- a/rustables/src/expr/verdict.rs +++ b/rustables/src/expr/verdict.rs @@ -1,8 +1,5 @@ use super::{DeserializationError, Expression, Rule}; -use rustables_sys::{ - self as sys, - libc::{self, c_char}, -}; +use crate::sys::{self, libc::{self, c_char}}; use std::ffi::{CStr, CString}; /// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl diff --git a/rustables/src/expr/wrapper.rs b/rustables/src/expr/wrapper.rs index 5162c21..1bcc520 100644 --- a/rustables/src/expr/wrapper.rs +++ b/rustables/src/expr/wrapper.rs @@ -3,10 +3,8 @@ use std::ffi::CString; use std::fmt::Debug; use std::rc::Rc; -use super::DeserializationError; -use super::Expression; -use crate::Rule; -use rustables_sys as sys; +use super::{DeserializationError, Expression}; +use crate::{sys, Rule}; pub struct ExpressionWrapper { pub(crate) expr: *const sys::nftnl_expr, diff --git a/rustables/src/lib.rs b/rustables/src/lib.rs index d94c753..6eedf9f 100644 --- a/rustables/src/lib.rs +++ b/rustables/src/lib.rs @@ -76,8 +76,8 @@ use thiserror::Error; #[macro_use] extern crate log; -pub use rustables_sys; -use rustables_sys::libc; +pub mod sys; +use sys::libc; use std::{convert::TryFrom, ffi::c_void, ops::Deref}; macro_rules! try_alloc { @@ -118,9 +118,6 @@ pub use rule::{get_rules_cb, list_rules_for_chain}; pub mod set; -#[cfg(test)] -mod tests; - /// The type of the message as it's sent to netfilter. A message consists of an object, such as a /// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with /// that object. If a [`Table`] object is sent with `MsgType::Add` then that table will be added diff --git a/rustables/src/query.rs b/rustables/src/query.rs index 8d2c281..02c4082 100644 --- a/rustables/src/query.rs +++ b/rustables/src/query.rs @@ -1,4 +1,4 @@ -use crate::{nft_nlmsg_maxsize, rustables_sys as sys, ProtoFamily}; +use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; use sys::libc; /// Returns a buffer containing a netlink message which requests a list of all the netfilter diff --git a/rustables/src/rule.rs b/rustables/src/rule.rs index b315daf..c8cb90d 100644 --- a/rustables/src/rule.rs +++ b/rustables/src/rule.rs @@ -1,6 +1,6 @@ use crate::expr::ExpressionWrapper; use crate::{chain::Chain, expr::Expression, MsgType}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; use std::os::raw::c_char; diff --git a/rustables/src/set.rs b/rustables/src/set.rs index aef74db..d6b9514 100644 --- a/rustables/src/set.rs +++ b/rustables/src/set.rs @@ -1,5 +1,5 @@ use crate::{table::Table, MsgType, ProtoFamily}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; use std::{ cell::Cell, ffi::{c_void, CStr, CString}, diff --git a/rustables/src/table.rs b/rustables/src/table.rs index 7cc475f..2f21453 100644 --- a/rustables/src/table.rs +++ b/rustables/src/table.rs @@ -1,5 +1,5 @@ use crate::{MsgType, ProtoFamily}; -use rustables_sys::{self as sys, libc}; +use crate::sys::{self, libc}; #[cfg(feature = "query")] use std::convert::TryFrom; use std::{ diff --git a/rustables/src/tests/mod.rs b/rustables/src/tests/mod.rs deleted file mode 100644 index 9a76606..0000000 --- a/rustables/src/tests/mod.rs +++ /dev/null @@ -1,702 +0,0 @@ -use crate::expr::{ - Bitwise, Cmp, CmpOp, Conntrack, Counter, Expression, HeaderField, IcmpCode, Immediate, Log, - LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, - TransportHeaderField, -}; -use crate::set::Set; -use crate::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Table}; -use rustables_sys::libc::{nlmsghdr, AF_UNIX, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; -use std::ffi::{c_void, CStr}; -use std::mem::size_of; -use std::net::Ipv4Addr; -use std::rc::Rc; -use thiserror::Error; - -mod bindings; -use bindings::*; - -fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { - ((x & 0xff00) >> 8) as u8 -} -fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { - (x & 0x00ff) as u8 -} - -const TABLE_NAME: &[u8; 10] = b"mocktable\0"; -const CHAIN_NAME: &[u8; 10] = b"mockchain\0"; - -type NetLinkType = u16; - -#[derive(Debug, PartialEq)] -enum NetlinkExpr { - Nested(NetLinkType, Vec), - Final(NetLinkType, Vec), - List(Vec), -} - -#[derive(Debug, Error)] -#[error("empty data")] -struct EmptyDataError; - -impl NetlinkExpr { - fn to_raw(self) -> Vec { - match self { - NetlinkExpr::Final(ty, val) => { - let len = val.len() + 4; - let mut res = Vec::with_capacity(len); - - res.extend(&(len as u16).to_le_bytes()); - res.extend(&ty.to_le_bytes()); - res.extend(val); - // alignment - while res.len() % 4 != 0 { - res.push(0); - } - - res - } - NetlinkExpr::Nested(ty, exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut sub = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - sub.append(&mut expr.to_raw()); - } - - let len = sub.len() + 4; - let mut res = Vec::with_capacity(len); - - // set the "NESTED" flag - res.extend(&(len as u16).to_le_bytes()); - res.extend(&(ty | 0x8000).to_le_bytes()); - res.extend(sub); - - res - } - NetlinkExpr::List(exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut list = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - list.append(&mut expr.to_raw()); - } - - list - } - } - } -} - -#[repr(C)] -#[derive(Clone, Copy)] -struct Nfgenmsg { - family: u8, /* AF_xxx */ - version: u8, /* nfnetlink version */ - res_id: u16, /* resource id */ -} - -fn get_test_rule() -> Rule { - let table = Rc::new(Table::new( - &CStr::from_bytes_with_nul(TABLE_NAME).unwrap(), - ProtoFamily::Inet, - )); - let chain = Rc::new(Chain::new( - &CStr::from_bytes_with_nul(CHAIN_NAME).unwrap(), - Rc::clone(&table), - )); - let rule = Rule::new(Rc::clone(&chain)); - rule -} - -fn get_test_nlmsg_from_expr( - rule: &mut Rule, - expr: &impl Expression, -) -> (nlmsghdr, Nfgenmsg, Vec) { - rule.add_expr(expr); - - let mut buf = vec![0u8; nft_nlmsg_maxsize() as usize]; - unsafe { - rule.write(buf.as_mut_ptr() as *mut c_void, 0, MsgType::Add); - - // right now the message is composed of the following parts: - // - nlmsghdr (contains the message size and type) - // - nfgenmsg (nftables header that describes the message family) - // - the raw expression that we want to validate - - let size_of_hdr = size_of::(); - let size_of_nfgenmsg = size_of::(); - let nlmsghdr = *(buf[0..size_of_hdr].as_ptr() as *const nlmsghdr); - let nfgenmsg = - *(buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg); - let raw_expr = buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize] - .iter() - .map(|x| *x) - .collect(); - - // sanity checks on the global message (this should be very similar/factorisable for the - // most part in other tests) - // TODO: check the messages flags - assert_eq!( - get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFNL_SUBSYS_NFTABLES as u8 - ); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_seq, 0); - assert_eq!(nlmsghdr.nlmsg_pid, 0); - assert_eq!(nfgenmsg.family, AF_UNIX as u8); - assert_eq!(nfgenmsg.version, NFNETLINK_V0 as u8); - assert_eq!(nfgenmsg.res_id.to_be(), 0); - - (nlmsghdr, nfgenmsg, raw_expr) - } -} - -#[test] -fn bitwise_expr_is_valid() { - let netmask = Ipv4Addr::new(255, 255, 255, 0); - let bitwise = Bitwise::new(netmask, 0); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &bitwise); - assert_eq!(nlmsghdr.nlmsg_len, 124); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_BITWISE_SREG, - NFT_REG_1.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_BITWISE_DREG, - NFT_REG_1.to_be_bytes().to_vec() - ), - NetlinkExpr::Final(NFTA_BITWISE_LEN, 4u32.to_be_bytes().to_vec()), - NetlinkExpr::Nested( - NFTA_BITWISE_MASK, - vec![NetlinkExpr::Final( - NFTA_DATA_VALUE, - vec![255, 255, 255, 0] - )] - ), - NetlinkExpr::Nested( - NFTA_BITWISE_XOR, - vec![NetlinkExpr::Final( - NFTA_DATA_VALUE, - 0u32.to_be_bytes().to_vec() - )] - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn cmp_expr_is_valid() { - let cmp = Cmp::new(CmpOp::Eq, 0); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &cmp); - assert_eq!(nlmsghdr.nlmsg_len, 100); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final(NFTA_CMP_SREG, NFT_REG_1.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()), - NetlinkExpr::Nested( - NFTA_CMP_DATA, - vec![NetlinkExpr::Final(1u16, 0u32.to_be_bytes().to_vec())] - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn counter_expr_is_valid() { - let nb_bytes = 123456u64; - let nb_packets = 987u64; - let mut counter = Counter::new(); - counter.nb_bytes = nb_bytes; - counter.nb_packets = nb_packets; - - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &counter); - assert_eq!(nlmsghdr.nlmsg_len, 100); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_COUNTER_BYTES, - nb_bytes.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_COUNTER_PACKETS, - nb_packets.to_be_bytes().to_vec() - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn ct_expr_is_valid() { - let ct = Conntrack::State; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &ct); - assert_eq!(nlmsghdr.nlmsg_len, 88); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_CT_KEY, - NFT_CT_STATE.to_be_bytes().to_vec() - ), - NetlinkExpr::Final(NFTA_CT_DREG, NFT_REG_1.to_be_bytes().to_vec()) - ] - ) - ] - )] - ) - ]) - .to_raw() - ) -} - -#[test] -fn immediate_expr_is_valid() { - let immediate = Immediate::new(42u8, Register::Reg1); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &immediate); - assert_eq!(nlmsghdr.nlmsg_len, 100); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_IMMEDIATE_DREG, - NFT_REG_1.to_be_bytes().to_vec() - ), - NetlinkExpr::Nested( - NFTA_IMMEDIATE_DATA, - vec![NetlinkExpr::Final(1u16, 42u8.to_be_bytes().to_vec())] - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn log_expr_is_valid() { - let log = Log { - group: Some(LogGroup(1)), - prefix: Some(LogPrefix::new("mockprefix").unwrap()), - }; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &log); - assert_eq!(nlmsghdr.nlmsg_len, 96); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"log\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix\0".to_vec()), - NetlinkExpr::Final(NFTA_LOG_GROUP, 1u16.to_be_bytes().to_vec()) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn lookup_expr_is_valid() { - let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); - let mut rule = get_test_rule(); - let table = rule.get_chain().get_table(); - let mut set = Set::new(set_name, 0, &table, ProtoFamily::Inet); - let address: Ipv4Addr = [8, 8, 8, 8].into(); - set.add(&address); - let lookup = Lookup::new(&set).unwrap(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); - assert_eq!(nlmsghdr.nlmsg_len, 104); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_LOOKUP_SREG, - NFT_REG_1.to_be_bytes().to_vec() - ), - NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), - NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -use crate::expr::Masquerade; -#[test] -fn masquerade_expr_is_valid() { - let masquerade = Masquerade; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &masquerade); - assert_eq!(nlmsghdr.nlmsg_len, 76); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq\0".to_vec()), - NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]), - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn meta_expr_is_valid() { - let meta = Meta::Protocol; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &meta); - assert_eq!(nlmsghdr.nlmsg_len, 92); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_META_KEY, - NFT_META_PROTOCOL.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_META_DREG, - NFT_REG_1.to_be_bytes().to_vec() - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn nat_expr_is_valid() { - let nat = Nat { - nat_type: NatType::SNat, - family: ProtoFamily::Ipv4, - ip_register: Register::Reg1, - port_register: None, - }; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &nat); - assert_eq!(nlmsghdr.nlmsg_len, 96); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_NAT_TYPE, - NFT_NAT_SNAT.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_NAT_FAMILY, - // TODO find the right value to substitute here. - //(ProtoFamily::Ipv4 as u16).to_le_bytes().to_vec() - 2u32.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_NAT_REG_ADDR_MIN, - NFT_REG_1.to_be_bytes().to_vec() - ) - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn payload_expr_is_valid() { - // TODO test loaded payload ? - let tcp_header_field = TcpHeaderField::Sport; - let transport_header_field = TransportHeaderField::Tcp(tcp_header_field); - let payload = Payload::Transport(transport_header_field); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &payload); - assert_eq!(nlmsghdr.nlmsg_len, 108); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_PAYLOAD_DREG, - NFT_REG_1.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_PAYLOAD_BASE, - NFT_PAYLOAD_TRANSPORT_HEADER.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_PAYLOAD_OFFSET, - tcp_header_field.offset().to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_PAYLOAD_LEN, - //tcp_header_field.len().to_be_bytes().to_vec() - 0u32.to_be_bytes().to_vec() - ), - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -#[test] -fn reject_expr_is_valid() { - let code = IcmpCode::NoRoute; - let reject = Reject::Icmp(code); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &reject); - assert_eq!(nlmsghdr.nlmsg_len, 92); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject\0".to_vec()), - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_REJECT_TYPE, - NFT_REJECT_ICMPX_UNREACH.to_be_bytes().to_vec() - ), - NetlinkExpr::Final( - NFTA_REJECT_ICMP_CODE, - (code as u8).to_be_bytes().to_vec() - ), - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} - -use crate::expr::Verdict; -use rustables_sys::libc::NF_DROP; -#[test] -fn verdict_expr_is_valid() { - let verdict = Verdict::Drop; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &verdict); - assert_eq!(nlmsghdr.nlmsg_len, 104); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Nested( - NFTA_RULE_EXPRESSIONS, - vec![NetlinkExpr::Nested( - NFTA_LIST_ELEM, - vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), - // TODO find the right arrangement for Verdict's data. - NetlinkExpr::Nested( - NFTA_EXPR_DATA, - vec![ - NetlinkExpr::Final( - NFTA_IMMEDIATE_DREG, - NFT_REG_VERDICT.to_be_bytes().to_vec() - ), - NetlinkExpr::Nested( - NFTA_IMMEDIATE_DATA, - vec![NetlinkExpr::Nested( - NFTA_DATA_VERDICT, - vec![NetlinkExpr::Final( - NFTA_VERDICT_CODE, - NF_DROP.to_be_bytes().to_vec() //0u32.to_be_bytes().to_vec() - ),] - )], - ), - ] - ) - ] - )] - ) - ]) - .to_raw() - ); -} diff --git a/rustables/tests/expr.rs b/rustables/tests/expr.rs new file mode 100644 index 0000000..5c27119 --- /dev/null +++ b/rustables/tests/expr.rs @@ -0,0 +1,700 @@ +use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Table}; +use rustables::set::Set; +use rustables::sys::libc::{nlmsghdr, AF_UNIX, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES, NF_DROP}; +use rustables::expr::{ + Bitwise, Cmp, CmpOp, Conntrack, Counter, Expression, HeaderField, IcmpCode, Immediate, Log, + LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, + TransportHeaderField, Verdict +}; +use std::ffi::{c_void, CStr}; +use std::mem::size_of; +use std::net::Ipv4Addr; +use std::rc::Rc; +use thiserror::Error; + +mod sys; +use sys::*; + +fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} +fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +const TABLE_NAME: &[u8; 10] = b"mocktable\0"; +const CHAIN_NAME: &[u8; 10] = b"mockchain\0"; + +type NetLinkType = u16; + +#[derive(Debug, PartialEq)] +enum NetlinkExpr { + Nested(NetLinkType, Vec), + Final(NetLinkType, Vec), + List(Vec), +} + +#[derive(Debug, Error)] +#[error("empty data")] +struct EmptyDataError; + +impl NetlinkExpr { + fn to_raw(self) -> Vec { + match self { + NetlinkExpr::Final(ty, val) => { + let len = val.len() + 4; + let mut res = Vec::with_capacity(len); + + res.extend(&(len as u16).to_le_bytes()); + res.extend(&ty.to_le_bytes()); + res.extend(val); + // alignment + while res.len() % 4 != 0 { + res.push(0); + } + + res + } + NetlinkExpr::Nested(ty, exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut sub = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + sub.append(&mut expr.to_raw()); + } + + let len = sub.len() + 4; + let mut res = Vec::with_capacity(len); + + // set the "NESTED" flag + res.extend(&(len as u16).to_le_bytes()); + res.extend(&(ty | 0x8000).to_le_bytes()); + res.extend(sub); + + res + } + NetlinkExpr::List(exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut list = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + list.append(&mut expr.to_raw()); + } + + list + } + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +struct Nfgenmsg { + family: u8, /* AF_xxx */ + version: u8, /* nfnetlink version */ + res_id: u16, /* resource id */ +} + +fn get_test_rule() -> Rule { + let table = Rc::new(Table::new( + &CStr::from_bytes_with_nul(TABLE_NAME).unwrap(), + ProtoFamily::Inet, + )); + let chain = Rc::new(Chain::new( + &CStr::from_bytes_with_nul(CHAIN_NAME).unwrap(), + Rc::clone(&table), + )); + let rule = Rule::new(Rc::clone(&chain)); + rule +} + +fn get_test_nlmsg_from_expr( + rule: &mut Rule, + expr: &impl Expression, +) -> (nlmsghdr, Nfgenmsg, Vec) { + rule.add_expr(expr); + + let mut buf = vec![0u8; nft_nlmsg_maxsize() as usize]; + unsafe { + rule.write(buf.as_mut_ptr() as *mut c_void, 0, MsgType::Add); + + // right now the message is composed of the following parts: + // - nlmsghdr (contains the message size and type) + // - nfgenmsg (nftables header that describes the message family) + // - the raw expression that we want to validate + + let size_of_hdr = size_of::(); + let size_of_nfgenmsg = size_of::(); + let nlmsghdr = *(buf[0..size_of_hdr].as_ptr() as *const nlmsghdr); + let nfgenmsg = + *(buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg); + let raw_expr = buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize] + .iter() + .map(|x| *x) + .collect(); + + // sanity checks on the global message (this should be very similar/factorisable for the + // most part in other tests) + // TODO: check the messages flags + assert_eq!( + get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFNL_SUBSYS_NFTABLES as u8 + ); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_seq, 0); + assert_eq!(nlmsghdr.nlmsg_pid, 0); + assert_eq!(nfgenmsg.family, AF_UNIX as u8); + assert_eq!(nfgenmsg.version, NFNETLINK_V0 as u8); + assert_eq!(nfgenmsg.res_id.to_be(), 0); + + (nlmsghdr, nfgenmsg, raw_expr) + } +} + +#[test] +fn bitwise_expr_is_valid() { + let netmask = Ipv4Addr::new(255, 255, 255, 0); + let bitwise = Bitwise::new(netmask, 0); + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &bitwise); + assert_eq!(nlmsghdr.nlmsg_len, 124); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_BITWISE_SREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_BITWISE_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_BITWISE_LEN, 4u32.to_be_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_BITWISE_MASK, + vec![NetlinkExpr::Final( + NFTA_DATA_VALUE, + vec![255, 255, 255, 0] + )] + ), + NetlinkExpr::Nested( + NFTA_BITWISE_XOR, + vec![NetlinkExpr::Final( + NFTA_DATA_VALUE, + 0u32.to_be_bytes().to_vec() + )] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn cmp_expr_is_valid() { + let cmp = Cmp::new(CmpOp::Eq, 0); + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &cmp); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_CMP_SREG, NFT_REG_1.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_CMP_DATA, + vec![NetlinkExpr::Final(1u16, 0u32.to_be_bytes().to_vec())] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn counter_expr_is_valid() { + let nb_bytes = 123456u64; + let nb_packets = 987u64; + let mut counter = Counter::new(); + counter.nb_bytes = nb_bytes; + counter.nb_packets = nb_packets; + + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &counter); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_COUNTER_BYTES, + nb_bytes.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_COUNTER_PACKETS, + nb_packets.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn ct_expr_is_valid() { + let ct = Conntrack::State; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &ct); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_CT_KEY, + NFT_CT_STATE.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_CT_DREG, NFT_REG_1.to_be_bytes().to_vec()) + ] + ) + ] + )] + ) + ]) + .to_raw() + ) +} + +#[test] +fn immediate_expr_is_valid() { + let immediate = Immediate::new(42u8, Register::Reg1); + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &immediate); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_IMMEDIATE_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Nested( + NFTA_IMMEDIATE_DATA, + vec![NetlinkExpr::Final(1u16, 42u8.to_be_bytes().to_vec())] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn log_expr_is_valid() { + let log = Log { + group: Some(LogGroup(1)), + prefix: Some(LogPrefix::new("mockprefix").unwrap()), + }; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &log); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"log\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix\0".to_vec()), + NetlinkExpr::Final(NFTA_LOG_GROUP, 1u16.to_be_bytes().to_vec()) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn lookup_expr_is_valid() { + let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); + let mut rule = get_test_rule(); + let table = rule.get_chain().get_table(); + let mut set = Set::new(set_name, 0, &table, ProtoFamily::Inet); + let address: Ipv4Addr = [8, 8, 8, 8].into(); + set.add(&address); + let lookup = Lookup::new(&set).unwrap(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); + assert_eq!(nlmsghdr.nlmsg_len, 104); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_LOOKUP_SREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), + NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +use rustables::expr::Masquerade; +#[test] +fn masquerade_expr_is_valid() { + let masquerade = Masquerade; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &masquerade); + assert_eq!(nlmsghdr.nlmsg_len, 76); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq\0".to_vec()), + NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]), + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn meta_expr_is_valid() { + let meta = Meta::Protocol; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &meta); + assert_eq!(nlmsghdr.nlmsg_len, 92); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_META_KEY, + NFT_META_PROTOCOL.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_META_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn nat_expr_is_valid() { + let nat = Nat { + nat_type: NatType::SNat, + family: ProtoFamily::Ipv4, + ip_register: Register::Reg1, + port_register: None, + }; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &nat); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_NAT_TYPE, + NFT_NAT_SNAT.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_NAT_FAMILY, + // TODO find the right value to substitute here. + //(ProtoFamily::Ipv4 as u16).to_le_bytes().to_vec() + 2u32.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_NAT_REG_ADDR_MIN, + NFT_REG_1.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn payload_expr_is_valid() { + // TODO test loaded payload ? + let tcp_header_field = TcpHeaderField::Sport; + let transport_header_field = TransportHeaderField::Tcp(tcp_header_field); + let payload = Payload::Transport(transport_header_field); + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &payload); + assert_eq!(nlmsghdr.nlmsg_len, 108); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_PAYLOAD_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_BASE, + NFT_PAYLOAD_TRANSPORT_HEADER.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_OFFSET, + tcp_header_field.offset().to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_LEN, + //tcp_header_field.len().to_be_bytes().to_vec() + 0u32.to_be_bytes().to_vec() + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn reject_expr_is_valid() { + let code = IcmpCode::NoRoute; + let reject = Reject::Icmp(code); + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &reject); + assert_eq!(nlmsghdr.nlmsg_len, 92); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_REJECT_TYPE, + NFT_REJECT_ICMPX_UNREACH.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_REJECT_ICMP_CODE, + (code as u8).to_be_bytes().to_vec() + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn verdict_expr_is_valid() { + let verdict = Verdict::Drop; + let mut rule = get_test_rule(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &verdict); + assert_eq!(nlmsghdr.nlmsg_len, 104); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), + // TODO find the right arrangement for Verdict's data. + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_IMMEDIATE_DREG, + NFT_REG_VERDICT.to_be_bytes().to_vec() + ), + NetlinkExpr::Nested( + NFTA_IMMEDIATE_DATA, + vec![NetlinkExpr::Nested( + NFTA_DATA_VERDICT, + vec![NetlinkExpr::Final( + NFTA_VERDICT_CODE, + NF_DROP.to_be_bytes().to_vec() //0u32.to_be_bytes().to_vec() + ),] + )], + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} diff --git a/rustables/tests_wrapper.h b/rustables/tests_wrapper.h new file mode 100644 index 0000000..8f976e8 --- /dev/null +++ b/rustables/tests_wrapper.h @@ -0,0 +1 @@ +#include "linux/netfilter/nf_tables.h" diff --git a/rustables/wrapper.h b/rustables/wrapper.h index 8f976e8..e6eb221 100644 --- a/rustables/wrapper.h +++ b/rustables/wrapper.h @@ -1 +1,12 @@ -#include "linux/netfilter/nf_tables.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -- cgit v1.2.3