diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/chain.rs | 8 | ||||
-rw-r--r-- | src/chain_methods.rs | 40 | ||||
-rw-r--r-- | src/data_type.rs | 9 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/expr/cmp.rs | 5 | ||||
-rw-r--r-- | src/expr/ct.rs | 4 | ||||
-rw-r--r-- | src/expr/log.rs | 14 | ||||
-rw-r--r-- | src/expr/mod.rs | 30 | ||||
-rw-r--r-- | src/lib.rs | 9 | ||||
-rw-r--r-- | src/nlmsg.rs | 4 | ||||
-rw-r--r-- | src/rule.rs | 26 | ||||
-rw-r--r-- | src/rule_methods.rs | 355 | ||||
-rw-r--r-- | src/set.rs | 3 | ||||
-rw-r--r-- | src/table.rs | 8 | ||||
-rw-r--r-- | src/tests/expr.rs | 46 | ||||
-rw-r--r-- | src/tests/mod.rs | 4 | ||||
-rw-r--r-- | src/tests/set.rs | 19 |
17 files changed, 270 insertions, 320 deletions
diff --git a/src/chain.rs b/src/chain.rs index 0ce0ad8..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; -use crate::{ProtocolFamily, Table}; +use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; pub type ChainPriority = i32; @@ -169,6 +169,12 @@ impl Chain { chain } + + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Chain { diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc<Table>) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc<Table>) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/data_type.rs b/src/data_type.rs index f9c97cb..43a7f1a 100644 --- a/src/data_type.rs +++ b/src/data_type.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; pub trait DataType { const TYPE: u32; @@ -33,3 +33,10 @@ impl<const N: usize> DataType for [u8; N] { self.to_vec() } } + +pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> { + match ip { + IpAddr::V4(x) => x.octets().to_vec(), + IpAddr::V6(x) => x.octets().to_vec(), + } +} diff --git a/src/error.rs b/src/error.rs index eae6898..f6b6247 100644 --- a/src/error.rs +++ b/src/error.rs @@ -129,6 +129,12 @@ pub enum BuilderError { #[error("Missing name for the set")] MissingSetName, + + #[error("The interface name is too long to be written")] + InterfaceNameTooLong, + + #[error("The log prefix string is more than 127 characters long")] + TooLongLogPrefix, } #[derive(thiserror::Error, Debug)] diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index 223902f..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,7 +1,6 @@ use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::{ - data_type::DataType, parser_impls::NfNetlinkData, sys::{ NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, @@ -44,11 +43,11 @@ pub struct Cmp { impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: impl DataType) -> Self { + pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { Cmp { sreg: Some(Register::Reg1), op: Some(op), - data: Some(NfNetlinkData::default().with_value(data.data())), + data: Some(NfNetlinkData::default().with_value(data.into())), } } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index ccf61e1..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -43,6 +43,10 @@ impl Expression for Conntrack { } impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) + } + pub fn set_mark_value(&mut self, reg: Register) { self.set_sreg(reg); self.set_key(ConntrackKey::Mark); diff --git a/src/expr/log.rs b/src/expr/log.rs index 80bb7a9..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,7 +1,10 @@ use rustables_macros::nfnetlink_struct; -use super::{Expression, ExpressionError}; -use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct] @@ -14,10 +17,7 @@ pub struct Log { } impl Log { - pub fn new( - group: Option<u16>, - prefix: Option<impl Into<String>>, - ) -> Result<Log, ExpressionError> { + pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> { let mut res = Log::default(); if let Some(group) = group { res.set_group(group); @@ -26,7 +26,7 @@ impl Log { let prefix = prefix.into(); if prefix.bytes().count() > 127 { - return Err(ExpressionError::TooLongLogPrefix); + return Err(BuilderError::TooLongLogPrefix); } res.set_prefix(prefix); } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 979ebb2..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -6,7 +6,6 @@ use std::fmt::Debug; use rustables_macros::nfnetlink_struct; -use thiserror::Error; use crate::error::DecodeError; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; @@ -55,35 +54,6 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -#[derive(Debug, Error)] -pub enum ExpressionError { - #[error("The log prefix string is more than 127 characters long")] - /// The log prefix string is more than 127 characters long - TooLongLogPrefix, - - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, -} - pub trait Expression { fn get_name() -> &'static str; } @@ -53,7 +53,7 @@ use std::convert::TryFrom; mod batch; pub use batch::{default_batch_page_size, Batch}; -mod data_type; +pub mod data_type; mod table; pub use table::list_tables; @@ -65,9 +65,6 @@ pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; pub mod error; -//mod chain_methods; -//pub use chain_methods::ChainMethods; - pub mod query; pub(crate) mod nlmsg; @@ -80,8 +77,8 @@ pub use rule::Rule; pub mod expr; -//mod rule_methods; -//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; +mod rule_methods; +pub use rule_methods::{iface_index, Protocol}; pub mod set; pub use set::Set; diff --git a/src/nlmsg.rs b/src/nlmsg.rs index b3710bf..1c5b519 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -62,10 +62,6 @@ impl<'a> NfNetlinkWriter<'a> { &mut self.buf[start..start + size] } - pub fn extract_buffer(self) -> &'a mut Vec<u8> { - self.buf - } - // rewrite of `__nftnl_nlmsg_build_hdr` pub fn write_header( &mut self, diff --git a/src/rule.rs b/src/rule.rs index 7f732d3..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -4,7 +4,7 @@ use rustables_macros::nfnetlink_struct; use crate::chain::Chain; use crate::error::{BuilderError, QueryError}; -use crate::expr::ExpressionList; +use crate::expr::{ExpressionList, RawExpression}; use crate::nlmsg::NfNetlinkObject; use crate::query::list_objects_with_data; use crate::sys::{ @@ -12,7 +12,7 @@ use crate::sys::{ NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, NLM_F_CREATE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. #[derive(Clone, PartialEq, Eq, Default, Debug)] @@ -53,6 +53,28 @@ impl Rule { .ok_or(BuilderError::MissingChainInformationError)?, )) } + + pub fn add_expr(&mut self, e: impl Into<RawExpression>) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() + } + }; + exprs.add_value(e); + } + + pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self { + self.add_expr(e); + self + } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Rule { diff --git a/src/rule_methods.rs b/src/rule_methods.rs index d7145d7..dff9bf6 100644 --- a/src/rule_methods.rs +++ b/src/rule_methods.rs @@ -1,230 +1,211 @@ -use crate::{Batch, Rule, nft_expr, sys::libc}; -use crate::expr::{LogGroup, LogPrefix}; -use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; +use std::ffi::CString; use std::net::IpAddr; -use std::num::ParseIntError; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C string to a string")] - CStringError(#[from] NulError), - #[error("no interface found under that name")] - NoSuchIface, - #[error("Error converting from a string to an integer")] - ParseError(#[from] ParseIntError), - #[error("the interface name is too long")] - NameTooLong, -} +use ipnetwork::IpNetwork; +use crate::data_type::ip_to_vec; +use crate::error::BuilderError; +use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey}; +use crate::expr::{ + Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta, + MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField, + VerdictKind, +}; +use crate::Rule; /// Simple protocol description. Note that it does not implement other layer 4 protocols as /// IGMP et al. See [`Rule::igmp`] for a workaround. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { TCP, - UDP + UDP, } -/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a -/// verdict. Mostly adapted from [talpid-core's firewall]. -/// All methods return the rule itself, allowing them to be chained. Usage example : -/// ```rust -/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook}; -/// use std::ffi::CString; -/// use std::rc::Rc; -/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet)); -/// let mut batch = Batch::new(); -/// batch.add(&table, MsgType::Add); -/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table)) -/// .add_to_batch(&mut batch)); -/// let rule = Rule::new(inbound) -/// .dport("80", &Protocol::TCP).unwrap() -/// .accept() -/// .add_to_batch(&mut batch); -/// ``` -/// [talpid-core's firewall]: -/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs -pub trait RuleMethods { - /// Matches ICMP packets. - fn icmp(self) -> Self; - /// Matches IGMP packets. - fn igmp(self) -> Self; - /// Matches packets to destination `port` and `protocol`. - fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets on `protocol`. - fn protocol(self, protocol: Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets in an already established connection. - fn established(self) -> Self where Self: std::marker::Sized; - /// Matches packets going through `iface_index`. Interface indexes can be queried with - /// `iface_index()`. - fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo". - fn iface(self, iface_name: &str) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix - /// appended to each log line. - fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self; - /// Matches packets whose source IP address is `saddr`. - fn saddr(self, ip: IpAddr) -> Self; - /// Matches packets whose source network is `snet`. - fn snetwork(self, ip: IpNetwork) -> Self; - /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. - fn accept(self) -> Self; - /// Adds the `Drop` verdict to the rule. The packet will be dropped. - fn drop(self) -> Self; - /// Appends this rule to `batch`. - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - -/// A trait to add helper functions to match some criterium over `crate::Rule`. -impl RuleMethods for Rule { - fn icmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8)); - self - } - fn igmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8)); +impl Rule { + fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self { + self = self.protocol(protocol); + self.add_expr( + HighLevelPayload::Transport(match protocol { + Protocol::TCP => TransportHeaderField::Tcp(if source { + TCPHeaderField::Sport + } else { + TCPHeaderField::Dport + }), + Protocol::UDP => TransportHeaderField::Udp(if source { + UDPHeaderField::Sport + } else { + UDPHeaderField::Dport + }), + }) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes())); self } - fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - &Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - self.add_expr(&nft_expr!(payload tcp dport)); - }, - &Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - self.add_expr(&nft_expr!(payload udp dport)); - } - } - // Convert the port to Big-Endian number spelling. - // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969 - self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be())); - Ok(self) - } - fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - }, - Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - } - } - Ok(self) - } - fn established(mut self) -> Self { - let allowed_states = crate::expr::ct::States::ESTABLISHED.bits(); - self.add_expr(&nft_expr!(ct state)); - self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32)); - self.add_expr(&nft_expr!(cmp != 0u32)); - self - } - fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta iif)); - self.add_expr(&nft_expr!(cmp == iface_index)); - Ok(self) - } - fn iface(mut self, iface_name: &str) -> Result<Self, Error> { - if iface_name.len() >= libc::IFNAMSIZ { - return Err(Error::NameTooLong); - } - let mut name_arr = [0u8; libc::IFNAMSIZ]; - for (pos, i) in iface_name.bytes().enumerate() { - name_arr[pos] = i; - } - self.add_expr(&nft_expr!(meta iifname)); - self.add_expr(&nft_expr!(cmp == name_arr.as_ref())); - Ok(self) - } - fn saddr(mut self, ip: IpAddr) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self { + self.add_expr(Meta::new(MetaType::NfProto)); match ip { IpAddr::V4(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); + } IpAddr::V6(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); } } self } - fn snetwork(mut self, net: IpNetwork) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + + pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> { + self.add_expr(Meta::new(MetaType::NfProto)); match net { IpNetwork::V4(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32)); - self.add_expr(&nft_expr!(cmp == net.network())); - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?); + } IpNetwork::V6(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..])); - self.add_expr(&nft_expr!(cmp == net.network())); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?); } } + self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network()))); + Ok(self) + } +} + +impl Rule { + /// Matches ICMP packets. + pub fn icmp(mut self) -> Self { + // quid of icmpv6? + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8])); self } - fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self { - match (group.is_some(), prefix.is_some()) { - (true, true) => { - self.add_expr(&nft_expr!(log group group prefix prefix)); - }, - (false, true) => { - self.add_expr(&nft_expr!(log prefix prefix)); - }, - (true, false) => { - self.add_expr(&nft_expr!(log group group)); - }, - (false, false) => { - self.add_expr(&nft_expr!(log)); - } - } + /// Matches IGMP packets. + pub fn igmp(mut self) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8])); self } - fn accept(mut self) -> Self { - self.add_expr(&nft_expr!(verdict accept)); + /// Matches packets from source `port` and `protocol`. + pub fn sport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets to destination `port` and `protocol`. + pub fn dport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets on `protocol`. + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new( + CmpOp::Eq, + [match protocol { + Protocol::TCP => libc::IPPROTO_TCP, + Protocol::UDP => libc::IPPROTO_UDP, + } as u8], + )); + self + } + /// Matches packets in an already established connection. + pub fn established(mut self) -> Result<Self, BuilderError> { + let allowed_states = ConnTrackState::ESTABLISHED.bits(); + self.add_expr(Conntrack::new(ConntrackKey::State)); + self.add_expr(Bitwise::new( + allowed_states.to_le_bytes(), + 0u32.to_be_bytes(), + )?); + self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes())); + Ok(self) + } + /// Matches packets going through `iface_index`. Interface indexes can be queried with + /// `iface_index()`. + pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self { + self.add_expr(Meta::new(MetaType::Iif)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes())); self } - fn drop(mut self) -> Self { - self.add_expr(&nft_expr!(verdict drop)); + /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo" + pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> { + if iface_name.len() >= libc::IFNAMSIZ { + return Err(BuilderError::InterfaceNameTooLong); + } + let mut iface_vec = iface_name.as_bytes().to_vec(); + // null terminator + iface_vec.push(0u8); + + self.add_expr(Meta::new(MetaType::IifName)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_vec)); + Ok(self) + } + /// Matches packets whose source IP address is `saddr`. + pub fn saddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, true) + } + /// Matches packets whose destination IP address is `saddr`. + pub fn daddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, false) + } + /// Matches packets whose source network is `net`. + pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, true) + } + /// Matches packets whose destination network is `net`. + pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, false) + } + /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. + pub fn accept(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Accept)); self } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, crate::MsgType::Add); + /// Adds the `Drop` verdict to the rule. The packet will be dropped. + pub fn drop(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Drop)); self } } /// Looks up the interface index for a given interface name. -pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> { +pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> { let c_name = CString::new(name)?; let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; match index { - 0 => Err(Error::NoSuchIface), - _ => Ok(index) + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(index), } } - - @@ -55,11 +55,10 @@ pub struct SetBuilder<K: DataType> { } impl<K: DataType> SetBuilder<K> { - pub fn new(name: impl Into<String>, id: u32, table: &Table) -> Result<Self, BuilderError> { + pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> { let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; let set_name = name.into(); let set = Set::default() - .with_id(id) .with_key_type(K::TYPE) .with_key_len(K::LEN) .with_table(table_name) diff --git a/src/table.rs b/src/table.rs index 63bf669..81a26ef 100644 --- a/src/table.rs +++ b/src/table.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. @@ -32,6 +32,12 @@ impl Table { res.family = family; res } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Table { diff --git a/src/tests/expr.rs b/src/tests/expr.rs index 141f6ac..35c4fea 100644 --- a/src/tests/expr.rs +++ b/src/tests/expr.rs @@ -5,21 +5,23 @@ use libc::NF_DROP; use crate::{ expr::{ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, - HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType, - Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, + HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat, + NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, }, + set::SetBuilder, sys::{ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES, NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, - NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_META_DREG, NFTA_META_KEY, NFTA_NAT_FAMILY, - NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, NFTA_PAYLOAD_DREG, - NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE, - NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_CMP_EQ, - NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, - NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, + NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG, + NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, + NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, + NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, + NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, + NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, }, + tests::{get_test_table, SET_NAME}, ProtocolFamily, }; @@ -283,39 +285,40 @@ fn log_expr_is_valid() { ); } -/* #[test] fn lookup_expr_is_valid() { - let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); - let mut rule = get_test_rule(); - let table = rule.get_chain().get_table(); - let mut set = Set::new(set_name, 0, table); + let table = get_test_table(); + let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap(); let address: Ipv4Addr = [8, 8, 8, 8].into(); - set.add(&address); + set_builder.add(&address); + let (set, _set_elements) = set_builder.finish(); let lookup = Lookup::new(&set).unwrap(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); - assert_eq!(nlmsghdr.nlmsg_len, 104); + + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()), NetlinkExpr::Final( NFTA_LOOKUP_SREG, NFT_REG_1.to_be_bytes().to_vec() ), - NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), - NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), ] ) ] @@ -325,7 +328,6 @@ fn lookup_expr_is_valid() { .to_raw() ); } -*/ #[test] fn masquerade_expr_is_valid() { diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3693d35..75fe8b0 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -20,8 +20,6 @@ pub const CHAIN_USERDATA: &'static str = "mockchaindata"; pub const RULE_USERDATA: &'static str = "mockruledata"; pub const SET_USERDATA: &'static str = "mocksetdata"; -pub const SET_ID: u32 = 123456; - type NetLinkType = u16; #[derive(Debug, thiserror::Error)] @@ -157,7 +155,7 @@ pub fn get_test_rule() -> Rule { } pub fn get_test_set<K: DataType>() -> Set { - SetBuilder::<K>::new(SET_NAME, SET_ID, &get_test_table()) + SetBuilder::<K>::new(SET_NAME, &get_test_table()) .expect("Couldn't create a set") .finish() .0 diff --git a/src/tests/set.rs b/src/tests/set.rs index db27ced..6c8247c 100644 --- a/src/tests/set.rs +++ b/src/tests/set.rs @@ -6,16 +6,16 @@ use crate::{ set::SetBuilder, sys::{ NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, - NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN, - NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, - NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, }, MsgType, }; use super::{ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, - SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME, + SET_NAME, SET_USERDATA, TABLE_NAME, }; #[test] @@ -28,7 +28,7 @@ fn new_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_NEWSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -37,7 +37,6 @@ fn new_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -55,7 +54,7 @@ fn delete_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_DELSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -64,7 +63,6 @@ fn delete_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -75,9 +73,8 @@ fn delete_empty_set() { fn new_set_with_data() { let ip1 = Ipv4Addr::new(127, 0, 0, 1); let ip2 = Ipv4Addr::new(1, 1, 1, 1); - let mut set_builder = - SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), SET_ID, &get_test_table()) - .expect("Couldn't create a set"); + let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table()) + .expect("Couldn't create a set"); set_builder.add(&ip1); set_builder.add(&ip2); |