From 22edb0197854bf4f504e833e69b0e545d382f065 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Sun, 13 Nov 2022 18:32:22 +0100 Subject: wip: exprs --- src/expr/log.rs | 112 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 57 insertions(+), 55 deletions(-) (limited to 'src/expr/log.rs') diff --git a/src/expr/log.rs b/src/expr/log.rs index 8d20b48..cf50cb2 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,21 +1,61 @@ -use super::{DeserializationError, Expression, Rule}; +use super::{Expression, ExpressionError}; +use crate::create_expr_type; +use crate::nlmsg::NfNetlinkAttributes; use crate::sys; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; -use thiserror::Error; -/// A Log expression will log all packets that match the rule. -#[derive(Debug, PartialEq)] -pub struct Log { - pub group: Option, - pub prefix: Option, +// A Log expression will log all packets that match the rule. +create_expr_type!( + inline with_builder : Log, + [ + ( + get_group, + set_group, + with_group, + sys::NFTA_LOG_GROUP, + U32, + u32 + ), + ( + get_prefix, + set_prefix, + with_prefix, + sys::NFTA_LOG_PREFIX, + String, + String + ) + ] +); + +impl Log { + pub fn new( + group: Option, + prefix: Option>, + ) -> Result { + let mut res = Log { + inner: NfNetlinkAttributes::new(), + //pub group: Option, + //pub prefix: Option, + }; + if let Some(group) = group { + res.set_group(group); + } + if let Some(prefix) = prefix { + let prefix = prefix.into(); + + if prefix.bytes().count() > 127 { + return Err(ExpressionError::TooLongLogPrefix); + } + res.set_prefix(prefix); + } + Ok(res) + } } impl Expression for Log { - fn get_raw_name() -> *const sys::libc::c_char { - b"log\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "log" } - + /* fn from_expr(expr: *const sys::nftnl_expr) -> Result where Self: Sized, @@ -54,59 +94,21 @@ impl Expression for Log { expr } } -} - -#[derive(Error, Debug)] -pub enum LogPrefixError { - #[error("The log prefix string is more than 128 characters long")] - TooLongPrefix, - #[error("The log prefix string contains an invalid Nul character.")] - PrefixContainsANul(#[from] std::ffi::NulError), -} - -/// The NFLOG group that will be assigned to each log line. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct LogGroup(pub u16); - -/// A prefix that will get prepended to each log line. -#[derive(Debug, Clone, PartialEq)] -pub struct LogPrefix(CString); - -impl LogPrefix { - /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that - /// LogPrefix should not be more than 127 characters long. - pub fn new(prefix: &str) -> Result { - if prefix.chars().count() > 127 { - return Err(LogPrefixError::TooLongPrefix); - } - Ok(LogPrefix(CString::new(prefix)?)) - } + */ } #[macro_export] macro_rules! nft_expr_log { (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log { - group: $group, - prefix: $prefix, - } + $crate::expr::Log::new(Some($group), Some($prefix)) }; (prefix $prefix:expr) => { - $crate::expr::Log { - group: None, - prefix: $prefix, - } + $crate::expr::Log::new(None, Some($prefix)) }; (group $group:ident) => { - $crate::expr::Log { - group: $group, - prefix: None, - } + $crate::expr::Log::new(Some($group), None) }; () => { - $crate::expr::Log { - group: None, - prefix: None, - } + $crate::expr::Log::new(None, None) }; } -- cgit v1.2.3 From 9ff02d4e40113ae10b6244a8a3d94c6e0bad5427 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Fri, 2 Dec 2022 23:58:52 +0100 Subject: refactor to remove the enum AttributeType --- Cargo.nix | 150 +++++++++----------- examples/add-rules.rs | 50 ++----- include/tests_wrapper.h | 0 src/chain.rs | 133 ++++++++---------- src/chain_methods.rs | 40 ------ src/expr/bitwise.rs | 39 +++--- src/expr/immediate.rs | 14 +- src/expr/log.rs | 73 +--------- src/expr/meta.rs | 230 +++++++++++------------------- src/expr/mod.rs | 175 +++++++---------------- src/expr/verdict.rs | 18 +-- src/lib.rs | 24 ++-- src/nlmsg.rs | 17 +-- src/parser.rs | 366 ++++++++++++++++++++++++++---------------------- src/query.rs | 15 +- src/rule.rs | 191 +++++++------------------ src/table.rs | 88 +++++------- tests/batch.rs | 24 +--- tests/chain.rs | 15 +- tests/common.rs | 199 ++++++++++++++++++++++++++ tests/expr.rs | 97 +++++++------ tests/lib.rs | 199 -------------------------- tests/rule.rs | 201 +++++++++++++------------- tests/table.rs | 7 +- 24 files changed, 997 insertions(+), 1368 deletions(-) delete mode 100644 include/tests_wrapper.h delete mode 100644 src/chain_methods.rs create mode 100644 tests/common.rs delete mode 100644 tests/lib.rs (limited to 'src/expr/log.rs') diff --git a/Cargo.nix b/Cargo.nix index d5647cb..667e886 100644 --- a/Cargo.nix +++ b/Cargo.nix @@ -204,7 +204,7 @@ rec { } { name = "env_logger"; - packageId = "env_logger"; + packageId = "env_logger 0.7.1"; optional = true; } { @@ -459,7 +459,7 @@ rec { }; resolvedDefaultFeatures = [ "ansi_term" "atty" "color" "default" "strsim" "suggestions" "vec_map" ]; }; - "env_logger" = rec { + "env_logger 0.7.1" = rec { crateName = "env_logger"; version = "0.7.1"; edition = "2018"; @@ -475,7 +475,7 @@ rec { } { name = "humantime"; - packageId = "humantime"; + packageId = "humantime 1.3.0"; optional = true; } { @@ -503,6 +503,49 @@ rec { }; resolvedDefaultFeatures = [ "atty" "default" "humantime" "regex" "termcolor" ]; }; + "env_logger 0.9.3" = rec { + crateName = "env_logger"; + version = "0.9.3"; + edition = "2018"; + sha256 = "1rq0kqpa8my6i1qcyhfqrn1g9xr5fbkwwbd42nqvlzn9qibncbm1"; + dependencies = [ + { + name = "atty"; + packageId = "atty"; + optional = true; + } + { + name = "humantime"; + packageId = "humantime 2.1.0"; + optional = true; + } + { + name = "log"; + packageId = "log"; + features = [ "std" ]; + } + { + name = "regex"; + packageId = "regex"; + optional = true; + usesDefaultFeatures = false; + features = [ "std" "perf" ]; + } + { + name = "termcolor"; + packageId = "termcolor"; + optional = true; + } + ]; + features = { + "atty" = [ "dep:atty" ]; + "default" = [ "termcolor" "atty" "humantime" "regex" ]; + "humantime" = [ "dep:humantime" ]; + "regex" = [ "dep:regex" ]; + "termcolor" = [ "dep:termcolor" ]; + }; + resolvedDefaultFeatures = [ "atty" "default" "humantime" "regex" "termcolor" ]; + }; "glob" = rec { crateName = "glob"; version = "0.3.0"; @@ -535,7 +578,7 @@ rec { }; resolvedDefaultFeatures = [ "default" ]; }; - "humantime" = rec { + "humantime 1.3.0" = rec { crateName = "humantime"; version = "1.3.0"; edition = "2015"; @@ -550,30 +593,31 @@ rec { } ]; + }; + "humantime 2.1.0" = rec { + crateName = "humantime"; + version = "2.1.0"; + edition = "2018"; + sha256 = "1r55pfkkf5v0ji1x6izrjwdq9v6sc7bv99xj6srywcar37xmnfls"; + authors = [ + "Paul Colomiets " + ]; + }; "ipnetwork" = rec { crateName = "ipnetwork"; - version = "0.16.0"; - edition = "2018"; - sha256 = "07nkh9djfmkkwd0phkgrv977kfmvw4hmrn1xxw4cjyx23psskv5q"; + version = "0.20.0"; + edition = "2021"; + sha256 = "03hhmxyimz0800z44wl3z1ak8iw91xcnk7sgx5p5jinmx50naimz"; authors = [ "Abhishek Chanda " "Linus Färnstrand " ]; - dependencies = [ - { - name = "serde"; - packageId = "serde"; - optional = true; - } - ]; features = { - "clippy" = [ "dep:clippy" ]; "default" = [ "serde" ]; - "dev" = [ "clippy" ]; + "schemars" = [ "dep:schemars" ]; "serde" = [ "dep:serde" ]; }; - resolvedDefaultFeatures = [ "default" "serde" ]; }; "lazy_static" = rec { crateName = "lazy_static"; @@ -915,6 +959,7 @@ rec { { name = "ipnetwork"; packageId = "ipnetwork"; + usesDefaultFeatures = false; } { name = "libc"; @@ -928,11 +973,6 @@ rec { name = "nix"; packageId = "nix"; } - { - name = "serde"; - packageId = "serde"; - features = [ "derive" ]; - } { name = "thiserror"; packageId = "thiserror"; @@ -954,14 +994,11 @@ rec { ]; devDependencies = [ { - name = "rustables"; - packageId = "rustables"; - features = [ "query" ]; + name = "env_logger"; + packageId = "env_logger 0.9.3"; } ]; - features = { - }; - resolvedDefaultFeatures = [ "query" "unsafe-raw-handles" ]; + }; "rustc-hash" = rec { crateName = "rustc-hash"; @@ -976,63 +1013,6 @@ rec { }; resolvedDefaultFeatures = [ "default" "std" ]; }; - "serde" = rec { - crateName = "serde"; - version = "1.0.147"; - edition = "2015"; - sha256 = "0rc9jj8bbhf3lkf07ln8kyljigyzc4kk90nzg4dc2gwqmsdxd4yi"; - authors = [ - "Erick Tryzelaar " - "David Tolnay " - ]; - dependencies = [ - { - name = "serde_derive"; - packageId = "serde_derive"; - optional = true; - } - ]; - devDependencies = [ - { - name = "serde_derive"; - packageId = "serde_derive"; - } - ]; - features = { - "default" = [ "std" ]; - "derive" = [ "serde_derive" ]; - "serde_derive" = [ "dep:serde_derive" ]; - }; - resolvedDefaultFeatures = [ "default" "derive" "serde_derive" "std" ]; - }; - "serde_derive" = rec { - crateName = "serde_derive"; - version = "1.0.147"; - edition = "2015"; - sha256 = "0ln8rqbybpxmk4fvh6lgm75acs1d8x90fi44fhx3x77wm0n3c7ag"; - procMacro = true; - authors = [ - "Erick Tryzelaar " - "David Tolnay " - ]; - dependencies = [ - { - name = "proc-macro2"; - packageId = "proc-macro2"; - } - { - name = "quote"; - packageId = "quote"; - } - { - name = "syn"; - packageId = "syn"; - } - ]; - features = { - }; - resolvedDefaultFeatures = [ "default" ]; - }; "shlex" = rec { crateName = "shlex"; version = "0.1.1"; diff --git a/examples/add-rules.rs b/examples/add-rules.rs index cb4e41c..cd7423c 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -87,30 +87,20 @@ fn main() -> Result<(), Error> { batch.add(&rule, MsgType::Add); - let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::builder() - .with_expression(Immediate::new_data( - vec![1, 2, 3, 4], - rustables::expr::Register::Reg2, - )) - .with_expression(Immediate::new_verdict(VerdictKind::Continue)), - ); - - batch.add(&rule, MsgType::Add); - - // // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === - // - // // Create a new rule object under the input chain. - // let mut allow_loopback_in_rule = Rule::new(Rc::clone(&in_chain)); - // // Lookup the interface index of the loopback interface. - // let lo_iface_index = iface_index("lo")?; - // - // // First expression to be evaluated in this rule is load the meta information "iif" - // // (incoming interface index) into the comparison register of netfilter. - // // When an incoming network packet is processed by this rule it will first be processed by this - // // expression, which will load the interface index of the interface the packet came from into - // // a special "register" in netfilter. - // allow_loopback_in_rule.add_expr(&nft_expr!(meta iif)); + // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === + + // Create a new rule object under the input chain. + let mut allow_loopback_in_rule = Rule::new(&in_chain)?; + // Lookup the interface index of the loopback interface. + let lo_iface_index = iface_index("lo")?; + + // First expression to be evaluated in this rule is load the meta information "iif" + // (incoming interface index) into the comparison register of netfilter. + // When an incoming network packet is processed by this rule it will first be processed by this + // expression, which will load the interface index of the interface the packet came from into + // a special "register" in netfilter. + //allow_loopback_in_rule.set_expressions(ExpressionList::builder().with_expression()); + //add_expr(&nft_expr!(meta iif)); // // Next expression in the rule is to compare the value loaded into the register with our desired // // interface index, and succeed only if it's equal. For any packet processed where the equality // // does not hold the packet is said to not match this rule, and the packet moves on to be @@ -190,17 +180,7 @@ fn main() -> Result<(), Error> { // netfilter the we reached the end of the transaction message. It's also converted to a // Vec, containing the raw netlink data so it can be sent over a netlink socket to netfilter. // Finally, the batch is sent over to the kernel. - batch.send()?; - - let tables = list_tables()?; - let chains = list_chains_for_table(&tables[0])?; - let rules = list_rules_for_chain(&chains[1])?; - for rule in rules { - for expr in rule.get_expressions().unwrap().iter() { - println!("{:?}", expr); - } - } - Ok(()) + Ok(batch.send()?) } // Look up the interface index for a given interface name. diff --git a/include/tests_wrapper.h b/include/tests_wrapper.h deleted file mode 100644 index e69de29..0000000 diff --git a/src/chain.rs b/src/chain.rs index cce0fa9..eeedcd1 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,12 +1,12 @@ use libc::{NF_ACCEPT, NF_DROP}; -use crate::nlmsg::{ - NfNetlinkAttribute, NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, - NfNetlinkWriter, -}; -use crate::parser::{parse_object, DecodeError, InnerFormat, NfNetlinkAttributeReader}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::{DecodeError, InnerFormat, Parsable}; use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE}; -use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily, Table}; +use crate::{ + create_wrapper_type, impl_attr_getters_and_setters, impl_nfnetlinkattribute, MsgType, + ProtocolFamily, Table, +}; use std::convert::TryFrom; use std::fmt::Debug; @@ -28,29 +28,8 @@ pub enum HookClass { PostRouting = libc::NF_INET_POST_ROUTING, } -#[derive(Clone, PartialEq, Eq)] -pub struct Hook { - inner: NfNetlinkAttributes, -} - -impl Hook { - pub fn new(class: HookClass, priority: ChainPriority) -> Self { - Hook { - inner: NfNetlinkAttributes::new(), - } - .with_class(class as u32) - .with_priority(priority as u32) - } -} - -impl Debug for Hook { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.inner_format_struct(f.debug_struct("Hook"))?.finish() - } -} - -impl_attr_getters_and_setters!( - Hook, +create_wrapper_type!( + nested: Hook, [ // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. ( @@ -58,7 +37,7 @@ impl_attr_getters_and_setters!( set_class, with_class, sys::NFTA_HOOK_HOOKNUM, - U32, + class, u32 ), ( @@ -66,31 +45,17 @@ impl_attr_getters_and_setters!( set_priority, with_priority, sys::NFTA_HOOK_PRIORITY, - U32, + priority, u32 ) ] ); -impl NfNetlinkAttribute for Hook { - fn is_nested(&self) -> bool { - true - } - - fn get_size(&self) -> usize { - self.inner.get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - self.inner.write_payload(addr) - } -} - -impl NfNetlinkDeserializable for Hook { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let reader = NfNetlinkAttributeReader::new(buf, buf.len())?; - let inner = reader.decode::()?; - Ok((Hook { inner }, &[])) +impl Hook { + pub fn new(class: HookClass, priority: ChainPriority) -> Self { + Hook::default() + .with_class(class as u32) + .with_priority(priority as u32) } } @@ -186,10 +151,16 @@ impl NfNetlinkDeserializable for ChainType { /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html /// [`set_hook`]: #method.set_hook -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Default)] pub struct Chain { family: ProtocolFamily, - inner: NfNetlinkAttributes, + flags: Option, + name: Option, + hook: Option, + policy: Option, + table: Option, + chain_type: Option, + userdata: Option>, } impl Chain { @@ -197,10 +168,8 @@ impl Chain { /// /// [`Table`]: struct.Table.html pub fn new(table: &Table) -> Chain { - let mut chain = Chain { - family: table.get_family(), - inner: NfNetlinkAttributes::new(), - }; + let mut chain = Chain::default(); + chain.family = table.family; if let Some(table_name) = table.get_name() { chain.set_table(table_name); @@ -213,10 +182,6 @@ impl Chain { self.family } - fn raw_attributes(&self) -> &NfNetlinkAttributes { - &self.inner - } - /* /// Returns a textual description of the chain. pub fn get_str(&self) -> CString { @@ -268,31 +233,29 @@ impl NfNetlinkObject for Chain { seq, None, ); - self.inner.serialize(writer); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } writer.finalize_writing_object(); } } impl NfNetlinkDeserializable for Chain { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (inner, nfgenmsg, remaining_data) = - parse_object::(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?; + let (mut obj, nfgenmsg, remaining_data) = + Self::parse_object(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?; + obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; - Ok(( - Self { - inner, - family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, - }, - remaining_data, - )) + Ok((obj, remaining_data)) } } impl_attr_getters_and_setters!( Chain, [ - (get_flags, set_flags, with_flags, sys::NFTA_CHAIN_FLAGS, U32, u32), - (get_name, set_name, with_name, sys::NFTA_CHAIN_NAME, String, String), + (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, table, String), + (get_name, set_name, with_name, sys::NFTA_CHAIN_NAME, name, String), // Sets the hook and priority for this chain. Without calling this method the chain will // become a "regular chain" without any hook and will thus not receive any traffic unless // some rule forward packets to it via goto or jump verdicts. @@ -300,22 +263,38 @@ impl_attr_getters_and_setters!( // By calling `set_hook` with a hook the chain that is created will be registered with that // hook and is thus a "base chain". A "base chain" is an entry point for packets from the // networking stack. - (get_hook, set_hook, with_hook, sys::NFTA_CHAIN_HOOK, ChainHook, Hook), - (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, ChainPolicy, ChainPolicy), - (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, String, String), + (get_hook, set_hook, with_hook, sys::NFTA_CHAIN_HOOK, hook, Hook), + (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, policy, ChainPolicy), // This only applies if the chain has been registered with a hook by calling `set_hook`. - (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, ChainType, ChainType), + (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, chain_type, ChainType), + (get_flags, set_flags, with_flags, sys::NFTA_CHAIN_FLAGS, flags, u32), ( get_userdata, set_userdata, with_userdata, sys::NFTA_CHAIN_USERDATA, - VecU8, + userdata, Vec ) ] ); +impl_nfnetlinkattribute!( + inline : Chain, + [ + (sys::NFTA_CHAIN_TABLE, table), + (sys::NFTA_CHAIN_NAME, name), + (sys::NFTA_CHAIN_HOOK, hook), + (sys::NFTA_CHAIN_POLICY, policy), + (sys::NFTA_CHAIN_TYPE, chain_type), + (sys::NFTA_CHAIN_FLAGS, flags), + ( + sys::NFTA_CHAIN_USERDATA, + userdata + ) + ] +); + pub fn list_chains_for_table(table: &Table) -> Result, crate::query::Error> { let mut result = Vec::new(); crate::query::list_objects_with_data( diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc
) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index 38c0383..73c2467 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,41 +1,34 @@ use super::{Expression, ExpressionData, Register}; -use crate::create_expr_type; +use crate::create_wrapper_type; use crate::parser::DecodeError; use crate::sys; -create_expr_type!( - inline with_builder : Bitwise, +create_wrapper_type!( + inline: Bitwise, [ - ( - get_dreg, - set_dreg, - with_dreg, - sys::NFTA_BITWISE_DREG, - Register, - Register - ), ( get_sreg, set_sreg, with_sreg, sys::NFTA_BITWISE_SREG, - Register, + sreg, Register ), ( - get_len, - set_len, - with_len, - sys::NFTA_BITWISE_LEN, - U32, - u32 + get_dreg, + set_dreg, + with_dreg, + sys::NFTA_BITWISE_DREG, + dreg, + Register ), + (get_len, set_len, with_len, sys::NFTA_BITWISE_LEN, len, u32), ( get_mask, set_mask, with_mask, sys::NFTA_BITWISE_MASK, - ExprData, + mask, ExpressionData ), ( @@ -43,7 +36,7 @@ create_expr_type!( set_xor, with_xor, sys::NFTA_BITWISE_XOR, - ExprData, + xor, ExpressionData ) ] @@ -64,11 +57,11 @@ impl Bitwise { if mask.len() != xor.len() { return Err(DecodeError::IncompatibleLength); } - Ok(Self::builder() + Ok(Bitwise::default() .with_sreg(Register::Reg1) .with_dreg(Register::Reg1) .with_len(mask.len() as u32) - .with_xor(ExpressionData::builder().with_value(xor)) - .with_mask(ExpressionData::builder().with_value(mask))) + .with_xor(ExpressionData::default().with_value(xor)) + .with_mask(ExpressionData::default().with_value(mask))) } } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 6f26bc3..925ca06 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,15 +1,15 @@ use super::{Expression, ExpressionData, Register}; -use crate::{create_expr_type, sys}; +use crate::{create_wrapper_type, sys}; -create_expr_type!( - inline with_builder : Immediate, +create_wrapper_type!( + inline: Immediate, [ ( get_dreg, set_dreg, with_dreg, sys::NFTA_IMMEDIATE_DREG, - Register, + dreg, Register ), ( @@ -17,7 +17,7 @@ create_expr_type!( set_data, with_data, sys::NFTA_IMMEDIATE_DATA, - ExprData, + data, ExpressionData ) ] @@ -25,9 +25,9 @@ create_expr_type!( impl Immediate { pub fn new_data(data: Vec, register: Register) -> Self { - Immediate::builder() + Immediate::default() .with_dreg(register) - .with_data(ExpressionData::builder().with_value(data)) + .with_data(ExpressionData::default().with_value(data)) } } diff --git a/src/expr/log.rs b/src/expr/log.rs index cf50cb2..82c201d 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,18 +1,17 @@ use super::{Expression, ExpressionError}; -use crate::create_expr_type; -use crate::nlmsg::NfNetlinkAttributes; +use crate::create_wrapper_type; use crate::sys; // A Log expression will log all packets that match the rule. -create_expr_type!( - inline with_builder : Log, +create_wrapper_type!( + inline: Log, [ ( get_group, set_group, with_group, sys::NFTA_LOG_GROUP, - U32, + group, u32 ), ( @@ -20,7 +19,7 @@ create_expr_type!( set_prefix, with_prefix, sys::NFTA_LOG_PREFIX, - String, + prefix, String ) ] @@ -31,11 +30,7 @@ impl Log { group: Option, prefix: Option>, ) -> Result { - let mut res = Log { - inner: NfNetlinkAttributes::new(), - //pub group: Option, - //pub prefix: Option, - }; + let mut res = Log::default(); if let Some(group) = group { res.set_group(group); } @@ -55,60 +50,4 @@ impl Expression for Log { fn get_name() -> &'static str { "log" } - /* - fn from_expr(expr: *const sys::nftnl_expr) -> Result - where - Self: Sized, - { - unsafe { - let mut group = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) { - group = Some(LogGroup(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_LOG_GROUP as u16, - ) as u16)); - } - let mut prefix = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { - let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); - if raw_prefix.is_null() { - return Err(DeserializationError::NullPointer); - } else { - prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); - } - } - Ok(Log { group, prefix }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char)); - if let Some(log_group) = self.group { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32); - }; - if let Some(LogPrefix(prefix)) = &self.prefix { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr()); - }; - - expr - } - } - */ -} - -#[macro_export] -macro_rules! nft_expr_log { - (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log::new(Some($group), Some($prefix)) - }; - (prefix $prefix:expr) => { - $crate::expr::Log::new(None, Some($prefix)) - }; - (group $group:ident) => { - $crate::expr::Log::new(Some($group), None) - }; - () => { - $crate::expr::Log::new(None, None) - }; } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index a015f65..bb8023d 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,175 +1,115 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use super::{Expression, Register, Rule}; +use crate::{ + create_wrapper_type, + nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, + parser::DecodeError, + sys, +}; +use std::convert::TryFrom; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u32)] #[non_exhaustive] -pub enum Meta { +pub enum MetaType { /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. - Protocol, + Protocol = sys::NFT_META_PROTOCOL, /// Packet mark. - Mark { set: bool }, + Mark = sys::NFT_META_MARK, /// Packet input interface index (dev->ifindex). - Iif, + Iif = sys::NFT_META_IIF, /// Packet output interface index (dev->ifindex). - Oif, + Oif = sys::NFT_META_OIF, /// Packet input interface name (dev->name). - IifName, + IifName = sys::NFT_META_IIFNAME, /// Packet output interface name (dev->name). - OifName, + OifName = sys::NFT_META_OIFNAME, /// Packet input interface type (dev->type). - IifType, + IifType = sys::NFT_META_IFTYPE, /// Packet output interface type (dev->type). - OifType, + OifType = sys::NFT_META_OIFTYPE, /// Originating socket UID (fsuid). - SkUid, + SkUid = sys::NFT_META_SKUID, /// Originating socket GID (fsgid). - SkGid, + SkGid = sys::NFT_META_SKGID, /// Netfilter protocol (Transport layer protocol). - NfProto, + NfProto = sys::NFT_META_NFPROTO, /// Layer 4 protocol number. - L4Proto, + L4Proto = sys::NFT_META_L4PROTO, /// Socket control group (skb->sk->sk_classid). - Cgroup, + Cgroup = sys::NFT_META_CGROUP, /// A 32bit pseudo-random number. - PRandom, + PRandom = sys::NFT_META_PRANDOM, } -impl Meta { - /// Returns the corresponding `NFT_*` constant for this meta expression. - pub fn to_raw_key(&self) -> u32 { - use Meta::*; - match *self { - Protocol => libc::NFT_META_PROTOCOL as u32, - Mark { .. } => libc::NFT_META_MARK as u32, - Iif => libc::NFT_META_IIF as u32, - Oif => libc::NFT_META_OIF as u32, - IifName => libc::NFT_META_IIFNAME as u32, - OifName => libc::NFT_META_OIFNAME as u32, - IifType => libc::NFT_META_IIFTYPE as u32, - OifType => libc::NFT_META_OIFTYPE as u32, - SkUid => libc::NFT_META_SKUID as u32, - SkGid => libc::NFT_META_SKGID as u32, - NfProto => libc::NFT_META_NFPROTO as u32, - L4Proto => libc::NFT_META_L4PROTO as u32, - Cgroup => libc::NFT_META_CGROUP as u32, - PRandom => libc::NFT_META_PRANDOM as u32, - } +impl NfNetlinkAttribute for MetaType { + fn get_size(&self) -> usize { + (*self as u32).get_size() } - fn from_raw(val: u32) -> Result { - match val as i32 { - libc::NFT_META_PROTOCOL => Ok(Self::Protocol), - libc::NFT_META_MARK => Ok(Self::Mark { set: false }), - libc::NFT_META_IIF => Ok(Self::Iif), - libc::NFT_META_OIF => Ok(Self::Oif), - libc::NFT_META_IIFNAME => Ok(Self::IifName), - libc::NFT_META_OIFNAME => Ok(Self::OifName), - libc::NFT_META_IIFTYPE => Ok(Self::IifType), - libc::NFT_META_OIFTYPE => Ok(Self::OifType), - libc::NFT_META_SKUID => Ok(Self::SkUid), - libc::NFT_META_SKGID => Ok(Self::SkGid), - libc::NFT_META_NFPROTO => Ok(Self::NfProto), - libc::NFT_META_L4PROTO => Ok(Self::L4Proto), - libc::NFT_META_CGROUP => Ok(Self::Cgroup), - libc::NFT_META_PRANDOM => Ok(Self::PRandom), - _ => Err(DeserializationError::InvalidValue), - } + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as u32).write_payload(addr); } } -impl Expression for Meta { - fn get_raw_name() -> *const libc::c_char { - b"meta\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result - where - Self: Sized, - { - unsafe { - let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_META_KEY as u16, - ))?; - - if let Self::Mark { ref mut set } = ret { - *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); - } - - Ok(ret) - } +impl NfNetlinkDeserializable for MetaType { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = u32::deserialize(buf)?; + Ok(( + match v { + sys::NFT_META_PROTOCOL => Self::Protocol, + sys::NFT_META_MARK => Self::Mark, + sys::NFT_META_IIF => Self::Iif, + sys::NFT_META_OIF => Self::Oif, + sys::NFT_META_IIFNAME => Self::IifName, + sys::NFT_META_OIFNAME => Self::OifName, + sys::NFT_META_IFTYPE => Self::IifType, + sys::NFT_META_OIFTYPE => Self::OifType, + sys::NFT_META_SKUID => Self::SkUid, + sys::NFT_META_SKGID => Self::SkGid, + sys::NFT_META_NFPROTO => Self::NfProto, + sys::NFT_META_L4PROTO => Self::L4Proto, + sys::NFT_META_CGROUP => Self::Cgroup, + sys::NFT_META_PRANDOM => Self::PRandom, + value => return Err(DecodeError::UnknownMetaType(value)), + }, + remaining_data, + )) } +} - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); +create_wrapper_type!( + inline: Meta, + [ + ( + get_dreg, + set_dreg, + with_dreg, + sys::NFTA_META_DREG, + dreg, + Register + ), + ( + get_key, + set_key, + with_key, + sys::NFTA_META_KEY, + key, + MetaType + ), + ( + get_sreg, + set_sreg, + with_sreg, + sys::NFTA_META_SREG, + sreg, + Register + ) + ] +); - if let Meta::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key()); - expr - } +impl Expression for Meta { + fn get_name() -> &'static str { + "meta" } } - -#[macro_export] -macro_rules! nft_expr_meta { - (proto) => { - $crate::expr::Meta::Protocol - }; - (mark set) => { - $crate::expr::Meta::Mark { set: true } - }; - (mark) => { - $crate::expr::Meta::Mark { set: false } - }; - (iif) => { - $crate::expr::Meta::Iif - }; - (oif) => { - $crate::expr::Meta::Oif - }; - (iifname) => { - $crate::expr::Meta::IifName - }; - (oifname) => { - $crate::expr::Meta::OifName - }; - (iiftype) => { - $crate::expr::Meta::IifType - }; - (oiftype) => { - $crate::expr::Meta::OifType - }; - (skuid) => { - $crate::expr::Meta::SkUid - }; - (skgid) => { - $crate::expr::Meta::SkGid - }; - (nfproto) => { - $crate::expr::Meta::NfProto - }; - (l4proto) => { - $crate::expr::Meta::L4Proto - }; - (cgroup) => { - $crate::expr::Meta::Cgroup - }; - (random) => { - $crate::expr::Meta::PRandom - }; -} diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 78b1717..6dfa6c7 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -12,14 +12,13 @@ use std::net::Ipv6Addr; use std::slice::Iter; use super::rule::Rule; +use crate::create_wrapper_type; use crate::nlmsg::AttributeDecoder; use crate::nlmsg::NfNetlinkAttribute; -use crate::nlmsg::NfNetlinkAttributes; use crate::nlmsg::NfNetlinkDeserializable; use crate::parser::pad_netlink_object; use crate::parser::pad_netlink_object_with_variable_size; use crate::parser::write_attribute; -use crate::parser::AttributeType; use crate::parser::DecodeError; use crate::parser::InnerFormat; use crate::sys::{self, nlattr}; @@ -52,10 +51,12 @@ pub use self::lookup::*; mod masquerade; pub use self::masquerade::*; +*/ mod meta; pub use self::meta::*; +/* mod nat; pub use self::nat::*; @@ -111,91 +112,15 @@ pub trait Expression { fn get_name() -> &'static str; } -// wrapper for the general case, as we need to create many holder types given the depth of some -// netlink expressions -#[macro_export] -macro_rules! create_expr_type { - (without_decoder : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - #[derive(Clone, PartialEq, Eq)] - pub struct $struct { - inner: $crate::nlmsg::NfNetlinkAttributes, - } - - - $crate::impl_attr_getters_and_setters!(without_decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - - impl std::fmt::Debug for $struct { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use $crate::parser::InnerFormat; - self.inner_format_struct(f.debug_struct(stringify!($struct)))? - .finish() - } - } - - - impl $crate::nlmsg::NfNetlinkDeserializable for $struct { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), $crate::parser::DecodeError> { - let reader = $crate::parser::NfNetlinkAttributeReader::new(buf, buf.len())?; - let inner = reader.decode::()?; - Ok(($struct { inner }, &[])) - } - } - - }; - ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_expr_type!(without_decoder : $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_attr_getters_and_setters!(decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - }; - (with_builder : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_expr_type!($struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - - impl $struct { - pub fn builder() -> Self { - Self { inner: $crate::nlmsg::NfNetlinkAttributes::new() } - } - } - }; - (inline $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_expr_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - - impl $crate::nlmsg::NfNetlinkAttribute for $struct { - fn get_size(&self) -> usize { - self.inner.get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - self.inner.write_payload(addr) - } - } - }; - (nested $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_expr_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - - impl $crate::nlmsg::NfNetlinkAttribute for $struct { - fn is_nested(&self) -> bool { - true - } - - fn get_size(&self) -> usize { - self.inner.get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - self.inner.write_payload(addr) - } - } - }; -} - -create_expr_type!( - nested without_decoder : ExpressionHolder, [ +create_wrapper_type!( + nested without_deser : RawExpression, [ // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. ( get_name, set_name, with_name, sys::NFTA_EXPR_NAME, - String, + name, String ), ( @@ -203,22 +128,20 @@ create_expr_type!( set_data, with_data, sys::NFTA_EXPR_DATA, - ExpressionVariant, + data, ExpressionVariant ) ]); -impl ExpressionHolder { +impl RawExpression { pub fn new(expr: T) -> Self where T: Expression, ExpressionVariant: From, { - ExpressionHolder { - inner: NfNetlinkAttributes::new(), - } - .with_name(T::get_name()) - .with_data(ExpressionVariant::from(expr)) + RawExpression::default() + .with_name(T::get_name()) + .with_data(ExpressionVariant::from(expr)) } } @@ -262,44 +185,45 @@ macro_rules! create_expr_variant { } )+ - impl AttributeDecoder for ExpressionHolder { + impl $crate::nlmsg::AttributeDecoder for RawExpression { fn decode_attribute( - attrs: &NfNetlinkAttributes, + &mut self, attr_type: u16, buf: &[u8], - ) -> Result { + ) -> Result<(), $crate::parser::DecodeError> { debug!("Decoding attribute {} in an expression", attr_type); match attr_type { x if x == sys::NFTA_EXPR_NAME => { debug!("Calling {}::deserialize()", std::any::type_name::()); let (val, remaining) = String::deserialize(buf)?; if remaining.len() != 0 { - return Err(DecodeError::InvalidDataSize); + return Err($crate::parser::DecodeError::InvalidDataSize); } - Ok(AttributeType::String(val)) + self.name = Some(val); + Ok(()) }, x if x == sys::NFTA_EXPR_DATA => { // we can assume we have already the name parsed, as that's how we identify the // type of expression - let name = attrs - .get_attr(sys::NFTA_EXPR_NAME) - .ok_or(DecodeError::MissingExpressionName)?; + let name = self.name.as_ref() + .ok_or($crate::parser::DecodeError::MissingExpressionName)?; match name { $( - AttributeType::String(x) if x == <$type>::get_name() => { + x if x == <$type>::get_name() => { debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); let (res, remaining) = <$type>::deserialize(buf)?; if remaining.len() != 0 { return Err($crate::parser::DecodeError::InvalidDataSize); } - Ok(AttributeType::ExpressionVariant(ExpressionVariant::from(res))) + self.data = Some(ExpressionVariant::from(res)); + Ok(()) }, )+ - AttributeType::String(name) => { + name => { info!("Unrecognized expression '{}', generating an ExpressionRaw", name); - ExpressionRaw::deserialize(buf).map(|(res, _)| AttributeType::ExpressionVariant(ExpressionVariant::ExpressionRaw(res))) - }, - _ => unreachable!() + self.data = Some(ExpressionVariant::ExpressionRaw(ExpressionRaw::deserialize(buf)?.0)); + Ok(()) + } } }, _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), @@ -314,12 +238,13 @@ create_expr_variant!( [Log, Log], [Immediate, Immediate], [Bitwise, Bitwise], - [ExpressionRaw, ExpressionRaw] + [ExpressionRaw, ExpressionRaw], + [Meta, Meta] ); -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct ExpressionList { - exprs: Vec, + exprs: Vec, } impl ExpressionList { @@ -327,9 +252,9 @@ impl ExpressionList { Self { exprs: Vec::new() } } - /// Useful to add raw expressions because ExpressionHolder cannot infer alone its type - pub fn add_raw_expression(&mut self, e: ExpressionHolder) { - self.exprs.push(AttributeType::Expression(e)); + /// Useful to add raw expressions because RawExpression cannot infer alone its type + pub fn add_raw_expression(&mut self, e: RawExpression) { + self.exprs.push(e); } pub fn add_expression(&mut self, e: T) @@ -337,8 +262,7 @@ impl ExpressionList { T: Expression, ExpressionVariant: From, { - self.exprs - .push(AttributeType::Expression(ExpressionHolder::new(e))); + self.exprs.push(RawExpression::new(e)); } pub fn with_expression(mut self, e: T) -> Self @@ -351,10 +275,7 @@ impl ExpressionList { } pub fn iter<'a>(&'a self) -> impl Iterator { - self.exprs.iter().map(|t| match t { - AttributeType::Expression(e) => e.get_data().unwrap(), - _ => unreachable!(), - }) + self.exprs.iter().map(|e| e.get_data().unwrap()) } } @@ -392,13 +313,13 @@ impl NfNetlinkDeserializable for ExpressionList { return Err(DecodeError::UnsupportedAttributeType(nla_type)); } - let (expr, remaining) = ExpressionHolder::deserialize( + let (expr, remaining) = RawExpression::deserialize( &buf[pos + pad_netlink_object::()..pos + nlattr.nla_len as usize], )?; if remaining.len() != 0 { return Err(DecodeError::InvalidDataSize); } - exprs.push(AttributeType::Expression(expr)); + exprs.push(expr); pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); } @@ -411,15 +332,27 @@ impl NfNetlinkDeserializable for ExpressionList { } } -create_expr_type!( - nested with_builder : ExpressionData, +impl From> for ExpressionList +where + ExpressionVariant: From, + T: Expression, +{ + fn from(v: Vec) -> Self { + ExpressionList { + exprs: v.into_iter().map(RawExpression::new).collect(), + } + } +} + +create_wrapper_type!( + nested : ExpressionData, [ ( get_value, set_value, with_value, sys::NFTA_DATA_VALUE, - VecU8, + value, Vec ), ( @@ -427,7 +360,7 @@ create_expr_type!( set_verdict, with_verdict, sys::NFTA_DATA_VERDICT, - ExprVerdictAttribute, + verdict, VerdictAttribute ) ] @@ -454,7 +387,7 @@ impl NfNetlinkDeserializable for ExpressionRaw { } // Because we loose the name of the expression when parsing, this is the only expression -// where deserializing a message and the reserializing it alter its content +// where deserializing a message and then reserializing it is invalid impl Expression for ExpressionRaw { fn get_name() -> &'static str { "unknown_expression" diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 7c27af6..547ba91 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -4,7 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; use super::{ExpressionData, Immediate, Register}; use crate::{ - create_expr_type, + create_wrapper_type, nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, parser::DecodeError, sys::{self, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN}, @@ -53,15 +53,15 @@ impl NfNetlinkDeserializable for VerdictType { } } -create_expr_type!( - nested with_builder : VerdictAttribute, +create_wrapper_type!( + nested: VerdictAttribute, [ ( get_code, set_code, with_code, sys::NFTA_VERDICT_CODE, - ExprVerdictType, + code, VerdictType ), ( @@ -69,7 +69,7 @@ create_expr_type!( set_chain, with_chain, sys::NFTA_VERDICT_CHAIN, - String, + chain, String ), ( @@ -77,7 +77,7 @@ create_expr_type!( set_chain_id, with_chain_id, sys::NFTA_VERDICT_CHAIN_ID, - U32, + chain_id, u32 ) ] @@ -113,12 +113,12 @@ impl Immediate { VerdictKind::Goto { .. } => VerdictType::Goto, VerdictKind::Return => VerdictType::Return, }; - let mut data = VerdictAttribute::builder().with_code(code); + let mut data = VerdictAttribute::default().with_code(code); if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { data.set_chain(chain); } - Immediate::builder() + Immediate::default() .with_dreg(Register::Verdict) - .with_data(ExpressionData::builder().with_verdict(data)) + .with_data(ExpressionData::default().with_verdict(data)) } } diff --git a/src/lib.rs b/src/lib.rs index d02c785..044030f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,25 +78,11 @@ extern crate log; pub mod sys; use libc; -use std::{convert::TryFrom, ffi::c_void, ops::Deref}; - -macro_rules! try_alloc { - ($e:expr) => {{ - let ptr = $e; - if ptr.is_null() { - // OOM, and the tried allocation was likely very small, - // so we are in a very tight situation. We do what libstd does, aborts. - std::process::abort(); - } - ptr - }}; -} +use std::convert::TryFrom; mod batch; pub use batch::{default_batch_page_size, Batch}; -pub mod expr; - mod table; pub use table::list_tables; pub use table::Table; @@ -117,6 +103,8 @@ mod rule; pub use rule::list_rules_for_chain; pub use rule::Rule; +pub mod expr; + //mod rule_methods; //pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; @@ -155,6 +143,12 @@ pub enum ProtocolFamily { DecNet = libc::NFPROTO_DECNET, } +impl Default for ProtocolFamily { + fn default() -> Self { + Self::Unspec + } +} + impl TryFrom for ProtocolFamily { type Error = DecodeError; fn try_from(value: i32) -> Result { diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 8960146..8563a37 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -1,12 +1,9 @@ -use std::{collections::BTreeMap, fmt::Debug, mem::size_of}; +use std::{fmt::Debug, mem::size_of}; use crate::{ - parser::{ - pad_netlink_object, pad_netlink_object_with_variable_size, write_attribute, AttributeType, - DecodeError, - }, + parser::{pad_netlink_object, pad_netlink_object_with_variable_size, DecodeError}, sys::{ - nfgenmsg, nlattr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, }, MsgType, ProtocolFamily, @@ -88,11 +85,7 @@ impl<'a> NfNetlinkWriter<'a> { } pub trait AttributeDecoder { - fn decode_attribute( - attrs: &NfNetlinkAttributes, - attr_type: u16, - buf: &[u8], - ) -> Result; + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; } pub trait NfNetlinkDeserializable: Sized { @@ -119,6 +112,7 @@ pub trait NfNetlinkAttribute: Debug + Sized { unsafe fn write_payload(&self, addr: *mut u8); } +/* #[derive(Debug, Clone, PartialEq, Eq)] pub struct NfNetlinkAttributes { pub attributes: BTreeMap, @@ -170,3 +164,4 @@ impl NfNetlinkAttribute for NfNetlinkAttributes { } } } +*/ diff --git a/src/parser.rs b/src/parser.rs index 8b14d74..b7d0ac3 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -9,10 +9,9 @@ use std::{ use thiserror::Error; use crate::{ - expr::ExpressionHolder, + //expr::ExpressionHolder, nlmsg::{ - AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkAttributes, - NfNetlinkDeserializable, NfNetlinkWriter, + AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkWriter, }, sys::{ nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, @@ -66,6 +65,9 @@ pub enum DecodeError { #[error("Invalid policy for a chain")] UnknownChainPolicy, + #[error("Unknown type for a Meta expression")] + UnknownMetaType(u32), + #[error("Invalid value for a register")] UnknownRegisterValue, @@ -208,7 +210,11 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -pub unsafe fn write_attribute<'a>(ty: NetlinkType, obj: &AttributeType, mut buf: *mut u8) { +pub unsafe fn write_attribute<'a>( + ty: NetlinkType, + obj: &impl NfNetlinkAttribute, + mut buf: *mut u8, +) { let header_len = pad_netlink_object::(); // copy the header *(buf as *mut nlattr) = nlattr { @@ -296,7 +302,6 @@ impl NfNetlinkDeserializable for u64 { } } -// TODO: safe handling for null-delimited strings impl NfNetlinkAttribute for String { fn get_size(&self) -> usize { self.len() @@ -346,122 +351,42 @@ impl NfNetlinkDeserializable for ProtocolFamily { } } -pub struct NfNetlinkAttributeReader<'a> { - buf: &'a [u8], - pos: usize, - remaining_size: usize, - attrs: NfNetlinkAttributes, -} - -impl<'a> NfNetlinkAttributeReader<'a> { - pub fn new(buf: &'a [u8], remaining_size: usize) -> Result { - if buf.len() < remaining_size { - return Err(DecodeError::BufTooSmall); +pub(crate) fn read_attributes(buf: &[u8]) -> Result { + debug!( + "Calling <{} as NfNetlinkDeserialize>::deserialize()", + std::any::type_name::() + ); + let mut remaining_size = buf.len(); + let mut pos = 0; + let mut res = T::default(); + while remaining_size > pad_netlink_object::() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + pos += pad_netlink_object::(); + let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::(); + match T::decode_attribute(&mut res, nla_type, &buf[pos..pos + attr_remaining_size]) { + Ok(()) => {} + Err(DecodeError::UnsupportedAttributeType(t)) => info!( + "Ignoring unsupported attribute type {} for type {}", + t, + std::any::type_name::() + ), + Err(e) => return Err(e), } + pos += pad_netlink_object_with_variable_size(attr_remaining_size); - Ok(Self { - buf, - pos: 0, - remaining_size, - attrs: NfNetlinkAttributes::new(), - }) + remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize); } - pub fn get_raw_data(&self) -> &'a [u8] { - &self.buf[self.pos..] - } - - pub fn decode( - mut self, - ) -> Result { - debug!( - "Calling NfNetlinkAttributeReader::decode() on {}", - std::any::type_name::() - ); - while self.remaining_size > pad_netlink_object::() { - let nlattr = - unsafe { *transmute::<*const u8, *const nlattr>(self.buf[self.pos..].as_ptr()) }; - // ignore the byteorder and nested attributes - let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; - - self.pos += pad_netlink_object::(); - let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::(); - match T::decode_attribute( - &self.attrs, - nla_type, - &self.buf[self.pos..self.pos + attr_remaining_size], - ) { - Ok(x) => self.attrs.set_attr(nla_type, x), - Err(DecodeError::UnsupportedAttributeType(t)) => info!( - "Ignoring unsupported attribute type {} for type {}", - t, - std::any::type_name::() - ), - Err(e) => return Err(e), - } - self.pos += pad_netlink_object_with_variable_size(attr_remaining_size); - - self.remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize); - } - - if self.remaining_size != 0 { - Err(DecodeError::InvalidDataSize) - } else { - Ok(self.attrs) - } + if remaining_size != 0 { + Err(DecodeError::InvalidDataSize) + } else { + Ok(res) } } -#[macro_export] -macro_rules! impl_attribute_holder { - ($enum_name:ident, $([$internal_name:ident, $type:ty]),+) => { - #[derive(Debug, Clone, PartialEq, Eq)] - pub enum $enum_name { - $( - $internal_name($type), - )+ - } - - impl NfNetlinkAttribute for $enum_name { - fn is_nested(&self) -> bool { - match self { - $( - $enum_name::$internal_name(val) => val.is_nested() - ),+ - } - } - - fn get_size(&self) -> usize { - match self { - $( - $enum_name::$internal_name(val) => val.get_size() - ),+ - } - } - - unsafe fn write_payload(&self, addr: *mut u8) { - match self { - $( - $enum_name::$internal_name(val) => val.write_payload(addr) - ),+ - } - } - } - - impl $enum_name { - $( - #[allow(non_snake_case)] - pub fn $internal_name(&self) -> Option<&$type> { - match self { - $enum_name::$internal_name(val) => Some(val), - _ => None - } - } - )+ - } - }; -} - pub trait InnerFormat { fn inner_format_struct<'a, 'b: 'a>( &'a self, @@ -469,49 +394,24 @@ pub trait InnerFormat { ) -> Result, std::fmt::Error>; } -impl_attribute_holder!( - AttributeType, - [String, String], - [U8, u8], - [U16, u16], - [I32, i32], - [U32, u32], - [U64, u64], - [VecU8, Vec], - [ChainHook, crate::chain::Hook], - [ChainPolicy, crate::chain::ChainPolicy], - [ChainType, crate::chain::ChainType], - [ProtocolFamily, crate::ProtocolFamily], - [Expression, crate::expr::ExpressionHolder], - [ExpressionVariant, crate::expr::ExpressionVariant], - [ExpressionList, crate::expr::ExpressionList], - [ExprLog, crate::expr::Log], - [ExprImmediate, crate::expr::Immediate], - [ExprData, crate::expr::ExpressionData], - [ExprVerdictAttribute, crate::expr::VerdictAttribute], - [ExprVerdictType, crate::expr::VerdictType], - [Register, crate::expr::Register], - [ExprRaw, crate::expr::ExpressionRaw] -); - #[macro_export] macro_rules! impl_attr_getters_and_setters { - (without_decoder $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + (without_deser $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { impl $struct { $( #[allow(dead_code)] pub fn $getter_name(&self) -> Option<&$type> { - self.inner.get_attr($attr_name as $crate::nlmsg::NetlinkType).map(|x| x.$internal_name()).flatten() + self.$internal_name.as_ref() } #[allow(dead_code)] pub fn $setter_name(&mut self, val: impl Into<$type>) { - self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into())); + self.$internal_name = Some(val.into()); } #[allow(dead_code)] pub fn $in_place_edit_name(mut self, val: impl Into<$type>) -> Self { - self.inner.set_attr($attr_name as $crate::nlmsg::NetlinkType, $crate::parser::AttributeType::$internal_name(val.into())); + self.$internal_name = Some(val.into()); self } @@ -540,10 +440,10 @@ macro_rules! impl_attr_getters_and_setters { } }; - (decoder $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + (deser $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { impl $crate::nlmsg::AttributeDecoder for $struct { #[allow(dead_code)] - fn decode_attribute(_attrs: &$crate::nlmsg::NfNetlinkAttributes, attr_type: u16, buf: &[u8]) -> Result<$crate::parser::AttributeType, $crate::parser::DecodeError> { + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), $crate::parser::DecodeError> { use $crate::nlmsg::NfNetlinkDeserializable; debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<$struct>()); match attr_type { @@ -554,7 +454,8 @@ macro_rules! impl_attr_getters_and_setters { if remaining.len() != 0 { return Err($crate::parser::DecodeError::InvalidDataSize); } - Ok($crate::parser::AttributeType::$internal_name(val)) + self.$setter_name(val); + Ok(()) }, )+ _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), @@ -563,39 +464,168 @@ macro_rules! impl_attr_getters_and_setters { } }; ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - $crate::impl_attr_getters_and_setters!(without_decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_attr_getters_and_setters!(decoder $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_attr_getters_and_setters!(without_deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_attr_getters_and_setters!(deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); }; } -pub fn parse_object( - buf: &[u8], - add_obj: u32, - del_obj: u32, -) -> Result<(NfNetlinkAttributes, nfgenmsg, &[u8]), DecodeError> { - debug!("parse_object() running"); - let (hdr, msg) = parse_nlmsg(buf)?; +pub trait Parsable +where + Self: Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>; +} + +impl Parsable for T +where + T: AttributeDecoder + Default + Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() started"); + let (hdr, msg) = parse_nlmsg(buf)?; + + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; + + if op != add_obj && op != del_obj { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } + + let obj_size = hdr.nlmsg_len as usize + - pad_netlink_object_with_variable_size(size_of::() + size_of::()); - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + let remaining_data = &buf[remaining_data_offset..]; - if op != add_obj && op != del_obj { - return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + let (nfgenmsg, res) = match msg { + NlMsg::NfGenMsg(nfgenmsg, content) => { + (nfgenmsg, read_attributes(&content[..obj_size])?) + } + _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), + }; + + Ok((res, nfgenmsg, remaining_data)) } +} + +#[macro_export] +macro_rules! impl_nfnetlinkattribute { + (__inner : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { + impl $struct { + fn inner_get_size(&self) -> usize { + use $crate::nlmsg::NfNetlinkAttribute; + use $crate::parser::{pad_netlink_object, pad_netlink_object_with_variable_size}; + use $crate::sys::nlattr; + let mut size = 0; + + $( + if let Some(val) = &self.$internal_name { + // Attribute header + attribute value + size += pad_netlink_object::() + + pad_netlink_object_with_variable_size(val.get_size()); + } + )+ + + size + } + + unsafe fn inner_write_payload(&self, mut addr: *mut u8) { + use $crate::nlmsg::NfNetlinkAttribute; + use $crate::parser::{pad_netlink_object, pad_netlink_object_with_variable_size}; + use $crate::sys::nlattr; + $( + if let Some(val) = &self.$internal_name { + debug!("writing attribute {} - {:?}", $attr_name, val); + + unsafe { + $crate::parser::write_attribute($attr_name, val, addr); + } + let size = pad_netlink_object::() + + pad_netlink_object_with_variable_size(val.get_size()); + #[allow(unused)] + { + addr = addr.offset(size as isize); + } + } + )+ + } + } + }; + (inline : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { + $crate::impl_nfnetlinkattribute!(__inner : $struct, [$(($attr_name, $internal_name)),+]); + + impl $crate::nlmsg::NfNetlinkAttribute for $struct { + fn get_size(&self) -> usize { + self.inner_get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.inner_write_payload(addr); + } + } + }; + (nested : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { + $crate::impl_nfnetlinkattribute!(__inner : $struct, [$(($attr_name, $internal_name)),+]); - let obj_size = hdr.nlmsg_len as usize - - pad_netlink_object_with_variable_size(size_of::() + size_of::()); + impl $crate::nlmsg::NfNetlinkAttribute for $struct { + fn is_nested(&self) -> bool { + true + } - let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); - let remaining_data = &buf[remaining_data_offset..]; + fn get_size(&self) -> usize { + self.inner_get_size() + } - let (nfgenmsg, attrs) = match msg { - NlMsg::NfGenMsg(nfgenmsg, content) => { - (nfgenmsg, NfNetlinkAttributeReader::new(content, obj_size)?) + unsafe fn write_payload(&self, addr: *mut u8) { + self.inner_write_payload(addr); + } } - _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), }; +} - let inner = attrs.decode::()?; +#[macro_export] +macro_rules! create_wrapper_type { + (without_deser : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + #[derive(Clone, PartialEq, Eq, Default)] + pub struct $struct { + $( + $internal_name: Option<$type> + ),+ + } - Ok((inner, nfgenmsg, remaining_data)) + $crate::impl_attr_getters_and_setters!(without_deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + + impl std::fmt::Debug for $struct { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use $crate::parser::InnerFormat; + self.inner_format_struct(f.debug_struct(stringify!($struct)))? + .finish() + } + } + + impl $crate::nlmsg::NfNetlinkDeserializable for $struct { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), $crate::parser::DecodeError> { + Ok(($crate::parser::read_attributes(buf)?, &[])) + } + } + }; + ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_wrapper_type!(without_deser : $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_attr_getters_and_setters!(deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + }; + (inline $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_wrapper_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_nfnetlinkattribute!(inline : $struct, [$(($attr_name, $internal_name)),+]); + }; + (nested $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { + create_wrapper_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); + $crate::impl_nfnetlinkattribute!(nested : $struct, [$(($attr_name, $internal_name)),+]); + }; } diff --git a/src/query.rs b/src/query.rs index 8ea7b89..294cbfe 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,7 +1,7 @@ use std::os::unix::prelude::RawFd; use crate::{ - nlmsg::{NfNetlinkAttributes, NfNetlinkObject, NfNetlinkWriter}, + nlmsg::{NfNetlinkAttribute, NfNetlinkObject, NfNetlinkWriter}, parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size}, sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI}, ProtocolFamily, @@ -152,10 +152,10 @@ where /// Returns a buffer containing a netlink message which requests a list of all the netfilter /// matching objects (e.g. tables, chains, rules, ...). /// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter. -pub fn get_list_of_objects( +pub fn get_list_of_objects( msg_type: u16, seq: u32, - filter: Option<&NfNetlinkAttributes>, + filter: Option<&T>, ) -> Result, Error> { let mut buffer = Vec::new(); let mut writer = NfNetlinkWriter::new(&mut buffer); @@ -167,7 +167,10 @@ pub fn get_list_of_objects( None, ); if let Some(filter) = filter { - filter.serialize(&mut writer); + let buf = writer.add_data_zeroed(filter.get_size()); + unsafe { + filter.write_payload(buf.as_mut_ptr()); + } } writer.finalize_writing_object(); Ok(buffer) @@ -180,11 +183,11 @@ pub fn get_list_of_objects( pub fn list_objects_with_data<'a, Object, Accumulator>( data_type: u16, cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), Error>, - filter: Option<&NfNetlinkAttributes>, + filter: Option<&Object>, working_data: &'a mut Accumulator, ) -> Result<(), Error> where - Object: NfNetlinkObject, + Object: NfNetlinkObject + NfNetlinkAttribute, { debug!("Listing objects of kind {}", data_type); let sock = socket::socket( diff --git a/src/rule.rs b/src/rule.rs index a596fce..5f2889e 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,20 +1,23 @@ use crate::expr::ExpressionList; -use crate::nlmsg::{ - NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, -}; -use crate::parser::InnerFormat; -use crate::parser::{parse_object, DecodeError}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::{DecodeError, InnerFormat, Parsable}; use crate::query::list_objects_with_data; use crate::sys::{self, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_ACK, NLM_F_CREATE}; use crate::{chain::Chain, MsgType}; -use crate::{impl_attr_getters_and_setters, ProtocolFamily}; +use crate::{impl_attr_getters_and_setters, impl_nfnetlinkattribute, ProtocolFamily}; use std::convert::TryFrom; use std::fmt::Debug; /// A nftables firewall rule. -#[derive(PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Default)] pub struct Rule { - inner: NfNetlinkAttributes, + id: Option, + handle: Option, + position: Option, + table: Option, + chain: Option, + userdata: Option>, + expressions: Option, family: ProtocolFamily, } @@ -23,21 +26,18 @@ impl Rule { /// /// [`Chain`]: struct.Chain.html pub fn new(chain: &Chain) -> Result { - let inner = NfNetlinkAttributes::new(); - Ok(Rule { - inner, - family: chain.get_family(), - } - .with_table( - chain - .get_table() - .ok_or(DecodeError::MissingChainInformationError)?, - ) - .with_chain( - chain - .get_name() - .ok_or(DecodeError::MissingChainInformationError)?, - )) + Ok(Rule::default() + .with_family(chain.get_family()) + .with_table( + chain + .get_table() + .ok_or(DecodeError::MissingChainInformationError)?, + ) + .with_chain( + chain + .get_name() + .ok_or(DecodeError::MissingChainInformationError)?, + )) } pub fn get_family(&self) -> ProtocolFamily { @@ -53,10 +53,6 @@ impl Rule { self } - fn raw_attributes(&self) -> &NfNetlinkAttributes { - &self.inner - } - /* /// Adds an expression to this rule. Expressions are evaluated from first to last added. /// As soon as an expression does not match the packet it's being evaluated for, evaluation @@ -191,143 +187,58 @@ impl NfNetlinkObject for Rule { seq, None, ); - self.inner.serialize(writer); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } writer.finalize_writing_object(); } } impl NfNetlinkDeserializable for Rule { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (inner, nfgenmsg, remaining_data) = - parse_object::(buf, NFT_MSG_NEWRULE, NFT_MSG_DELRULE)?; + let (mut obj, nfgenmsg, remaining_data) = + Self::parse_object(buf, NFT_MSG_NEWRULE, NFT_MSG_DELRULE)?; + obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; - Ok(( - Self { - inner, - family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, - }, - remaining_data, - )) + Ok((obj, remaining_data)) } } impl_attr_getters_and_setters!( Rule, [ - (get_id, set_id, with_id, sys::NFTA_RULE_ID, U32, u32), - (get_handle, set_handle, with_handle, sys::NFTA_RULE_HANDLE, U64, u64), + (get_table, set_table, with_table, sys::NFTA_RULE_TABLE, table, String), + (get_chain, set_chain, with_chain, sys::NFTA_RULE_CHAIN, chain, String), + (get_handle, set_handle, with_handle, sys::NFTA_RULE_HANDLE, handle, u64), + (get_expressions, set_expressions, with_expressions, sys::NFTA_RULE_EXPRESSIONS, expressions, ExpressionList), // Sets the position of this rule within the chain it lives in. By default a new rule is added // to the end of the chain. - (get_position, set_position, with_position, sys::NFTA_RULE_POSITION, U64, u64), - (get_table, set_table, with_table, sys::NFTA_RULE_TABLE, String, String), - (get_chain, set_chain, with_chain, sys::NFTA_RULE_CHAIN, String, String), + (get_position, set_position, with_position, sys::NFTA_RULE_POSITION, position, u64), ( get_userdata, set_userdata, with_userdata, sys::NFTA_RULE_USERDATA, - VecU8, + userdata, Vec ), - (get_expressions, set_expressions, with_expressions, sys::NFTA_RULE_EXPRESSIONS, ExpressionList, ExpressionList) + (get_id, set_id, with_id, sys::NFTA_RULE_ID, id, u32) ] ); -/* - -impl PartialEq for Rule { - fn eq(&self, other: &Self) -> bool { - if self.get_chain() != other.get_chain() { - return false; - } - - unsafe { - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16) - { - if self.get_handle() != other.get_handle() { - return false; - } - } - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16) - { - if self.get_position() != other.get_position() { - return false; - } - } - } - - return false; - } -} - -unsafe impl NlMsg for Rule { - unsafe fn write(&self, buf: &mut Vec, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWRULE, - MsgType::Del => libc::NFT_MSG_DELRULE, - }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16, - MsgType::Del => 0u16, - } | libc::NLM_F_ACK as u16; - /* - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.chain.get_table().get_family() as u16, - flags, - seq, - ); - sys::nftnl_rule_nlmsg_build_payload(header, self.rule); - */ - } -} - -impl Drop for Rule { - fn drop(&mut self) { - unsafe { sys::nftnl_rule_free(self.rule) }; - } -} - -pub struct RuleExprsIter { - rule: Rc, - iter: *mut sys::nftnl_expr_iter, -} - -impl RuleExprsIter { - fn new(rule: Rc) -> Self { - let iter = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) }); - RuleExprsIter { rule, iter } - } -} - -impl Iterator for RuleExprsIter { - type Item = ExpressionWrapper; - - fn next(&mut self) -> Option { - let next = unsafe { sys::nftnl_expr_iter_next(self.iter) }; - if next.is_null() { - trace!("RulesExprsIter iterator ending"); - None - } else { - trace!("RulesExprsIter returning new expression"); - Some(ExpressionWrapper { - expr: next, - rule: self.rule.clone(), - }) - } - } -} - -impl Drop for RuleExprsIter { - fn drop(&mut self) { - unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; - } -} -*/ +impl_nfnetlinkattribute!(inline : Rule, [ + (sys::NFTA_RULE_TABLE, table), + (sys::NFTA_RULE_CHAIN, chain), + (sys::NFTA_RULE_HANDLE, handle), + (sys::NFTA_RULE_EXPRESSIONS, expressions), + (sys::NFTA_RULE_POSITION, position), + ( + sys::NFTA_RULE_USERDATA, + userdata + ), + (sys::NFTA_RULE_ID, id) +]); pub fn list_rules_for_chain(chain: &Chain) -> Result, crate::query::Error> { let mut result = Vec::new(); @@ -338,7 +249,7 @@ pub fn list_rules_for_chain(chain: &Chain) -> Result, crate::query::Er Ok(()) }, // only retrieve rules from the currently targetted chain - Some(&Rule::new(chain)?.raw_attributes()), + Some(&Rule::new(chain)?), &mut result, )?; Ok(result) diff --git a/src/table.rs b/src/table.rs index 5074ac9..96a4964 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,62 +1,33 @@ use std::convert::TryFrom; use std::fmt::Debug; -use crate::nlmsg::{ - NfNetlinkAttributes, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter, -}; -use crate::parser::{parse_object, DecodeError, InnerFormat}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::Parsable; +use crate::parser::{DecodeError, InnerFormat}; use crate::sys::{ - self, NFNL_SUBSYS_NFTABLES, NFTA_OBJ_TABLE, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, - NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, + self, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, }; -use crate::{impl_attr_getters_and_setters, MsgType, ProtocolFamily}; +use crate::{impl_attr_getters_and_setters, impl_nfnetlinkattribute, MsgType, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html -#[derive(PartialEq, Eq)] +#[derive(Default, PartialEq, Eq)] pub struct Table { - inner: NfNetlinkAttributes, - family: ProtocolFamily, + flags: Option, + name: Option, + userdata: Option>, + pub family: ProtocolFamily, } impl Table { pub fn new(family: ProtocolFamily) -> Table { - Table { - inner: NfNetlinkAttributes::new(), - family, - } - } - - pub fn get_family(&self) -> ProtocolFamily { - self.family - } - - /* - /// Returns a textual description of the table. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_table_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.table, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } + let mut res = Self::default(); + res.family = family; + res } - */ } -/* -impl PartialEq for Table { - fn eq(&self, other: &Self) -> bool { - self.get_name() == other.get_name() && self.family == other.family - } -} -*/ impl Debug for Table { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -83,42 +54,49 @@ impl NfNetlinkObject for Table { seq, None, ); - self.inner.serialize(writer); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } writer.finalize_writing_object(); } } impl NfNetlinkDeserializable for Table { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (inner, nfgenmsg, remaining_data) = - parse_object::(buf, NFT_MSG_NEWTABLE, NFT_MSG_DELTABLE)?; + let (mut obj, nfgenmsg, remaining_data) = + Self::parse_object(buf, NFT_MSG_NEWTABLE, NFT_MSG_DELTABLE)?; + obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; - Ok(( - Self { - inner, - family: ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?, - }, - remaining_data, - )) + Ok((obj, remaining_data)) } } impl_attr_getters_and_setters!( Table, [ - (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, U32, u32), - (get_name, set_name, with_name, sys::NFTA_TABLE_NAME, String, String), + (get_name, set_name, with_name, sys::NFTA_TABLE_NAME, name, String), + (get_flags, set_flags, with_flags, sys::NFTA_TABLE_FLAGS, flags, u32), ( get_userdata, set_userdata, with_userdata, sys::NFTA_TABLE_USERDATA, - VecU8, + userdata, Vec ) ] ); +impl_nfnetlinkattribute!( + inline : Table, + [ + (sys::NFTA_TABLE_NAME, name), + (sys::NFTA_TABLE_FLAGS, flags), + (sys::NFTA_TABLE_USERDATA, userdata) + ] +); + pub fn list_tables() -> Result, crate::query::Error> { let mut result = Vec::new(); crate::query::list_objects_with_data( diff --git a/tests/batch.rs b/tests/batch.rs index 5b3380c..5a766b0 100644 --- a/tests/batch.rs +++ b/tests/batch.rs @@ -1,27 +1,15 @@ -mod sys; use std::mem::size_of; -use libc::AF_NETLINK; -use libc::AF_UNIX; -use libc::AF_UNSPEC; -use libc::NFNL_MSG_BATCH_BEGIN; -use libc::NLM_F_ACK; -use libc::NLM_F_REQUEST; +use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST}; use nix::libc::NFNL_MSG_BATCH_END; + use rustables::nlmsg::NfNetlinkDeserializable; -use rustables::nlmsg::NfNetlinkObject; -use rustables::parser::pad_netlink_object; -use rustables::parser::pad_netlink_object_with_variable_size; -use rustables::parser::NlMsg; -use rustables::parser::{get_operation_from_nlmsghdr_type, parse_nlmsg, parse_object}; -use rustables::sys::nfgenmsg; -use rustables::sys::nlmsghdr; -use rustables::sys::NFNETLINK_V0; -use rustables::sys::NFNL_SUBSYS_NFTABLES; +use rustables::parser::{pad_netlink_object_with_variable_size, parse_nlmsg, NlMsg}; +use rustables::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; use rustables::{Batch, MsgType, Table}; -mod lib; -use lib::*; +mod common; +use common::*; const HEADER_SIZE: u32 = pad_netlink_object_with_variable_size(size_of::() + size_of::()) as u32; diff --git a/tests/chain.rs b/tests/chain.rs index 09594f1..99347da 100644 --- a/tests/chain.rs +++ b/tests/chain.rs @@ -1,9 +1,14 @@ -mod sys; -use rustables::{parser::get_operation_from_nlmsghdr_type, ChainType, Hook, HookClass, MsgType}; -use sys::*; +use rustables::{ + parser::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, + NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, + }, + ChainType, Hook, HookClass, MsgType, +}; -mod lib; -use lib::*; +mod common; +use common::*; #[test] fn new_empty_chain() { diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..99b5a6a --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,199 @@ +#![allow(dead_code)] +use std::ffi::CString; + +use rustables::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +//use rustables::set::SetKey; +use rustables::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; + +//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, Rule, Set, Table}; + +pub const TABLE_NAME: &'static str = "mocktable"; +pub const CHAIN_NAME: &'static str = "mockchain"; +pub const SET_NAME: &'static str = "mockset"; + +pub const TABLE_USERDATA: &'static str = "mocktabledata"; +pub const CHAIN_USERDATA: &'static str = "mockchaindata"; +pub const RULE_USERDATA: &'static str = "mockruledata"; +pub const SET_USERDATA: &'static str = "mocksetdata"; + +pub const SET_ID: u32 = 123456; + +type NetLinkType = u16; + +#[derive(Debug, thiserror::Error)] +#[error("empty data")] +pub struct EmptyDataError; + +#[derive(Debug, Clone, Eq, Ord)] +pub enum NetlinkExpr { + Nested(NetLinkType, Vec), + Final(NetLinkType, Vec), + List(Vec), +} + +impl NetlinkExpr { + pub fn to_raw(self) -> Vec { + match self.sort() { + NetlinkExpr::Final(ty, val) => { + let len = val.len() + 4; + let mut res = Vec::with_capacity(len); + + res.extend(&(len as u16).to_le_bytes()); + res.extend(&ty.to_le_bytes()); + res.extend(val); + // alignment + while res.len() % 4 != 0 { + res.push(0); + } + + res + } + NetlinkExpr::Nested(ty, exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut sub = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + sub.append(&mut expr.to_raw()); + } + + let len = sub.len() + 4; + let mut res = Vec::with_capacity(len); + + // set the "NESTED" flag + res.extend(&(len as u16).to_le_bytes()); + res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes()); + res.extend(sub); + + res + } + NetlinkExpr::List(exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut list = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + list.append(&mut expr.to_raw()); + } + + list + } + } + } + + pub fn sort(self) -> Self { + match self { + NetlinkExpr::Final(_, _) => self, + NetlinkExpr::Nested(ty, mut exprs) => { + exprs.sort(); + NetlinkExpr::Nested(ty, exprs) + } + NetlinkExpr::List(mut exprs) => { + exprs.sort(); + NetlinkExpr::List(exprs) + } + } + } +} + +impl PartialEq for NetlinkExpr { + fn eq(&self, other: &Self) -> bool { + match (self.clone().sort(), other.clone().sort()) { + (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2, + _ => false, + } + } +} + +impl PartialOrd for NetlinkExpr { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + ( + NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _), + NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _), + ) => k1.partial_cmp(k2), + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2), + (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less), + (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater), + } + } +} + +pub fn get_test_table() -> Table { + Table::new(ProtocolFamily::Inet) + .with_name(TABLE_NAME) + .with_flags(0u32) +} + +pub fn get_test_table_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final( + NFTA_TABLE_NAME, + CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), + ), + ]) + .sort() +} + +pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final( + NFTA_TABLE_NAME, + CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), + ), + NetlinkExpr::Final( + NFTA_TABLE_USERDATA, + CString::new(TABLE_USERDATA).unwrap().to_bytes().to_vec(), + ), + ]) + .sort() +} + +pub fn get_test_chain() -> Chain { + Chain::new(&get_test_table()).with_name(CHAIN_NAME) +} + +pub fn get_test_rule() -> Rule { + Rule::new(&get_test_chain()).unwrap() +} + +/* +pub fn get_test_set() -> Set { + Set::new(SET_NAME, SET_ID, Rc::new(get_test_table())) +} +*/ + +pub fn get_test_nlmsg_with_msg_type<'a>( + buf: &'a mut Vec, + obj: &mut impl NfNetlinkObject, + msg_type: MsgType, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + let mut writer = NfNetlinkWriter::new(buf); + obj.add_or_remove(&mut writer, msg_type, 0); + + let (hdr, msg) = + rustables::parser::parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); + + let (nfgenmsg, raw_value) = match msg { + rustables::parser::NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), + _ => panic!("Invalid return value type, expected a valid message"), + }; + + // sanity checks on the global message (this should be very similar/factorisable for the + // most part in other tests) + // TODO: check the messages flags + assert_eq!(nfgenmsg.res_id.to_be(), 0); + + (hdr, nfgenmsg, raw_value) +} + +pub fn get_test_nlmsg<'a>( + buf: &'a mut Vec, + obj: &mut impl NfNetlinkObject, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add) +} diff --git a/tests/expr.rs b/tests/expr.rs index 46b50f0..c5ac8a2 100644 --- a/tests/expr.rs +++ b/tests/expr.rs @@ -1,4 +1,13 @@ -use rustables::expr::{Bitwise, ExpressionList, Immediate, Register, VerdictKind}; +use rustables::{ + expr::{Bitwise, ExpressionList, Immediate, Meta, MetaType, Register, VerdictKind}, + sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, + NFTA_BITWISE_XOR, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, + NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, NFTA_META_DREG, NFTA_META_KEY, + NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, + NFT_META_PROTOCOL, NFT_REG_1, NFT_REG_VERDICT, + }, +}; //use rustables::expr::{ // Bitwise, Cmp, CmpOp, Conntrack, Counter, Expression, HeaderField, IcmpCode, Immediate, Log, // LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, @@ -11,12 +20,10 @@ use rustables::expr::{Bitwise, ExpressionList, Immediate, Register, VerdictKind} //use std::ffi::CStr; use std::net::Ipv4Addr; -mod sys; use libc::NF_DROP; -use sys::*; -mod lib; -use lib::*; +mod common; +use common::*; #[test] fn bitwise_expr_is_valid() { @@ -341,44 +348,48 @@ fn immediate_expr_is_valid() { // ); //} // -//#[test] -//fn meta_expr_is_valid() { -// let meta = Meta::Protocol; -// let mut rule = get_test_rule(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &meta); -// assert_eq!(nlmsghdr.nlmsg_len, 92); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Nested( -// NFTA_RULE_EXPRESSIONS, -// vec![NetlinkExpr::Nested( -// NFTA_LIST_ELEM, -// vec![ -// NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta\0".to_vec()), -// NetlinkExpr::Nested( -// NFTA_EXPR_DATA, -// vec![ -// NetlinkExpr::Final( -// NFTA_META_KEY, -// NFT_META_PROTOCOL.to_be_bytes().to_vec() -// ), -// NetlinkExpr::Final( -// NFTA_META_DREG, -// NFT_REG_1.to_be_bytes().to_vec() -// ) -// ] -// ) -// ] -// )] -// ) -// ]) -// .to_raw() -// ); -//} +#[test] +fn meta_expr_is_valid() { + let meta = Meta::default() + .with_key(MetaType::Protocol) + .with_dreg(Register::Reg1); + let mut rule = get_test_rule().with_expressions(vec![meta]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_META_KEY, + NFT_META_PROTOCOL.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_META_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} // //#[test] //fn nat_expr_is_valid() { diff --git a/tests/lib.rs b/tests/lib.rs deleted file mode 100644 index 99b5a6a..0000000 --- a/tests/lib.rs +++ /dev/null @@ -1,199 +0,0 @@ -#![allow(dead_code)] -use std::ffi::CString; - -use rustables::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; -//use rustables::set::SetKey; -use rustables::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; - -//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, Rule, Set, Table}; - -pub const TABLE_NAME: &'static str = "mocktable"; -pub const CHAIN_NAME: &'static str = "mockchain"; -pub const SET_NAME: &'static str = "mockset"; - -pub const TABLE_USERDATA: &'static str = "mocktabledata"; -pub const CHAIN_USERDATA: &'static str = "mockchaindata"; -pub const RULE_USERDATA: &'static str = "mockruledata"; -pub const SET_USERDATA: &'static str = "mocksetdata"; - -pub const SET_ID: u32 = 123456; - -type NetLinkType = u16; - -#[derive(Debug, thiserror::Error)] -#[error("empty data")] -pub struct EmptyDataError; - -#[derive(Debug, Clone, Eq, Ord)] -pub enum NetlinkExpr { - Nested(NetLinkType, Vec), - Final(NetLinkType, Vec), - List(Vec), -} - -impl NetlinkExpr { - pub fn to_raw(self) -> Vec { - match self.sort() { - NetlinkExpr::Final(ty, val) => { - let len = val.len() + 4; - let mut res = Vec::with_capacity(len); - - res.extend(&(len as u16).to_le_bytes()); - res.extend(&ty.to_le_bytes()); - res.extend(val); - // alignment - while res.len() % 4 != 0 { - res.push(0); - } - - res - } - NetlinkExpr::Nested(ty, exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut sub = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - sub.append(&mut expr.to_raw()); - } - - let len = sub.len() + 4; - let mut res = Vec::with_capacity(len); - - // set the "NESTED" flag - res.extend(&(len as u16).to_le_bytes()); - res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes()); - res.extend(sub); - - res - } - NetlinkExpr::List(exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut list = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - list.append(&mut expr.to_raw()); - } - - list - } - } - } - - pub fn sort(self) -> Self { - match self { - NetlinkExpr::Final(_, _) => self, - NetlinkExpr::Nested(ty, mut exprs) => { - exprs.sort(); - NetlinkExpr::Nested(ty, exprs) - } - NetlinkExpr::List(mut exprs) => { - exprs.sort(); - NetlinkExpr::List(exprs) - } - } - } -} - -impl PartialEq for NetlinkExpr { - fn eq(&self, other: &Self) -> bool { - match (self.clone().sort(), other.clone().sort()) { - (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2, - (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2, - (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2, - _ => false, - } - } -} - -impl PartialOrd for NetlinkExpr { - fn partial_cmp(&self, other: &Self) -> Option { - match (self, other) { - ( - NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _), - NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _), - ) => k1.partial_cmp(k2), - (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2), - (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less), - (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater), - } - } -} - -pub fn get_test_table() -> Table { - Table::new(ProtocolFamily::Inet) - .with_name(TABLE_NAME) - .with_flags(0u32) -} - -pub fn get_test_table_raw_expr() -> NetlinkExpr { - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - NetlinkExpr::Final( - NFTA_TABLE_NAME, - CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), - ), - ]) - .sort() -} - -pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr { - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - NetlinkExpr::Final( - NFTA_TABLE_NAME, - CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), - ), - NetlinkExpr::Final( - NFTA_TABLE_USERDATA, - CString::new(TABLE_USERDATA).unwrap().to_bytes().to_vec(), - ), - ]) - .sort() -} - -pub fn get_test_chain() -> Chain { - Chain::new(&get_test_table()).with_name(CHAIN_NAME) -} - -pub fn get_test_rule() -> Rule { - Rule::new(&get_test_chain()).unwrap() -} - -/* -pub fn get_test_set() -> Set { - Set::new(SET_NAME, SET_ID, Rc::new(get_test_table())) -} -*/ - -pub fn get_test_nlmsg_with_msg_type<'a>( - buf: &'a mut Vec, - obj: &mut impl NfNetlinkObject, - msg_type: MsgType, -) -> (nlmsghdr, nfgenmsg, &'a [u8]) { - let mut writer = NfNetlinkWriter::new(buf); - obj.add_or_remove(&mut writer, msg_type, 0); - - let (hdr, msg) = - rustables::parser::parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); - - let (nfgenmsg, raw_value) = match msg { - rustables::parser::NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), - _ => panic!("Invalid return value type, expected a valid message"), - }; - - // sanity checks on the global message (this should be very similar/factorisable for the - // most part in other tests) - // TODO: check the messages flags - assert_eq!(nfgenmsg.res_id.to_be(), 0); - - (hdr, nfgenmsg, raw_value) -} - -pub fn get_test_nlmsg<'a>( - buf: &'a mut Vec, - obj: &mut impl NfNetlinkObject, -) -> (nlmsghdr, nfgenmsg, &'a [u8]) { - get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add) -} diff --git a/tests/rule.rs b/tests/rule.rs index f7e23a0..de5be3c 100644 --- a/tests/rule.rs +++ b/tests/rule.rs @@ -1,9 +1,14 @@ -mod sys; -use rustables::parser::get_operation_from_nlmsghdr_type; -use sys::*; +use rustables::{ + parser::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA, + NFT_MSG_DELRULE, NFT_MSG_NEWRULE, + }, + MsgType, +}; -mod lib; -use lib::*; +mod common; +use common::*; #[test] fn new_empty_rule() { @@ -27,93 +32,99 @@ fn new_empty_rule() { ); } -//#[test] -//fn new_empty_rule_with_userdata() { -// let mut rule = get_test_rule(); -// rule.set_userdata(CStr::from_bytes_with_nul(RULE_USERDATA).unwrap()); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut rule); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_NEWRULE as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 72); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.to_vec()) -// ]) -// .to_raw() -// ); -//} -// -//#[test] -//fn new_empty_rule_with_position_and_handle() { -// let handle = 1337; -// let position = 42; -// let mut rule = get_test_rule(); -// rule.set_handle(handle); -// rule.set_position(position); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut rule); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_NEWRULE as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 76); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), -// NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()), -// ]) -// .to_raw() -// ); -//} -// -//#[test] -//fn delete_empty_rule() { -// let mut rule = get_test_rule(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut rule, MsgType::Del); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_DELRULE as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 52); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// ]) -// .to_raw() -// ); -//} -// -//#[test] -//fn delete_empty_rule_with_handle() { -// let handle = 42; -// let mut rule = get_test_rule(); -// rule.set_handle(handle); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut rule, MsgType::Del); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_DELRULE as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 64); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), -// ]) -// .to_raw() -// ); -//} +#[test] +fn new_empty_rule_with_userdata() { + let mut rule = get_test_rule().with_userdata(RULE_USERDATA); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 68); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.as_bytes().to_vec()) + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_rule_with_position_and_handle() { + let handle: u64 = 1337; + let position: u64 = 42; + let mut rule = get_test_rule().with_handle(handle).with_position(position); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 76); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule() { + let mut rule = get_test_rule(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule_with_handle() { + let handle: u64 = 42; + let mut rule = get_test_rule().with_handle(handle); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 64); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} diff --git a/tests/table.rs b/tests/table.rs index 5961d65..44394c9 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -1,13 +1,12 @@ -mod sys; use rustables::{ nlmsg::NfNetlinkDeserializable, parser::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize}, + sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE}, MsgType, Table, }; -use sys::*; -mod lib; -use lib::*; +mod common; +use common::*; #[test] fn new_empty_table() { -- cgit v1.2.3 From 4b60b3cd41f5198c47a260ce69abf4c15b60ca92 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Sat, 3 Dec 2022 22:47:53 +0100 Subject: convert the expressions to the new macros --- src/chain.rs | 4 +- src/expr/bitwise.rs | 59 +++++----------- src/expr/immediate.rs | 33 +++------ src/expr/log.rs | 36 ++++------ src/expr/meta.rs | 45 ++++-------- src/expr/mod.rs | 64 ++++++----------- src/expr/reject.rs | 34 +++------ src/expr/verdict.rs | 46 ++++-------- src/parser.rs | 190 -------------------------------------------------- 9 files changed, 103 insertions(+), 408 deletions(-) (limited to 'src/expr/log.rs') diff --git a/src/chain.rs b/src/chain.rs index 8bdf95b..7a62fb2 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -4,11 +4,11 @@ use rustables_macros::nfnetlink_struct; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; use crate::parser::{DecodeError, Parsable}; use crate::sys::{ - self, NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, + NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE, }; -use crate::{create_wrapper_type, MsgType, ProtocolFamily, Table}; +use crate::{MsgType, ProtocolFamily, Table}; use std::convert::TryFrom; use std::fmt::Debug; diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index 73c2467..29d2d63 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,46 +1,25 @@ +use rustables_macros::nfnetlink_struct; + use super::{Expression, ExpressionData, Register}; -use crate::create_wrapper_type; use crate::parser::DecodeError; -use crate::sys; +use crate::sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, +}; -create_wrapper_type!( - inline: Bitwise, - [ - ( - get_sreg, - set_sreg, - with_sreg, - sys::NFTA_BITWISE_SREG, - sreg, - Register - ), - ( - get_dreg, - set_dreg, - with_dreg, - sys::NFTA_BITWISE_DREG, - dreg, - Register - ), - (get_len, set_len, with_len, sys::NFTA_BITWISE_LEN, len, u32), - ( - get_mask, - set_mask, - with_mask, - sys::NFTA_BITWISE_MASK, - mask, - ExpressionData - ), - ( - get_xor, - set_xor, - with_xor, - sys::NFTA_BITWISE_XOR, - xor, - ExpressionData - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Bitwise { + #[field(NFTA_BITWISE_SREG)] + sreg: Register, + #[field(NFTA_BITWISE_DREG)] + dreg: Register, + #[field(NFTA_BITWISE_LEN)] + len: u32, + #[field(NFTA_BITWISE_MASK)] + mask: ExpressionData, + #[field(NFTA_BITWISE_XOR)] + xor: ExpressionData, +} impl Expression for Bitwise { fn get_name() -> &'static str { diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 925ca06..134f7e1 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,27 +1,16 @@ +use rustables_macros::nfnetlink_struct; + use super::{Expression, ExpressionData, Register}; -use crate::{create_wrapper_type, sys}; +use crate::sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}; -create_wrapper_type!( - inline: Immediate, - [ - ( - get_dreg, - set_dreg, - with_dreg, - sys::NFTA_IMMEDIATE_DREG, - dreg, - Register - ), - ( - get_data, - set_data, - with_data, - sys::NFTA_IMMEDIATE_DATA, - data, - ExpressionData - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Immediate { + #[field(NFTA_IMMEDIATE_DREG)] + dreg: Register, + #[field(NFTA_IMMEDIATE_DATA)] + data: ExpressionData, +} impl Immediate { pub fn new_data(data: Vec, register: Register) -> Self { diff --git a/src/expr/log.rs b/src/expr/log.rs index 82c201d..3c72257 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,29 +1,17 @@ +use rustables_macros::nfnetlink_struct; + use super::{Expression, ExpressionError}; -use crate::create_wrapper_type; -use crate::sys; +use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}; -// A Log expression will log all packets that match the rule. -create_wrapper_type!( - inline: Log, - [ - ( - get_group, - set_group, - with_group, - sys::NFTA_LOG_GROUP, - group, - u32 - ), - ( - get_prefix, - set_prefix, - with_prefix, - sys::NFTA_LOG_PREFIX, - prefix, - String - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +/// A Log expression will log all packets that match the rule. +pub struct Log { + #[field(NFTA_LOG_GROUP)] + group: u32, + #[field(NFTA_LOG_PREFIX)] + prefix: String, +} impl Log { pub fn new( diff --git a/src/expr/meta.rs b/src/expr/meta.rs index bb8023d..c4c1adb 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,11 +1,11 @@ -use super::{Expression, Register, Rule}; +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register}; use crate::{ - create_wrapper_type, nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, parser::DecodeError, sys, }; -use std::convert::TryFrom; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -78,35 +78,16 @@ impl NfNetlinkDeserializable for MetaType { } } -create_wrapper_type!( - inline: Meta, - [ - ( - get_dreg, - set_dreg, - with_dreg, - sys::NFTA_META_DREG, - dreg, - Register - ), - ( - get_key, - set_key, - with_key, - sys::NFTA_META_KEY, - key, - MetaType - ), - ( - get_sreg, - set_sreg, - with_sreg, - sys::NFTA_META_SREG, - sreg, - Register - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Meta { + #[field(sys::NFTA_META_DREG)] + dreg: Register, + #[field(sys::NFTA_META_KEY)] + key: MetaType, + #[field(sys::NFTA_META_SREG)] + sreg: Register, +} impl Expression for Meta { fn get_name() -> &'static str { diff --git a/src/expr/mod.rs b/src/expr/mod.rs index e5c2729..63385e0 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -6,8 +6,6 @@ use std::fmt::Debug; use std::mem::transmute; -use super::rule::Rule; -use crate::create_wrapper_type; use crate::nlmsg::NfNetlinkAttribute; use crate::nlmsg::NfNetlinkDeserializable; use crate::parser::pad_netlink_object; @@ -15,7 +13,10 @@ use crate::parser::pad_netlink_object_with_variable_size; use crate::parser::write_attribute; use crate::parser::DecodeError; use crate::sys::{self, nlattr}; -use libc::NLA_TYPE_MASK; +use crate::sys::{ + NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NLA_TYPE_MASK, +}; +use rustables_macros::nfnetlink_struct; use thiserror::Error; mod bitwise; @@ -105,26 +106,14 @@ pub trait Expression { fn get_name() -> &'static str; } -create_wrapper_type!( - nested without_deser : RawExpression, [ - // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. - ( - get_name, - set_name, - with_name, - sys::NFTA_EXPR_NAME, - name, - String - ), - ( - get_data, - set_data, - with_data, - sys::NFTA_EXPR_DATA, - data, - ExpressionVariant - ) -]); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true, derive_decoder = false)] +pub struct RawExpression { + #[field(NFTA_EXPR_NAME)] + name: String, + #[field(NFTA_EXPR_DATA)] + data: ExpressionVariant, +} impl RawExpression { pub fn new(expr: T) -> Self @@ -338,27 +327,14 @@ where } } -create_wrapper_type!( - nested : ExpressionData, - [ - ( - get_value, - set_value, - with_value, - sys::NFTA_DATA_VALUE, - value, - Vec - ), - ( - get_verdict, - set_verdict, - with_verdict, - sys::NFTA_DATA_VERDICT, - verdict, - VerdictAttribute - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct ExpressionData { + #[field(NFTA_DATA_VALUE)] + value: Vec, + #[field(NFTA_DATA_VERDICT)] + verdict: VerdictAttribute, +} // default type for expressions that we do not handle yet #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/expr/reject.rs b/src/expr/reject.rs index e15f905..10b95ea 100644 --- a/src/expr/reject.rs +++ b/src/expr/reject.rs @@ -1,5 +1,6 @@ +use rustables_macros::nfnetlink_struct; + use crate::{ - create_wrapper_type, nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, parser::DecodeError, sys, @@ -13,28 +14,15 @@ impl Expression for Reject { } } -// A reject expression that defines the type of rejection message sent when discarding a packet. -create_wrapper_type!( - inline: Reject, - [ - ( - get_type, - set_type, - with_type, - sys::NFTA_REJECT_TYPE, - reject_type, - RejectType - ), - ( - get_icmp_code, - set_icmp_code, - with_icmp_code, - sys::NFTA_REJECT_ICMP_CODE, - icmp_code, - IcmpCode - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +/// A reject expression that defines the type of rejection message sent when discarding a packet. +pub struct Reject { + #[field(sys::NFTA_REJECT_TYPE, name_in_functions = "type")] + reject_type: RejectType, + #[field(sys::NFTA_REJECT_ICMP_CODE)] + icmp_code: IcmpCode, +} /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 547ba91..fc13f8a 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,13 +1,16 @@ use std::fmt::Debug; use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; +use rustables_macros::nfnetlink_struct; use super::{ExpressionData, Immediate, Register}; use crate::{ - create_wrapper_type, nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, parser::DecodeError, - sys::{self, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN}, + sys::{ + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, + NFT_GOTO, NFT_JUMP, NFT_RETURN, + }, }; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -53,35 +56,16 @@ impl NfNetlinkDeserializable for VerdictType { } } -create_wrapper_type!( - nested: VerdictAttribute, - [ - ( - get_code, - set_code, - with_code, - sys::NFTA_VERDICT_CODE, - code, - VerdictType - ), - ( - get_chain, - set_chain, - with_chain, - sys::NFTA_VERDICT_CHAIN, - chain, - String - ), - ( - get_chain_id, - set_chain_id, - with_chain_id, - sys::NFTA_VERDICT_CHAIN_ID, - chain_id, - u32 - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct VerdictAttribute { + #[field(NFTA_VERDICT_CODE)] + code: VerdictType, + #[field(NFTA_VERDICT_CHAIN)] + chain: String, + #[field(NFTA_VERDICT_CHAIN_ID)] + chain_id: u32, +} #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum VerdictKind { diff --git a/src/parser.rs b/src/parser.rs index 834874c..7d89a1e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -400,81 +400,6 @@ pub trait InnerFormat { ) -> Result, std::fmt::Error>; } -#[macro_export] -macro_rules! impl_attr_getters_and_setters { - (without_deser $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - impl $struct { - $( - #[allow(dead_code)] - pub fn $getter_name(&self) -> Option<&$type> { - self.$internal_name.as_ref() - } - - #[allow(dead_code)] - pub fn $setter_name(&mut self, val: impl Into<$type>) { - self.$internal_name = Some(val.into()); - } - - #[allow(dead_code)] - pub fn $in_place_edit_name(mut self, val: impl Into<$type>) -> Self { - self.$internal_name = Some(val.into()); - self - } - - )+ - } - - impl $crate::parser::InnerFormat for $struct { - fn inner_format_struct<'a, 'b: 'a>(&'a self, mut s: std::fmt::DebugStruct<'a, 'b>) -> Result, std::fmt::Error> { - $( - // Rewrite attributes names to be readable: 'sys::NFTA_CHAIN_NAME' -> 'name' - // Performance must be terrible, but this is the Debug impl anyway, so that - // must mean we can afford to be slow, right? ;) - if let Some(val) = self.$getter_name() { - let mut attr = stringify!($attr_name); - if let Some((nfta_idx, _match )) = attr.rmatch_indices("NFTA_").next() { - if let Some(underscore_idx) = &attr[nfta_idx+5..].find('_') { - attr = &attr[nfta_idx+underscore_idx+6..]; - } - } - let attr = attr.to_lowercase(); - s.field(&attr, val); - } - )+ - Ok(s) - } - } - - }; - (deser $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - impl $crate::nlmsg::AttributeDecoder for $struct { - #[allow(dead_code)] - fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), $crate::parser::DecodeError> { - use $crate::nlmsg::NfNetlinkDeserializable; - debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<$struct>()); - match attr_type { - $( - x if x == $attr_name => { - debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); - let (val, remaining) = <$type>::deserialize(buf)?; - if remaining.len() != 0 { - return Err($crate::parser::DecodeError::InvalidDataSize); - } - self.$setter_name(val); - Ok(()) - }, - )+ - _ => Err($crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), - } - } - } - }; - ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - $crate::impl_attr_getters_and_setters!(without_deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_attr_getters_and_setters!(deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - }; -} - pub trait Parsable where Self: Sized, @@ -520,118 +445,3 @@ where Ok((res, nfgenmsg, remaining_data)) } } - -#[macro_export] -macro_rules! impl_nfnetlinkattribute { - (__inner : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { - impl $struct { - fn inner_get_size(&self) -> usize { - use $crate::nlmsg::NfNetlinkAttribute; - use $crate::parser::{pad_netlink_object, pad_netlink_object_with_variable_size}; - use $crate::sys::nlattr; - let mut size = 0; - - $( - if let Some(val) = &self.$internal_name { - // Attribute header + attribute value - size += pad_netlink_object::() - + pad_netlink_object_with_variable_size(val.get_size()); - } - )+ - - size - } - - unsafe fn inner_write_payload(&self, mut addr: *mut u8) { - use $crate::nlmsg::NfNetlinkAttribute; - use $crate::parser::{pad_netlink_object, pad_netlink_object_with_variable_size}; - use $crate::sys::nlattr; - $( - if let Some(val) = &self.$internal_name { - debug!("writing attribute {} - {:?}", $attr_name, val); - - unsafe { - $crate::parser::write_attribute($attr_name, val, addr); - } - let size = pad_netlink_object::() - + pad_netlink_object_with_variable_size(val.get_size()); - #[allow(unused)] - { - addr = addr.offset(size as isize); - } - } - )+ - } - } - }; - (inline : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { - $crate::impl_nfnetlinkattribute!(__inner : $struct, [$(($attr_name, $internal_name)),+]); - - impl $crate::nlmsg::NfNetlinkAttribute for $struct { - fn get_size(&self) -> usize { - self.inner_get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - self.inner_write_payload(addr); - } - } - }; - (nested : $struct:ident, [$(($attr_name:expr, $internal_name:ident)),+]) => { - $crate::impl_nfnetlinkattribute!(__inner : $struct, [$(($attr_name, $internal_name)),+]); - - impl $crate::nlmsg::NfNetlinkAttribute for $struct { - fn is_nested(&self) -> bool { - true - } - - fn get_size(&self) -> usize { - self.inner_get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - self.inner_write_payload(addr); - } - } - }; -} - -#[macro_export] -macro_rules! create_wrapper_type { - (without_deser : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - #[derive(Clone, PartialEq, Eq, Default)] - pub struct $struct { - $( - $internal_name: Option<$type> - ),+ - } - - $crate::impl_attr_getters_and_setters!(without_deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - - impl std::fmt::Debug for $struct { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use $crate::parser::InnerFormat; - self.inner_format_struct(f.debug_struct(stringify!($struct)))? - .finish() - } - } - - impl $crate::nlmsg::NfNetlinkDeserializable for $struct { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), $crate::parser::DecodeError> { - Ok(($crate::parser::read_attributes(buf)?, &[])) - } - } - }; - ($struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_wrapper_type!(without_deser : $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_attr_getters_and_setters!(deser $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - }; - (inline $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_wrapper_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_nfnetlinkattribute!(inline : $struct, [$(($attr_name, $internal_name)),+]); - }; - (nested $($($attrs:ident) +)? : $struct:ident, [$(($getter_name:ident, $setter_name:ident, $in_place_edit_name:ident, $attr_name:expr, $internal_name:ident, $type:ty)),+]) => { - create_wrapper_type!($($($attrs) + :)? $struct, [$(($getter_name, $setter_name, $in_place_edit_name, $attr_name, $internal_name, $type)),+]); - $crate::impl_nfnetlinkattribute!(nested : $struct, [$(($attr_name, $internal_name)),+]); - }; -} -- cgit v1.2.3 From edb440a952320ea4f021c1d7063ff6d5f2f13818 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Sat, 3 Dec 2022 22:53:23 +0100 Subject: Macros: introduce a macro to simplify enums --- macros/Cargo.toml | 2 +- macros/src/lib.rs | 194 +++++++++++++++++++++++++++++++++++++++++++++------ src/chain_methods.rs | 40 +++++++++++ src/expr/counter.rs | 43 +++--------- src/expr/log.rs | 2 +- src/expr/meta.rs | 46 +----------- src/expr/mod.rs | 5 +- src/expr/register.rs | 33 ++------- src/expr/reject.rs | 71 +++---------------- src/expr/verdict.rs | 44 ++---------- src/lib.rs | 1 - src/parser.rs | 10 +-- tests/expr.rs | 83 +++++++++++----------- 13 files changed, 293 insertions(+), 281 deletions(-) create mode 100644 src/chain_methods.rs (limited to 'src/expr/log.rs') diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 7d9167f..82c8ad6 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" proc-macro = true [dependencies] -syn = { version = "1.0", features = ["full", "extra-traits"] } +syn = { version = "1.0", features = ["full"] } quote = "1.0" proc-macro2 = "1.0" proc-macro-error = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 38cde50..11aedaf 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,28 +1,26 @@ use proc_macro::TokenStream; -use proc_macro2::Group; +use proc_macro2::{Group, Span}; use quote::quote; use proc_macro_error::{abort, proc_macro_error}; use syn::parse::Parser; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::token::Struct; use syn::{ - parse, parse2, parse_macro_input, Attribute, Expr, ExprLit, FnArg, Ident, ItemFn, ItemStruct, - Lit, Meta, NestedMeta, Pat, PatIdent, Path, Result, ReturnType, Token, Type, TypePath, + parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result, + Token, Type, TypePath, Visibility, }; -use syn::{parse::Parse, PatReference}; -use syn::{parse::ParseStream, TypeReference}; struct Field<'a> { name: &'a Ident, ty: &'a Type, args: FieldArgs, netlink_type: Path, + vis: &'a Visibility, attrs: Vec<&'a Attribute>, } -#[derive(Debug, Default)] +#[derive(Default)] struct FieldArgs { netlink_type: Option, override_function_name: Option, @@ -68,7 +66,6 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result { Ok(args) } -#[derive(Debug)] struct StructArgs { nested: bool, derive_decoder: bool, @@ -85,12 +82,10 @@ impl Default for StructArgs { } } -fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { - if input.is_empty() { - return Ok(()); - } +fn parse_struct_args(input: TokenStream) -> Result { + let mut args = StructArgs::default(); let parser = Punctuated::::parse_terminated; - let attribute_args = parser.parse(input)?; + let attribute_args = parser.parse(input.clone())?; for arg in attribute_args.iter() { if let Meta::NameValue(namevalue) = arg { let key = namevalue @@ -126,7 +121,7 @@ fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { abort!(arg.span(), "Unrecognized argument"); } } - Ok(()) + Ok(args) } #[proc_macro_error] @@ -135,8 +130,10 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let ast: ItemStruct = parse(item).unwrap(); let name = ast.ident; - let mut args = StructArgs::default(); - parse_struct_args(&mut args, attrs).expect("Could not parse the macro arguments"); + let args = match parse_struct_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; let mut fields = Vec::with_capacity(ast.fields.len()); let mut identical_fields = Vec::new(); @@ -145,15 +142,25 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { for attr in field.attrs.iter() { if let Some(id) = attr.path.get_ident() { if id == "field" { - let field_args = parse_field_args(attr.tokens.clone()) - .expect("Could not parse the field attributes"); + let field_args = match parse_field_args(attr.tokens.clone()) { + Ok(x) => x, + Err(_) => { + abort!(attr.tokens.span(), "Could not parse the field attributes") + } + }; if let Some(netlink_type) = field_args.netlink_type.clone() { fields.push(Field { name: field.ident.as_ref().expect("Should be a names struct"), ty: &field.ty, args: field_args, netlink_type, - attrs: field.attrs.iter().filter(|x| *x != attr).collect(), + vis: &field.vis, + // drop the "field" attribute + attrs: field + .attrs + .iter() + .filter(|x| x.path.get_ident() != attr.path.get_ident()) + .collect(), }); } else { abort!(attr.tokens.span(), "Missing Netlink Type in field"); @@ -297,7 +304,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let name = field.name; let ty = field.ty; let attrs = &field.attrs; - quote!( #(#attrs) * #name: Option<#ty>, ) + let vis = &field.vis; + quote!( #(#attrs) * #vis #name: Option<#ty>, ) }); let nfnetlinkdeserialize_impl = if args.derive_deserialize { quote!( @@ -327,3 +335,149 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { res.into() } + +struct Variant<'a> { + inner: &'a syn::Variant, + name: &'a Ident, + value: &'a Path, +} + +#[derive(Default)] +struct EnumArgs { + nested: bool, + ty: Option, +} + +fn parse_enum_args(input: TokenStream) -> Result { + let mut args = EnumArgs::default(); + let parser = Punctuated::::parse_terminated; + let attribute_args = parser.parse(input)?; + for arg in attribute_args.iter() { + match arg { + Meta::Path(path) => { + if args.ty.is_none() { + args.ty = Some(path.clone()); + } else { + abort!(arg.span(), "A value can only have a single representation"); + } + } + Meta::NameValue(namevalue) => { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "nested" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.nested = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } + _ => abort!(arg.span(), "Unrecognized argument"), + } + } + Ok(args) +} + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { + let ast: ItemEnum = parse(item).unwrap(); + let name = ast.ident; + + let args = match parse_enum_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; + + if args.ty.is_none() { + abort!( + Span::call_site(), + "The target type representation is unspecified" + ); + } + + let mut variants = Vec::with_capacity(ast.variants.len()); + + for variant in ast.variants.iter() { + if variant.discriminant.is_none() { + abort!(variant.ident.span(), "Missing value"); + } + let discriminant = variant.discriminant.as_ref().unwrap(); + if let syn::Expr::Path(path) = &discriminant.1 { + variants.push(Variant { + inner: variant, + name: &variant.ident, + value: &path.path, + }); + } else { + abort!(discriminant.1.span(), "Expected a path"); + } + } + + let repr_type = args.ty.unwrap(); + let match_entries = variants.iter().map(|variant| { + let variant_name = variant.name; + let variant_value = &variant.value; + quote!( x if x == (#variant_value as #repr_type) => Self::#variant_name, ) + }); + let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span()); + let nfnetlinkdeserialize_impl = quote!( + impl crate::nlmsg::NfNetlinkDeserializable for #name { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + let (v, remaining_data) = #repr_type::deserialize(buf)?; + Ok(( + match v { + #(#match_entries) * + value => return Err(crate::parser::DecodeError::#unknown_type_ident(value)) + }, + remaining_data, + )) + } + } + ); + let vis = &ast.vis; + let attrs = ast.attrs; + let original_variants = variants.into_iter().map(|x| { + let mut inner = x.inner.clone(); + let mut discriminant = inner.discriminant.as_mut().unwrap(); + let cur_value = discriminant.1.clone(); + let cast_value = Expr::Cast(ExprCast { + attrs: vec![], + expr: Box::new(cur_value), + as_token: Token![as](name.span()), + ty: Box::new(Type::Path(TypePath { + qself: None, + path: repr_type.clone(), + })), + }); + discriminant.1 = cast_value; + inner + }); + let res = quote! { + #[repr(#repr_type)] + #(#attrs) * #vis enum #name { + #(#original_variants),* + } + + impl crate::nlmsg::NfNetlinkAttribute for #name { + fn get_size(&self) -> usize { + (*self as #repr_type).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as #repr_type).write_payload(addr); + } + } + + #nfnetlinkdeserialize_impl + + }; + + res.into() +} diff --git a/src/chain_methods.rs b/src/chain_methods.rs new file mode 100644 index 0000000..d384c35 --- /dev/null +++ b/src/chain_methods.rs @@ -0,0 +1,40 @@ +use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; +use std::ffi::CString; +use std::rc::Rc; + + +/// A helper trait over [`crate::Chain`]. +pub trait ChainMethods { + /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. + fn from_hook(hook: Hook, table: Rc
) -> Self + where Self: std::marker::Sized; + /// Adds a [`crate::Policy`] to the current Chain. + fn verdict(self, policy: Policy) -> Self; + fn add_to_batch(self, batch: &mut Batch) -> Self; +} + + +impl ChainMethods for Chain { + fn from_hook(hook: Hook, table: Rc
) -> Self { + let chain_name = match hook { + Hook::PreRouting => "prerouting", + Hook::Out => "out", + Hook::PostRouting => "postrouting", + Hook::Forward => "forward", + Hook::In => "in", + }; + let chain_name = CString::new(chain_name).unwrap(); + let mut chain = Chain::new(&chain_name, table); + chain.set_hook(hook, 0); + chain + } + fn verdict(mut self, policy: Policy) -> Self { + self.set_policy(policy); + self + } + fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, MsgType::Add); + self + } +} + diff --git a/src/expr/counter.rs b/src/expr/counter.rs index 4732e85..d22fb8a 100644 --- a/src/expr/counter.rs +++ b/src/expr/counter.rs @@ -1,46 +1,21 @@ -use super::{DeserializationError, Expression, Rule}; +use rustables_macros::nfnetlink_struct; + +use super::Expression; use crate::sys; -use std::os::raw::c_char; /// A counter expression adds a counter to the rule that is incremented to count number of packets /// and number of bytes for all packets that have matched the rule. -#[derive(Debug, PartialEq)] +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct] pub struct Counter { + #[field(sys::NFTA_COUNTER_BYTES)] pub nb_bytes: u64, + #[field(sys::NFTA_COUNTER_PACKETS)] pub nb_packets: u64, } -impl Counter { - pub fn new() -> Self { - Self { - nb_bytes: 0, - nb_packets: 0, - } - } -} - impl Expression for Counter { - fn get_raw_name() -> *const c_char { - b"counter\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result { - unsafe { - let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); - let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); - Ok(Counter { - nb_bytes, - nb_packets, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets); - expr - } + fn get_name() -> &'static str { + "counter" } } diff --git a/src/expr/log.rs b/src/expr/log.rs index 3c72257..80bb7a9 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -8,7 +8,7 @@ use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}; /// A Log expression will log all packets that match the rule. pub struct Log { #[field(NFTA_LOG_GROUP)] - group: u32, + group: u16, #[field(NFTA_LOG_PREFIX)] prefix: String, } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index c4c1adb..79016bd 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,15 +1,11 @@ -use rustables_macros::nfnetlink_struct; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use super::{Expression, Register}; -use crate::{ - nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, - parser::DecodeError, - sys, -}; +use crate::sys; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -#[repr(u32)] +#[nfnetlink_enum(u32)] #[non_exhaustive] pub enum MetaType { /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. @@ -42,42 +38,6 @@ pub enum MetaType { PRandom = sys::NFT_META_PRANDOM, } -impl NfNetlinkAttribute for MetaType { - fn get_size(&self) -> usize { - (*self as u32).get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as u32).write_payload(addr); - } -} - -impl NfNetlinkDeserializable for MetaType { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (v, remaining_data) = u32::deserialize(buf)?; - Ok(( - match v { - sys::NFT_META_PROTOCOL => Self::Protocol, - sys::NFT_META_MARK => Self::Mark, - sys::NFT_META_IIF => Self::Iif, - sys::NFT_META_OIF => Self::Oif, - sys::NFT_META_IIFNAME => Self::IifName, - sys::NFT_META_OIFNAME => Self::OifName, - sys::NFT_META_IFTYPE => Self::IifType, - sys::NFT_META_OIFTYPE => Self::OifType, - sys::NFT_META_SKUID => Self::SkUid, - sys::NFT_META_SKGID => Self::SkGid, - sys::NFT_META_NFPROTO => Self::NfProto, - sys::NFT_META_L4PROTO => Self::L4Proto, - sys::NFT_META_CGROUP => Self::Cgroup, - sys::NFT_META_PRANDOM => Self::PRandom, - value => return Err(DecodeError::UnknownMetaType(value)), - }, - remaining_data, - )) - } -} - #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct] pub struct Meta { diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 63385e0..d2cd917 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -25,9 +25,11 @@ pub use self::bitwise::*; /* mod cmp; pub use self::cmp::*; +*/ mod counter; pub use self::counter::*; +/* pub mod ct; pub use self::ct::*; @@ -222,7 +224,8 @@ create_expr_variant!( [Bitwise, Bitwise], [ExpressionRaw, ExpressionRaw], [Meta, Meta], - [Reject, Reject] + [Reject, Reject], + [Counter, Counter] ); #[derive(Debug, Clone, PartialEq, Eq, Default)] diff --git a/src/expr/register.rs b/src/expr/register.rs index def58a5..9cc1bee 100644 --- a/src/expr/register.rs +++ b/src/expr/register.rs @@ -1,15 +1,13 @@ use std::fmt::Debug; -use crate::{ - nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, - parser::DecodeError, - sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}, -}; +use rustables_macros::nfnetlink_enum; + +use crate::sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}; /// A netfilter data register. The expressions store and read data to and from these when /// evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u32)] +#[nfnetlink_enum(u32)] pub enum Register { Verdict = NFT_REG_VERDICT, Reg1 = NFT_REG_1, @@ -17,26 +15,3 @@ pub enum Register { Reg3 = NFT_REG_3, Reg4 = NFT_REG_4, } - -impl NfNetlinkAttribute for Register { - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as u32).write_payload(addr); - } -} - -impl NfNetlinkDeserializable for Register { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { - let (val, remaining) = u32::deserialize(buf)?; - Ok(( - match val { - NFT_REG_VERDICT => Self::Verdict, - NFT_REG_1 => Self::Reg1, - NFT_REG_2 => Self::Reg2, - NFT_REG_3 => Self::Reg3, - NFT_REG_4 => Self::Reg4, - _ => return Err(DecodeError::UnknownRegisterValue), - }, - remaining, - )) - } -} diff --git a/src/expr/reject.rs b/src/expr/reject.rs index 10b95ea..83fd843 100644 --- a/src/expr/reject.rs +++ b/src/expr/reject.rs @@ -1,10 +1,6 @@ -use rustables_macros::nfnetlink_struct; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -use crate::{ - nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, - parser::DecodeError, - sys, -}; +use crate::sys; use super::Expression; @@ -26,70 +22,19 @@ pub struct Reject { /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -#[repr(u32)] +#[nfnetlink_enum(u32)] pub enum RejectType { IcmpUnreach = sys::NFT_REJECT_ICMP_UNREACH, TcpRst = sys::NFT_REJECT_TCP_RST, IcmpxUnreach = sys::NFT_REJECT_ICMPX_UNREACH, } -impl NfNetlinkDeserializable for RejectType { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (v, remaining_code) = u32::deserialize(buf)?; - Ok(( - match v { - sys::NFT_REJECT_ICMP_UNREACH => Self::IcmpUnreach, - sys::NFT_REJECT_TCP_RST => Self::TcpRst, - sys::NFT_REJECT_ICMPX_UNREACH => Self::IcmpxUnreach, - _ => return Err(DecodeError::UnknownRejectType(v)), - }, - remaining_code, - )) - } -} - -impl NfNetlinkAttribute for RejectType { - fn get_size(&self) -> usize { - (*self as u32).get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as u32).write_payload(addr); - } -} - /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -#[repr(u8)] +#[nfnetlink_enum(u8)] pub enum IcmpCode { - NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE as u8, - PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH as u8, - HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH as u8, - AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8, -} - -impl NfNetlinkDeserializable for IcmpCode { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (value, remaining_code) = u8::deserialize(buf)?; - Ok(( - match value as u32 { - sys::NFT_REJECT_ICMPX_NO_ROUTE => Self::NoRoute, - sys::NFT_REJECT_ICMPX_PORT_UNREACH => Self::PortUnreach, - sys::NFT_REJECT_ICMPX_HOST_UNREACH => Self::HostUnreach, - sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Self::AdminProhibited, - _ => return Err(DecodeError::UnknownIcmpCode(value)), - }, - remaining_code, - )) - } -} - -impl NfNetlinkAttribute for IcmpCode { - fn get_size(&self) -> usize { - (*self as u8).get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as u8).write_payload(addr); - } + NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE, + PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH, + HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH, + AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED, } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index fc13f8a..c4facfb 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,20 +1,16 @@ use std::fmt::Debug; use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; -use rustables_macros::nfnetlink_struct; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use super::{ExpressionData, Immediate, Register}; -use crate::{ - nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}, - parser::DecodeError, - sys::{ - NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, - NFT_GOTO, NFT_JUMP, NFT_RETURN, - }, +use crate::sys::{ + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, + NFT_GOTO, NFT_JUMP, NFT_RETURN, }; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(i32)] pub enum VerdictType { Drop = NF_DROP, Accept = NF_ACCEPT, @@ -26,36 +22,6 @@ pub enum VerdictType { Return = NFT_RETURN, } -impl NfNetlinkAttribute for VerdictType { - fn get_size(&self) -> usize { - (*self as i32).get_size() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as i32).write_payload(addr); - } -} - -impl NfNetlinkDeserializable for VerdictType { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (v, remaining_data) = i32::deserialize(buf)?; - Ok(( - match v { - NF_DROP => VerdictType::Drop, - NF_ACCEPT => VerdictType::Accept, - NF_QUEUE => VerdictType::Queue, - NFT_CONTINUE => VerdictType::Continue, - NFT_BREAK => VerdictType::Break, - NFT_JUMP => VerdictType::Jump, - NFT_GOTO => VerdictType::Goto, - NFT_RETURN => VerdictType::Goto, - _ => return Err(DecodeError::UnknownExpressionVerdictType), - }, - remaining_data, - )) - } -} - #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(nested = true)] pub struct VerdictAttribute { diff --git a/src/lib.rs b/src/lib.rs index 044030f..fecbc83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,6 @@ //! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs use parser::DecodeError; -use thiserror::Error; #[macro_use] extern crate log; diff --git a/src/parser.rs b/src/parser.rs index 7d89a1e..55f1e1c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,5 +1,4 @@ use std::{ - any::TypeId, convert::TryFrom, fmt::{Debug, DebugStruct}, mem::{size_of, transmute}, @@ -9,10 +8,7 @@ use std::{ use thiserror::Error; use crate::{ - //expr::ExpressionHolder, - nlmsg::{ - AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkWriter, - }, + nlmsg::{AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkDeserializable}, sys::{ nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_ALIGNTO, @@ -75,10 +71,10 @@ pub enum DecodeError { UnknownIcmpCode(u8), #[error("Invalid value for a register")] - UnknownRegisterValue, + UnknownRegister(u32), #[error("Invalid type for a verdict expression")] - UnknownExpressionVerdictType, + UnknownVerdictType(i32), #[error("The object does not contain a name for the expression being parsed")] MissingExpressionName, diff --git a/tests/expr.rs b/tests/expr.rs index 4a90309..5baec2a 100644 --- a/tests/expr.rs +++ b/tests/expr.rs @@ -1,15 +1,15 @@ use rustables::{ expr::{ - Bitwise, ExpressionList, IcmpCode, Immediate, Meta, MetaType, Register, Reject, RejectType, - VerdictKind, + Bitwise, ExpressionList, IcmpCode, Immediate, Log, Meta, MetaType, Register, Reject, + RejectType, VerdictKind, }, sys::{ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, - NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, NFTA_META_DREG, NFTA_META_KEY, - NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, - NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_META_PROTOCOL, NFT_REG_1, NFT_REG_VERDICT, - NFT_REJECT_ICMPX_UNREACH, + NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, NFTA_LOG_GROUP, NFTA_LOG_PREFIX, + NFTA_META_DREG, NFTA_META_KEY, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, + NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_META_PROTOCOL, NFT_REG_1, + NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, }, }; //use rustables::expr::{ @@ -246,42 +246,41 @@ fn immediate_expr_is_valid() { ); } -//#[test] -//fn log_expr_is_valid() { -// let log = Log { -// group: Some(LogGroup(1)), -// prefix: Some(LogPrefix::new("mockprefix").unwrap()), -// }; -// let mut rule = get_test_rule(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &log); -// assert_eq!(nlmsghdr.nlmsg_len, 96); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Nested( -// NFTA_RULE_EXPRESSIONS, -// vec![NetlinkExpr::Nested( -// NFTA_LIST_ELEM, -// vec![ -// NetlinkExpr::Final(NFTA_EXPR_NAME, b"log\0".to_vec()), -// NetlinkExpr::Nested( -// NFTA_EXPR_DATA, -// vec![ -// NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix\0".to_vec()), -// NetlinkExpr::Final(NFTA_LOG_GROUP, 1u16.to_be_bytes().to_vec()) -// ] -// ) -// ] -// )] -// ) -// ]) -// .to_raw() -// ); -//} -// +#[test] +fn log_expr_is_valid() { + let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression"); + let mut rule = get_test_rule().with_expressions(ExpressionList::builder().with_expression(log)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"log".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_LOG_GROUP, 1337u16.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix".to_vec()), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + //#[test] //fn lookup_expr_is_valid() { // let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); -- cgit v1.2.3 From dc3c2ffab697b5d8fce7c69f76528fcfdf2edf38 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Sun, 8 Jan 2023 22:24:40 +0100 Subject: rewrite the examples --- examples/add-rules.rs | 168 +++++++++------------ examples/filter-ethernet.rs | 145 +++++++++--------- examples/firewall.rs | 234 ++++++++++++++--------------- macros/src/lib.rs | 7 + src/chain.rs | 8 +- src/chain_methods.rs | 40 ----- src/data_type.rs | 9 +- src/error.rs | 6 + src/expr/cmp.rs | 5 +- src/expr/ct.rs | 4 + src/expr/log.rs | 14 +- src/expr/mod.rs | 30 ---- src/lib.rs | 9 +- src/nlmsg.rs | 4 - src/rule.rs | 26 +++- src/rule_methods.rs | 355 +++++++++++++++++++++----------------------- src/set.rs | 3 +- src/table.rs | 8 +- src/tests/expr.rs | 46 +++--- src/tests/mod.rs | 4 +- src/tests/set.rs | 19 +-- 21 files changed, 535 insertions(+), 609 deletions(-) delete mode 100644 src/chain_methods.rs (limited to 'src/expr/log.rs') diff --git a/examples/add-rules.rs b/examples/add-rules.rs index b145291..a2b9c9c 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -38,12 +38,14 @@ use ipnetwork::{IpNetwork, Ipv4Network}; use rustables::{ - expr::{Cmp, CmpOp, ExpressionList, Immediate, Meta, MetaType, Verdict, VerdictKind}, - list_chains_for_table, list_rules_for_chain, list_tables, Batch, Chain, ChainPolicy, Hook, - HookClass, MsgType, ProtocolFamily, Rule, Table, + data_type::ip_to_vec, + expr::{ + Bitwise, Cmp, CmpOp, Counter, HighLevelPayload, ICMPv6HeaderField, IPv4HeaderField, + IcmpCode, Immediate, Meta, MetaType, NetworkHeaderField, TransportHeaderField, VerdictKind, + }, + iface_index, Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table, }; -//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, Rule, Table}; -use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; +use std::net::Ipv4Addr; const TABLE_NAME: &str = "example-table"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; @@ -81,105 +83,94 @@ fn main() -> Result<(), Error> { batch.add(&out_chain, MsgType::Add); batch.add(&in_chain, MsgType::Add); - let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Accept)), - ); - batch.add(&rule, MsgType::Add); - - let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Continue)), - ); - - batch.add(&rule, MsgType::Add); - // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === - // Create a new rule object under the input chain. - let mut allow_loopback_in_rule = Rule::new(&in_chain)?; // Lookup the interface index of the loopback interface. let lo_iface_index = iface_index("lo")?; - allow_loopback_in_rule.set_expressions( - ExpressionList::default() + // Create a new rule object under the input chain. + let allow_loopback_in_rule = Rule::new(&in_chain)? // First expression to be evaluated in this rule is load the meta information "iif" // (incoming interface index) into the comparison register of netfilter. // When an incoming network packet is processed by this rule it will first be processed by this // expression, which will load the interface index of the interface the packet came from into // a special "register" in netfilter. - .with_value(Meta::new(MetaType::Iif)) + .with_expr(Meta::new(MetaType::Iif)) + // Next expression in the rule is to compare the value loaded into the register with our desired // interface index, and succeed only if it's equal. For any packet processed where the equality // does not hold the packet is said to not match this rule, and the packet moves on to be // processed by the next rule in the chain instead. - .with_value(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) + .with_expr(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) // Add a verdict expression to the rule. Any packet getting this far in the expression // processing without failing any expression will be given the verdict added here. - .with_value(Immediate::new_verdict(VerdictKind::Accept)), - ); + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); // Add the rule to the batch. batch.add(&allow_loopback_in_rule, rustables::MsgType::Add); - // // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === - // - // let mut block_out_to_private_net_rule = Rule::new(Rc::clone(&out_chain)); - // let private_net_ip = Ipv4Addr::new(10, 1, 0, 0); - // let private_net_prefix = 24; - // let private_net = IpNetwork::V4(Ipv4Network::new(private_net_ip, private_net_prefix)?); - // - // // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 - // // protocol the packet being processed is using. - // block_out_to_private_net_rule.add_expr(&nft_expr!(meta nfproto)); - // // Check if the currently processed packet is an IPv4 packet. This must be done before payload - // // data assuming the packet uses IPv4 can be loaded in the next expression. - // block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - // - // // Load the IPv4 destination address into the netfilter register. - // block_out_to_private_net_rule.add_expr(&nft_expr!(payload ipv4 daddr)); - // // Mask out the part of the destination address that is not part of the network bits. The result - // // of this bitwise masking is stored back into the same netfilter register. - // block_out_to_private_net_rule.add_expr(&nft_expr!(bitwise mask private_net.mask(), xor 0)); - // // Compare the result of the masking with the IP of the network we are interested in. - // block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == private_net.ip())); - // - // // Add a packet counter to the rule. Shows how many packets have been evaluated against this - // // expression. Since expressions are evaluated from first to last, putting this counter before - // // the above IP net check would make the counter increment on all packets also *not* matching - // // those expressions. Because the counter would then be evaluated before it fails a check. - // // Similarly, if the counter was added after the verdict it would always remain at zero. Since - // // when the packet hits the verdict expression any further processing of expressions stop. - // block_out_to_private_net_rule.add_expr(&nft_expr!(counter)); - // - // // Accept all the packets matching the rule so far. - // block_out_to_private_net_rule.add_expr(&nft_expr!(verdict accept)); - // - // // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, - // // and all the work on `block_out_to_private_net_rule` so far would go to waste. - // batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); - // - // // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === - // - // let mut allow_router_solicitation = Rule::new(Rc::clone(&out_chain)); - // - // // Check that the packet is IPv6 and ICMPv6 - // allow_router_solicitation.add_expr(&nft_expr!(meta nfproto)); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - // allow_router_solicitation.add_expr(&nft_expr!(meta l4proto)); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - // - // allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - // rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Type), - // )); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == 133u8)); - // allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - // rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Code), - // )); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == 0u8)); - // - // allow_router_solicitation.add_expr(&nft_expr!(verdict accept)); - // - // batch.add(&allow_router_solicitation, rustables::MsgType::Add); + // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === + + let private_net_ip = Ipv4Addr::new(10, 1, 0, 0); + let private_net_prefix = 24; + let private_net = IpNetwork::V4(Ipv4Network::new(private_net_ip, private_net_prefix)?); + + let block_out_to_private_net_rule = Rule::new(&out_chain)? + // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 + // protocol the packet being processed is using. + .with_expr(Meta::new(MetaType::NfProto)) + + // Check if the currently processed packet is an IPv4 packet. This must be done before payload + // data assuming the packet uses IPv4 can be loaded in the next expression. + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])) + + // Load the IPv4 destination address into the netfilter register. + .with_expr(HighLevelPayload::Network(NetworkHeaderField::IPv4(IPv4HeaderField::Daddr)).build()) + + // Mask out the part of the destination address that is not part of the network bits. The result + // of this bitwise masking is stored back into the same netfilter register. + .with_expr(Bitwise::new(ip_to_vec(private_net.mask()), [0u8; 4])?) + + // Compare the result of the masking with the IP of the network we are interested in. + .with_expr(Cmp::new(CmpOp::Eq, ip_to_vec(private_net.ip()))) + + // Add a packet counter to the rule. Shows how many packets have been evaluated against this + // expression. Since expressions are evaluated from first to last, putting this counter before + // the above IP net check would make the counter increment on all packets also *not* matching + // those expressions. Because the counter would then be evaluated before it fails a check. + // Similarly, if the counter was added after the verdict it would always remain at zero. Since + // when the packet hits the verdict expression any further processing of expressions stop. + .with_expr(Counter::default()) + + // Accept all the packets matching the rule so far. + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); + + // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, + // and all the work on `block_out_to_private_net_rule` so far would go to waste. + batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); + + // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === + + let allow_router_solicitation = Rule::new(&out_chain)? + // Check that the packet is IPv6 and ICMPv6 + .with_expr(Meta::new(MetaType::NfProto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])) + .with_expr(Meta::new(MetaType::L4Proto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMPV6 as u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Type)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [133u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Code)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [IcmpCode::NoRoute as u8])) + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); + + batch.add(&allow_router_solicitation, rustables::MsgType::Add); // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === @@ -190,17 +181,6 @@ fn main() -> Result<(), Error> { Ok(batch.send()?) } -// Look up the interface index for a given interface name. -fn iface_index(name: &str) -> Result { - let c_name = CString::new(name)?; - let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; - if index == 0 { - Err(Error::from(io::Error::last_os_error())) - } else { - Ok(index) - } -} - #[derive(Debug)] struct Error(String); diff --git a/examples/filter-ethernet.rs b/examples/filter-ethernet.rs index 732c8cb..a136731 100644 --- a/examples/filter-ethernet.rs +++ b/examples/filter-ethernet.rs @@ -10,7 +10,7 @@ ///! table inet example-filter-ethernet { ///! chain chain-for-outgoing-packets { ///! type filter hook output priority 3; policy accept; -///! ether daddr 00:00:00:00:00:00 drop +///! ether daddr 01:02:03:04:05:06 drop ///! counter packets 0 bytes 0 meta random > 2147483647 counter packets 0 bytes 0 ///! } ///! } @@ -21,75 +21,78 @@ ///! ```bash ///! # nft delete table inet example-filter-ethernet ///! ``` -//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; -use std::{ffi::CString, rc::Rc}; -// -//const TABLE_NAME: &str = "example-filter-ethernet"; -//const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; -// -//const BLOCK_THIS_MAC: &[u8] = &[0, 0, 0, 0, 0, 0]; -// +use rustables::{ + expr::{ + Cmp, CmpOp, Counter, ExpressionList, HighLevelPayload, Immediate, LLHeaderField, Meta, + MetaType, VerdictKind, + }, + Batch, Chain, ChainPolicy, Hook, HookClass, ProtocolFamily, Rule, Table, +}; + +const TABLE_NAME: &str = "example-filter-ethernet"; +const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; + +const BLOCK_THIS_MAC: &[u8] = &[1, 2, 3, 4, 5, 6]; + fn main() { - // // For verbose explanations of what all these lines up until the rule creation does, see the - // // `add-rules` example. - // let mut batch = Batch::new(); - // let table = Rc::new(Table::new( - // &CString::new(TABLE_NAME).unwrap(), - // ProtoFamily::Inet, - // )); - // batch.add(&Rc::clone(&table), rustables::MsgType::Add); - // - // let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table)); - // out_chain.set_hook(rustables::Hook::Out, 3); - // out_chain.set_policy(rustables::Policy::Accept); - // let out_chain = Rc::new(out_chain); - // batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add); - // - // // === ADD RULE DROPPING ALL TRAFFIC TO THE MAC ADDRESS IN `BLOCK_THIS_MAC` === - // - // let mut block_ethernet_rule = Rule::new(Rc::clone(&out_chain)); - // - // // Check that the interface type is an ethernet interface. Must be done before we can check - // // payload values in the ethernet header. - // block_ethernet_rule.add_expr(&nft_expr!(meta iiftype)); - // block_ethernet_rule.add_expr(&nft_expr!(cmp == libc::ARPHRD_ETHER)); - // - // // Compare the ethernet destination address against the MAC address we want to drop - // block_ethernet_rule.add_expr(&nft_expr!(payload ethernet daddr)); - // block_ethernet_rule.add_expr(&nft_expr!(cmp == BLOCK_THIS_MAC)); - // - // // Drop the matching packets. - // block_ethernet_rule.add_expr(&nft_expr!(verdict drop)); - // - // batch.add(&block_ethernet_rule, rustables::MsgType::Add); - // - // // === FOR FUN, ADD A PACKET THAT MATCHES 50% OF ALL PACKETS === - // - // // This packet has a counter before and after the check that has 50% chance of matching. - // // So after a number of packets has passed through this rule, the first counter should have a - // // value approximately double that of the second counter. This rule has no verdict, so it never - // // does anything with the matching packets. - // let mut random_rule = Rule::new(Rc::clone(&out_chain)); - // // This counter expression will be evaluated (and increment the counter) for all packets coming - // // through. - // random_rule.add_expr(&nft_expr!(counter)); - // - // // Load a pseudo-random 32 bit unsigned integer into the netfilter register. - // random_rule.add_expr(&nft_expr!(meta random)); - // // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. - // random_rule.add_expr(&nft_expr!(cmp > (::std::u32::MAX / 2).to_be())); - // - // // Add a second counter. This will only be incremented for the packets passing the random check. - // random_rule.add_expr(&nft_expr!(counter)); - // - // batch.add(&random_rule, rustables::MsgType::Add); - // - // // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === - // - // match batch.finalize() { - // Some(mut finalized_batch) => { - // send_batch(&mut finalized_batch).expect("Couldn't process the batch"); - // } - // None => todo!(), - // } + // For verbose explanations of what all these lines up until the rule creation does, see the + // `add-rules` example. + let mut batch = Batch::new(); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); + batch.add(&table, rustables::MsgType::Add); + + let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME); + out_chain.set_hook(Hook::new(HookClass::Out, 3)); + out_chain.set_policy(ChainPolicy::Accept); + batch.add(&out_chain, rustables::MsgType::Add); + + // === ADD RULE DROPPING ALL TRAFFIC TO THE MAC ADDRESS IN `BLOCK_THIS_MAC` === + + let mut block_ethernet_rule = Rule::new(&out_chain).unwrap(); + + block_ethernet_rule.set_expressions( + ExpressionList::default() + // Check that the interface type is an ethernet interface. Must be done before we can check + // payload values in the ethernet header. + .with_value(Meta::new(MetaType::IifType)) + .with_value(Cmp::new(CmpOp::Eq, (libc::ARPHRD_ETHER as u16).to_le_bytes())) + + // Compare the ethernet destination address against the MAC address we want to drop + .with_value(HighLevelPayload::LinkLayer(LLHeaderField::Daddr).build()) + .with_value(Cmp::new(CmpOp::Eq, BLOCK_THIS_MAC)) + + // Drop the matching packets. + .with_value(Immediate::new_verdict(VerdictKind::Drop)), + ); + + batch.add(&block_ethernet_rule, rustables::MsgType::Add); + + // === FOR FUN, ADD A PACKET THAT MATCHES 50% OF ALL PACKETS === + + // This packet has a counter before and after the check that has 50% chance of matching. + // So after a number of packets has passed through this rule, the first counter should have a + // value approximately double that of the second counter. This rule has no verdict, so it never + // does anything with the matching packets. + let mut random_rule = Rule::new(&out_chain).unwrap(); + + random_rule.set_expressions( + ExpressionList::default() + // This counter expression will be evaluated (and increment the counter) for all packets coming + // through. + .with_value(Counter::default()) + + // Load a pseudo-random 32 bit unsigned integer into the netfilter register. + .with_value(Meta::new(MetaType::PRandom)) + // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. + .with_value(Cmp::new(CmpOp::Gt, (::std::u32::MAX / 2).to_be_bytes())) + + // Add a second counter. This will only be incremented for the packets passing the random check. + .with_value(Counter::default()), + ); + + batch.add(&random_rule, rustables::MsgType::Add); + + // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === + + batch.send().unwrap(); } diff --git a/examples/firewall.rs b/examples/firewall.rs index fc25010..3169cdc 100644 --- a/examples/firewall.rs +++ b/examples/firewall.rs @@ -3,138 +3,124 @@ //use rustables::query::{send_batch, Error as QueryError}; //use rustables::expr::{LogGroup, LogPrefix, LogPrefixError}; use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; -use std::rc::Rc; -use std::str::Utf8Error; +use rustables::error::{BuilderError, QueryError}; +use rustables::expr::Log; +use rustables::{ + Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, Protocol, ProtocolFamily, Rule, Table, +}; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C String")] - NulError(#[from] NulError), - //#[error("Error creating match")] - //MatchError(#[from] MatchError), - #[error("Error converting to utf-8 string")] - Utf8Error(#[from] Utf8Error), + #[error("Error building a netlink object")] + BuildError(#[from] BuilderError), #[error("Error applying batch")] - BatchError(#[from] std::io::Error), - //#[error("Error applying batch")] - //QueryError(#[from] QueryError), - //#[error("Error encoding the prefix")] - //LogPrefixError(#[from] LogPrefixError), + QueryError(#[from] QueryError), } const TABLE_NAME: &str = "main-table"; +const INBOUND_CHAIN_NAME: &str = "in-chain"; +const FORWARD_CHAIN_NAME: &str = "forward-chain"; +const OUTBOUND_CHAIN_NAME: &str = "out-chain"; fn main() -> Result<(), Error> { - // let fw = Firewall::new()?; - // fw.start()?; + let fw = Firewall::new()?; + fw.start()?; Ok(()) } -// -// -///// An example firewall. See the source of its `start()` method. -//pub struct Firewall { -// batch: Batch, -// inbound: Rc, -// _outbound: Rc, -// _forward: Rc, -// table: Rc
, -//} -// -//impl Firewall { -// pub fn new() -> Result { -// let mut batch = Batch::new(); -// let table = Rc::new( -// Table::new(&CString::new(TABLE_NAME)?, ProtoFamily::Inet) -// ); -// batch.add(&table, MsgType::Add); -// -// // Create base chains. Base chains are hooked into a Direction/Hook. -// let inbound = Rc::new( -// Chain::from_hook(Hook::In, Rc::clone(&table)) -// .verdict(Policy::Drop) -// .add_to_batch(&mut batch) -// ); -// let _outbound = Rc::new( -// Chain::from_hook(Hook::Out, Rc::clone(&table)) -// .verdict(Policy::Accept) -// .add_to_batch(&mut batch) -// ); -// let _forward = Rc::new( -// Chain::from_hook(Hook::Forward, Rc::clone(&table)) -// .verdict(Policy::Accept) -// .add_to_batch(&mut batch) -// ); -// -// Ok(Firewall { -// table, -// batch, -// inbound, -// _outbound, -// _forward -// }) -// } -// /// Allow some common-sense exceptions to inbound drop, and accept outbound and forward. -// pub fn start(mut self) -> Result<(), Error> { -// // Allow all established connections to get in. -// Rule::new(Rc::clone(&self.inbound)) -// .established() -// .accept() -// .add_to_batch(&mut self.batch); -// // Allow all traffic on the loopback interface. -// Rule::new(Rc::clone(&self.inbound)) -// .iface("lo")? -// .accept() -// .add_to_batch(&mut self.batch); -// // Allow ssh from anywhere, and log to dmesg with a prefix. -// Rule::new(Rc::clone(&self.inbound)) -// .dport("22", &Protocol::TCP)? -// .accept() -// .log(None, Some(LogPrefix::new("allow ssh connection:")?)) -// .add_to_batch(&mut self.batch); -// -// // Allow http from all IPs in 192.168.1.255/24 . -// let local_net = IpNetwork::new([192, 168, 1, 0].into(), 24).unwrap(); -// Rule::new(Rc::clone(&self.inbound)) -// .dport("80", &Protocol::TCP)? -// .snetwork(local_net) -// .accept() -// .add_to_batch(&mut self.batch); -// -// // Allow ICMP traffic, drop IGMP. -// Rule::new(Rc::clone(&self.inbound)) -// .icmp() -// .accept() -// .add_to_batch(&mut self.batch); -// Rule::new(Rc::clone(&self.inbound)) -// .igmp() -// .drop() -// .add_to_batch(&mut self.batch); -// -// // Log all traffic not accepted to NF_LOG group 1, accessible with ulogd. -// Rule::new(Rc::clone(&self.inbound)) -// .log(Some(LogGroup(1)), None) -// .add_to_batch(&mut self.batch); -// -// let mut finalized_batch = self.batch.finalize().unwrap(); -// send_batch(&mut finalized_batch)?; -// println!("table {} commited", TABLE_NAME); -// Ok(()) -// } -// /// If there is any table with name TABLE_NAME, remove it. -// pub fn stop(mut self) -> Result<(), Error> { -// self.batch.add(&self.table, MsgType::Add); -// self.batch.add(&self.table, MsgType::Del); -// -// let mut finalized_batch = self.batch.finalize().unwrap(); -// send_batch(&mut finalized_batch)?; -// println!("table {} destroyed", TABLE_NAME); -// Ok(()) -// } -//} -// -// + +/// An example firewall. See the source of its `start()` method. +pub struct Firewall { + batch: Batch, + inbound: Chain, + _outbound: Chain, + _forward: Chain, + table: Table, +} + +impl Firewall { + pub fn new() -> Result { + let mut batch = Batch::new(); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); + batch.add(&table, MsgType::Add); + + // Create base chains. Base chains are hooked into a Direction/Hook. + let inbound = Chain::new(&table) + .with_name(INBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::In, 0)) + .with_policy(ChainPolicy::Drop) + .add_to_batch(&mut batch); + let _outbound = Chain::new(&table) + .with_name(OUTBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Out, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); + let _forward = Chain::new(&table) + .with_name(FORWARD_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Forward, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); + + Ok(Firewall { + table, + batch, + inbound, + _outbound, + _forward, + }) + } + /// Allow some common-sense exceptions to inbound drop, and accept outbound and forward. + pub fn start(mut self) -> Result<(), Error> { + // Allow all established connections to get in. + Rule::new(&self.inbound)? + .established()? + .accept() + .add_to_batch(&mut self.batch); + // Allow all traffic on the loopback interface. + Rule::new(&self.inbound)? + .iface("lo")? + .accept() + .add_to_batch(&mut self.batch); + // Allow ssh from anywhere, and log to dmesg with a prefix. + Rule::new(&self.inbound)? + .dport(22, Protocol::TCP) + .accept() + .with_expr(Log::new(None, Some("allow ssh connection:"))?) + .add_to_batch(&mut self.batch); + + // Allow http from all IPs in 192.168.1.255/24 . + let local_net = IpNetwork::new([192, 168, 1, 0].into(), 24).unwrap(); + Rule::new(&self.inbound)? + .dport(80, Protocol::TCP) + .snetwork(local_net)? + .accept() + .add_to_batch(&mut self.batch); + + // Allow ICMP traffic, drop IGMP. + Rule::new(&self.inbound)? + .icmp() + .accept() + .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .igmp() + .drop() + .add_to_batch(&mut self.batch); + + // Log all traffic not accepted to NF_LOG group 1, accessible with ulogd. + Rule::new(&self.inbound)? + .with_expr(Log::new(Some(1), None::)?) + .add_to_batch(&mut self.batch); + + self.batch.send()?; + println!("table {} commited", TABLE_NAME); + Ok(()) + } + /// If there is any table with name TABLE_NAME, remove it. + pub fn stop(mut self) -> Result<(), Error> { + self.batch.add(&self.table, MsgType::Add); + self.batch.add(&self.table, MsgType::Del); + + self.batch.send()?; + println!("table {} destroyed", TABLE_NAME); + Ok(()) + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 9170e82..39f0d01 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -187,6 +187,9 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let getter_name = format!("get_{}", field_str); let getter_name = Ident::new(&getter_name, field.name.span()); + let muttable_getter_name = format!("get_mut_{}", field_str); + let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span()); + let setter_name = format!("set_{}", field_str); let setter_name = Ident::new(&setter_name, field.name.span()); @@ -199,6 +202,10 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { self.#field_name.as_ref() } + pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> { + self.#field_name.as_mut() + } + pub fn #setter_name(&mut self, val: impl Into<#field_type>) { self.#field_name = Some(val.into()); } diff --git a/src/chain.rs b/src/chain.rs index 0ce0ad8..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; -use crate::{ProtocolFamily, Table}; +use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; pub type ChainPriority = i32; @@ -169,6 +169,12 @@ impl Chain { chain } + + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Chain { diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc
) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc
) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/data_type.rs b/src/data_type.rs index f9c97cb..43a7f1a 100644 --- a/src/data_type.rs +++ b/src/data_type.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; pub trait DataType { const TYPE: u32; @@ -33,3 +33,10 @@ impl DataType for [u8; N] { self.to_vec() } } + +pub fn ip_to_vec(ip: IpAddr) -> Vec { + match ip { + IpAddr::V4(x) => x.octets().to_vec(), + IpAddr::V6(x) => x.octets().to_vec(), + } +} diff --git a/src/error.rs b/src/error.rs index eae6898..f6b6247 100644 --- a/src/error.rs +++ b/src/error.rs @@ -129,6 +129,12 @@ pub enum BuilderError { #[error("Missing name for the set")] MissingSetName, + + #[error("The interface name is too long to be written")] + InterfaceNameTooLong, + + #[error("The log prefix string is more than 127 characters long")] + TooLongLogPrefix, } #[derive(thiserror::Error, Debug)] diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index 223902f..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,7 +1,6 @@ use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::{ - data_type::DataType, parser_impls::NfNetlinkData, sys::{ NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, @@ -44,11 +43,11 @@ pub struct Cmp { impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: impl DataType) -> Self { + pub fn new(op: CmpOp, data: impl Into>) -> Self { Cmp { sreg: Some(Register::Reg1), op: Some(op), - data: Some(NfNetlinkData::default().with_value(data.data())), + data: Some(NfNetlinkData::default().with_value(data.into())), } } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index ccf61e1..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -43,6 +43,10 @@ impl Expression for Conntrack { } impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) + } + pub fn set_mark_value(&mut self, reg: Register) { self.set_sreg(reg); self.set_key(ConntrackKey::Mark); diff --git a/src/expr/log.rs b/src/expr/log.rs index 80bb7a9..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,7 +1,10 @@ use rustables_macros::nfnetlink_struct; -use super::{Expression, ExpressionError}; -use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct] @@ -14,10 +17,7 @@ pub struct Log { } impl Log { - pub fn new( - group: Option, - prefix: Option>, - ) -> Result { + pub fn new(group: Option, prefix: Option>) -> Result { let mut res = Log::default(); if let Some(group) = group { res.set_group(group); @@ -26,7 +26,7 @@ impl Log { let prefix = prefix.into(); if prefix.bytes().count() > 127 { - return Err(ExpressionError::TooLongLogPrefix); + return Err(BuilderError::TooLongLogPrefix); } res.set_prefix(prefix); } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 979ebb2..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -6,7 +6,6 @@ use std::fmt::Debug; use rustables_macros::nfnetlink_struct; -use thiserror::Error; use crate::error::DecodeError; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; @@ -55,35 +54,6 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -#[derive(Debug, Error)] -pub enum ExpressionError { - #[error("The log prefix string is more than 127 characters long")] - /// The log prefix string is more than 127 characters long - TooLongLogPrefix, - - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, -} - pub trait Expression { fn get_name() -> &'static str; } diff --git a/src/lib.rs b/src/lib.rs index 1ad1eed..dec5b76 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,7 +53,7 @@ use std::convert::TryFrom; mod batch; pub use batch::{default_batch_page_size, Batch}; -mod data_type; +pub mod data_type; mod table; pub use table::list_tables; @@ -65,9 +65,6 @@ pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; pub mod error; -//mod chain_methods; -//pub use chain_methods::ChainMethods; - pub mod query; pub(crate) mod nlmsg; @@ -80,8 +77,8 @@ pub use rule::Rule; pub mod expr; -//mod rule_methods; -//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; +mod rule_methods; +pub use rule_methods::{iface_index, Protocol}; pub mod set; pub use set::Set; diff --git a/src/nlmsg.rs b/src/nlmsg.rs index b3710bf..1c5b519 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -62,10 +62,6 @@ impl<'a> NfNetlinkWriter<'a> { &mut self.buf[start..start + size] } - pub fn extract_buffer(self) -> &'a mut Vec { - self.buf - } - // rewrite of `__nftnl_nlmsg_build_hdr` pub fn write_header( &mut self, diff --git a/src/rule.rs b/src/rule.rs index 7f732d3..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -4,7 +4,7 @@ use rustables_macros::nfnetlink_struct; use crate::chain::Chain; use crate::error::{BuilderError, QueryError}; -use crate::expr::ExpressionList; +use crate::expr::{ExpressionList, RawExpression}; use crate::nlmsg::NfNetlinkObject; use crate::query::list_objects_with_data; use crate::sys::{ @@ -12,7 +12,7 @@ use crate::sys::{ NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, NLM_F_CREATE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. #[derive(Clone, PartialEq, Eq, Default, Debug)] @@ -53,6 +53,28 @@ impl Rule { .ok_or(BuilderError::MissingChainInformationError)?, )) } + + pub fn add_expr(&mut self, e: impl Into) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() + } + }; + exprs.add_value(e); + } + + pub fn with_expr(mut self, e: impl Into) -> Self { + self.add_expr(e); + self + } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Rule { diff --git a/src/rule_methods.rs b/src/rule_methods.rs index d7145d7..dff9bf6 100644 --- a/src/rule_methods.rs +++ b/src/rule_methods.rs @@ -1,230 +1,211 @@ -use crate::{Batch, Rule, nft_expr, sys::libc}; -use crate::expr::{LogGroup, LogPrefix}; -use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; +use std::ffi::CString; use std::net::IpAddr; -use std::num::ParseIntError; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C string to a string")] - CStringError(#[from] NulError), - #[error("no interface found under that name")] - NoSuchIface, - #[error("Error converting from a string to an integer")] - ParseError(#[from] ParseIntError), - #[error("the interface name is too long")] - NameTooLong, -} +use ipnetwork::IpNetwork; +use crate::data_type::ip_to_vec; +use crate::error::BuilderError; +use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey}; +use crate::expr::{ + Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta, + MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField, + VerdictKind, +}; +use crate::Rule; /// Simple protocol description. Note that it does not implement other layer 4 protocols as /// IGMP et al. See [`Rule::igmp`] for a workaround. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { TCP, - UDP + UDP, } -/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a -/// verdict. Mostly adapted from [talpid-core's firewall]. -/// All methods return the rule itself, allowing them to be chained. Usage example : -/// ```rust -/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook}; -/// use std::ffi::CString; -/// use std::rc::Rc; -/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet)); -/// let mut batch = Batch::new(); -/// batch.add(&table, MsgType::Add); -/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table)) -/// .add_to_batch(&mut batch)); -/// let rule = Rule::new(inbound) -/// .dport("80", &Protocol::TCP).unwrap() -/// .accept() -/// .add_to_batch(&mut batch); -/// ``` -/// [talpid-core's firewall]: -/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs -pub trait RuleMethods { - /// Matches ICMP packets. - fn icmp(self) -> Self; - /// Matches IGMP packets. - fn igmp(self) -> Self; - /// Matches packets to destination `port` and `protocol`. - fn dport(self, port: &str, protocol: &Protocol) -> Result - where Self: std::marker::Sized; - /// Matches packets on `protocol`. - fn protocol(self, protocol: Protocol) -> Result - where Self: std::marker::Sized; - /// Matches packets in an already established connection. - fn established(self) -> Self where Self: std::marker::Sized; - /// Matches packets going through `iface_index`. Interface indexes can be queried with - /// `iface_index()`. - fn iface_id(self, iface_index: libc::c_uint) -> Result - where Self: std::marker::Sized; - /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo". - fn iface(self, iface_name: &str) -> Result - where Self: std::marker::Sized; - /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix - /// appended to each log line. - fn log(self, group: Option, prefix: Option) -> Self; - /// Matches packets whose source IP address is `saddr`. - fn saddr(self, ip: IpAddr) -> Self; - /// Matches packets whose source network is `snet`. - fn snetwork(self, ip: IpNetwork) -> Self; - /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. - fn accept(self) -> Self; - /// Adds the `Drop` verdict to the rule. The packet will be dropped. - fn drop(self) -> Self; - /// Appends this rule to `batch`. - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - -/// A trait to add helper functions to match some criterium over `crate::Rule`. -impl RuleMethods for Rule { - fn icmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8)); - self - } - fn igmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8)); +impl Rule { + fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self { + self = self.protocol(protocol); + self.add_expr( + HighLevelPayload::Transport(match protocol { + Protocol::TCP => TransportHeaderField::Tcp(if source { + TCPHeaderField::Sport + } else { + TCPHeaderField::Dport + }), + Protocol::UDP => TransportHeaderField::Udp(if source { + UDPHeaderField::Sport + } else { + UDPHeaderField::Dport + }), + }) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes())); self } - fn dport(mut self, port: &str, protocol: &Protocol) -> Result { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - &Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - self.add_expr(&nft_expr!(payload tcp dport)); - }, - &Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - self.add_expr(&nft_expr!(payload udp dport)); - } - } - // Convert the port to Big-Endian number spelling. - // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969 - self.add_expr(&nft_expr!(cmp == port.parse::()?.to_be())); - Ok(self) - } - fn protocol(mut self, protocol: Protocol) -> Result { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - }, - Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - } - } - Ok(self) - } - fn established(mut self) -> Self { - let allowed_states = crate::expr::ct::States::ESTABLISHED.bits(); - self.add_expr(&nft_expr!(ct state)); - self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32)); - self.add_expr(&nft_expr!(cmp != 0u32)); - self - } - fn iface_id(mut self, iface_index: libc::c_uint) -> Result { - self.add_expr(&nft_expr!(meta iif)); - self.add_expr(&nft_expr!(cmp == iface_index)); - Ok(self) - } - fn iface(mut self, iface_name: &str) -> Result { - if iface_name.len() >= libc::IFNAMSIZ { - return Err(Error::NameTooLong); - } - let mut name_arr = [0u8; libc::IFNAMSIZ]; - for (pos, i) in iface_name.bytes().enumerate() { - name_arr[pos] = i; - } - self.add_expr(&nft_expr!(meta iifname)); - self.add_expr(&nft_expr!(cmp == name_arr.as_ref())); - Ok(self) - } - fn saddr(mut self, ip: IpAddr) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self { + self.add_expr(Meta::new(MetaType::NfProto)); match ip { IpAddr::V4(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); + } IpAddr::V6(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); } } self } - fn snetwork(mut self, net: IpNetwork) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + + pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result { + self.add_expr(Meta::new(MetaType::NfProto)); match net { IpNetwork::V4(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32)); - self.add_expr(&nft_expr!(cmp == net.network())); - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?); + } IpNetwork::V6(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..])); - self.add_expr(&nft_expr!(cmp == net.network())); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?); } } + self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network()))); + Ok(self) + } +} + +impl Rule { + /// Matches ICMP packets. + pub fn icmp(mut self) -> Self { + // quid of icmpv6? + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8])); self } - fn log(mut self, group: Option, prefix: Option) -> Self { - match (group.is_some(), prefix.is_some()) { - (true, true) => { - self.add_expr(&nft_expr!(log group group prefix prefix)); - }, - (false, true) => { - self.add_expr(&nft_expr!(log prefix prefix)); - }, - (true, false) => { - self.add_expr(&nft_expr!(log group group)); - }, - (false, false) => { - self.add_expr(&nft_expr!(log)); - } - } + /// Matches IGMP packets. + pub fn igmp(mut self) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8])); self } - fn accept(mut self) -> Self { - self.add_expr(&nft_expr!(verdict accept)); + /// Matches packets from source `port` and `protocol`. + pub fn sport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets to destination `port` and `protocol`. + pub fn dport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets on `protocol`. + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new( + CmpOp::Eq, + [match protocol { + Protocol::TCP => libc::IPPROTO_TCP, + Protocol::UDP => libc::IPPROTO_UDP, + } as u8], + )); + self + } + /// Matches packets in an already established connection. + pub fn established(mut self) -> Result { + let allowed_states = ConnTrackState::ESTABLISHED.bits(); + self.add_expr(Conntrack::new(ConntrackKey::State)); + self.add_expr(Bitwise::new( + allowed_states.to_le_bytes(), + 0u32.to_be_bytes(), + )?); + self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes())); + Ok(self) + } + /// Matches packets going through `iface_index`. Interface indexes can be queried with + /// `iface_index()`. + pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self { + self.add_expr(Meta::new(MetaType::Iif)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes())); self } - fn drop(mut self) -> Self { - self.add_expr(&nft_expr!(verdict drop)); + /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo" + pub fn iface(mut self, iface_name: &str) -> Result { + if iface_name.len() >= libc::IFNAMSIZ { + return Err(BuilderError::InterfaceNameTooLong); + } + let mut iface_vec = iface_name.as_bytes().to_vec(); + // null terminator + iface_vec.push(0u8); + + self.add_expr(Meta::new(MetaType::IifName)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_vec)); + Ok(self) + } + /// Matches packets whose source IP address is `saddr`. + pub fn saddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, true) + } + /// Matches packets whose destination IP address is `saddr`. + pub fn daddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, false) + } + /// Matches packets whose source network is `net`. + pub fn snetwork(self, net: IpNetwork) -> Result { + self.match_network(net, true) + } + /// Matches packets whose destination network is `net`. + pub fn dnetwork(self, net: IpNetwork) -> Result { + self.match_network(net, false) + } + /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. + pub fn accept(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Accept)); self } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, crate::MsgType::Add); + /// Adds the `Drop` verdict to the rule. The packet will be dropped. + pub fn drop(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Drop)); self } } /// Looks up the interface index for a given interface name. -pub fn iface_index(name: &str) -> Result { +pub fn iface_index(name: &str) -> Result { let c_name = CString::new(name)?; let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; match index { - 0 => Err(Error::NoSuchIface), - _ => Ok(index) + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(index), } } - - diff --git a/src/set.rs b/src/set.rs index 32d1666..ab29770 100644 --- a/src/set.rs +++ b/src/set.rs @@ -55,11 +55,10 @@ pub struct SetBuilder { } impl SetBuilder { - pub fn new(name: impl Into, id: u32, table: &Table) -> Result { + pub fn new(name: impl Into, table: &Table) -> Result { let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; let set_name = name.into(); let set = Set::default() - .with_id(id) .with_key_type(K::TYPE) .with_key_len(K::LEN) .with_table(table_name) diff --git a/src/table.rs b/src/table.rs index 63bf669..81a26ef 100644 --- a/src/table.rs +++ b/src/table.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. @@ -32,6 +32,12 @@ impl Table { res.family = family; res } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Table { diff --git a/src/tests/expr.rs b/src/tests/expr.rs index 141f6ac..35c4fea 100644 --- a/src/tests/expr.rs +++ b/src/tests/expr.rs @@ -5,21 +5,23 @@ use libc::NF_DROP; use crate::{ expr::{ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, - HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType, - Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, + HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat, + NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, }, + set::SetBuilder, sys::{ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES, NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, - NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_META_DREG, NFTA_META_KEY, NFTA_NAT_FAMILY, - NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, NFTA_PAYLOAD_DREG, - NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE, - NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_CMP_EQ, - NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, - NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, + NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG, + NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, + NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, + NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, + NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, + NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, }, + tests::{get_test_table, SET_NAME}, ProtocolFamily, }; @@ -283,39 +285,40 @@ fn log_expr_is_valid() { ); } -/* #[test] fn lookup_expr_is_valid() { - let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); - let mut rule = get_test_rule(); - let table = rule.get_chain().get_table(); - let mut set = Set::new(set_name, 0, table); + let table = get_test_table(); + let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap(); let address: Ipv4Addr = [8, 8, 8, 8].into(); - set.add(&address); + set_builder.add(&address); + let (set, _set_elements) = set_builder.finish(); let lookup = Lookup::new(&set).unwrap(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); - assert_eq!(nlmsghdr.nlmsg_len, 104); + + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()), NetlinkExpr::Final( NFTA_LOOKUP_SREG, NFT_REG_1.to_be_bytes().to_vec() ), - NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), - NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), ] ) ] @@ -325,7 +328,6 @@ fn lookup_expr_is_valid() { .to_raw() ); } -*/ #[test] fn masquerade_expr_is_valid() { diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3693d35..75fe8b0 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -20,8 +20,6 @@ pub const CHAIN_USERDATA: &'static str = "mockchaindata"; pub const RULE_USERDATA: &'static str = "mockruledata"; pub const SET_USERDATA: &'static str = "mocksetdata"; -pub const SET_ID: u32 = 123456; - type NetLinkType = u16; #[derive(Debug, thiserror::Error)] @@ -157,7 +155,7 @@ pub fn get_test_rule() -> Rule { } pub fn get_test_set() -> Set { - SetBuilder::::new(SET_NAME, SET_ID, &get_test_table()) + SetBuilder::::new(SET_NAME, &get_test_table()) .expect("Couldn't create a set") .finish() .0 diff --git a/src/tests/set.rs b/src/tests/set.rs index db27ced..6c8247c 100644 --- a/src/tests/set.rs +++ b/src/tests/set.rs @@ -6,16 +6,16 @@ use crate::{ set::SetBuilder, sys::{ NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, - NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN, - NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, - NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, }, MsgType, }; use super::{ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, - SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME, + SET_NAME, SET_USERDATA, TABLE_NAME, }; #[test] @@ -28,7 +28,7 @@ fn new_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_NEWSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -37,7 +37,6 @@ fn new_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -55,7 +54,7 @@ fn delete_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_DELSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -64,7 +63,6 @@ fn delete_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -75,9 +73,8 @@ fn delete_empty_set() { fn new_set_with_data() { let ip1 = Ipv4Addr::new(127, 0, 0, 1); let ip2 = Ipv4Addr::new(1, 1, 1, 1); - let mut set_builder = - SetBuilder::::new(SET_NAME.to_string(), SET_ID, &get_test_table()) - .expect("Couldn't create a set"); + let mut set_builder = SetBuilder::::new(SET_NAME.to_string(), &get_test_table()) + .expect("Couldn't create a set"); set_builder.add(&ip1); set_builder.add(&ip2); -- cgit v1.2.3