aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/chain.rs8
-rw-r--r--src/chain_methods.rs40
-rw-r--r--src/data_type.rs9
-rw-r--r--src/error.rs6
-rw-r--r--src/expr/cmp.rs5
-rw-r--r--src/expr/ct.rs4
-rw-r--r--src/expr/log.rs14
-rw-r--r--src/expr/mod.rs30
-rw-r--r--src/lib.rs9
-rw-r--r--src/nlmsg.rs4
-rw-r--r--src/rule.rs26
-rw-r--r--src/rule_methods.rs355
-rw-r--r--src/set.rs3
-rw-r--r--src/table.rs8
-rw-r--r--src/tests/expr.rs46
-rw-r--r--src/tests/mod.rs4
-rw-r--r--src/tests/set.rs19
17 files changed, 270 insertions, 320 deletions
diff --git a/src/chain.rs b/src/chain.rs
index 0ce0ad8..37e4cb3 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -8,7 +8,7 @@ use crate::sys::{
NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
NFT_MSG_NEWCHAIN,
};
-use crate::{ProtocolFamily, Table};
+use crate::{Batch, ProtocolFamily, Table};
use std::fmt::Debug;
pub type ChainPriority = i32;
@@ -169,6 +169,12 @@ impl Chain {
chain
}
+
+ /// Appends this chain to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
+ }
}
impl NfNetlinkObject for Chain {
diff --git a/src/chain_methods.rs b/src/chain_methods.rs
deleted file mode 100644
index d384c35..0000000
--- a/src/chain_methods.rs
+++ /dev/null
@@ -1,40 +0,0 @@
-use crate::{Batch, Chain, Hook, MsgType, Policy, Table};
-use std::ffi::CString;
-use std::rc::Rc;
-
-
-/// A helper trait over [`crate::Chain`].
-pub trait ChainMethods {
- /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`].
- fn from_hook(hook: Hook, table: Rc<Table>) -> Self
- where Self: std::marker::Sized;
- /// Adds a [`crate::Policy`] to the current Chain.
- fn verdict(self, policy: Policy) -> Self;
- fn add_to_batch(self, batch: &mut Batch) -> Self;
-}
-
-
-impl ChainMethods for Chain {
- fn from_hook(hook: Hook, table: Rc<Table>) -> Self {
- let chain_name = match hook {
- Hook::PreRouting => "prerouting",
- Hook::Out => "out",
- Hook::PostRouting => "postrouting",
- Hook::Forward => "forward",
- Hook::In => "in",
- };
- let chain_name = CString::new(chain_name).unwrap();
- let mut chain = Chain::new(&chain_name, table);
- chain.set_hook(hook, 0);
- chain
- }
- fn verdict(mut self, policy: Policy) -> Self {
- self.set_policy(policy);
- self
- }
- fn add_to_batch(self, batch: &mut Batch) -> Self {
- batch.add(&self, MsgType::Add);
- self
- }
-}
-
diff --git a/src/data_type.rs b/src/data_type.rs
index f9c97cb..43a7f1a 100644
--- a/src/data_type.rs
+++ b/src/data_type.rs
@@ -1,4 +1,4 @@
-use std::net::{Ipv4Addr, Ipv6Addr};
+use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
pub trait DataType {
const TYPE: u32;
@@ -33,3 +33,10 @@ impl<const N: usize> DataType for [u8; N] {
self.to_vec()
}
}
+
+pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> {
+ match ip {
+ IpAddr::V4(x) => x.octets().to_vec(),
+ IpAddr::V6(x) => x.octets().to_vec(),
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index eae6898..f6b6247 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -129,6 +129,12 @@ pub enum BuilderError {
#[error("Missing name for the set")]
MissingSetName,
+
+ #[error("The interface name is too long to be written")]
+ InterfaceNameTooLong,
+
+ #[error("The log prefix string is more than 127 characters long")]
+ TooLongLogPrefix,
}
#[derive(thiserror::Error, Debug)]
diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs
index 223902f..86d3587 100644
--- a/src/expr/cmp.rs
+++ b/src/expr/cmp.rs
@@ -1,7 +1,6 @@
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
use crate::{
- data_type::DataType,
parser_impls::NfNetlinkData,
sys::{
NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT,
@@ -44,11 +43,11 @@ pub struct Cmp {
impl Cmp {
/// Returns a new comparison expression comparing the value loaded in the register with the
/// data in `data` using the comparison operator `op`.
- pub fn new(op: CmpOp, data: impl DataType) -> Self {
+ pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self {
Cmp {
sreg: Some(Register::Reg1),
op: Some(op),
- data: Some(NfNetlinkData::default().with_value(data.data())),
+ data: Some(NfNetlinkData::default().with_value(data.into())),
}
}
}
diff --git a/src/expr/ct.rs b/src/expr/ct.rs
index ccf61e1..ad76989 100644
--- a/src/expr/ct.rs
+++ b/src/expr/ct.rs
@@ -43,6 +43,10 @@ impl Expression for Conntrack {
}
impl Conntrack {
+ pub fn new(key: ConntrackKey) -> Self {
+ Self::default().with_dreg(Register::Reg1).with_key(key)
+ }
+
pub fn set_mark_value(&mut self, reg: Register) {
self.set_sreg(reg);
self.set_key(ConntrackKey::Mark);
diff --git a/src/expr/log.rs b/src/expr/log.rs
index 80bb7a9..cc2728e 100644
--- a/src/expr/log.rs
+++ b/src/expr/log.rs
@@ -1,7 +1,10 @@
use rustables_macros::nfnetlink_struct;
-use super::{Expression, ExpressionError};
-use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX};
+use super::Expression;
+use crate::{
+ error::BuilderError,
+ sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX},
+};
#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct]
@@ -14,10 +17,7 @@ pub struct Log {
}
impl Log {
- pub fn new(
- group: Option<u16>,
- prefix: Option<impl Into<String>>,
- ) -> Result<Log, ExpressionError> {
+ pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> {
let mut res = Log::default();
if let Some(group) = group {
res.set_group(group);
@@ -26,7 +26,7 @@ impl Log {
let prefix = prefix.into();
if prefix.bytes().count() > 127 {
- return Err(ExpressionError::TooLongLogPrefix);
+ return Err(BuilderError::TooLongLogPrefix);
}
res.set_prefix(prefix);
}
diff --git a/src/expr/mod.rs b/src/expr/mod.rs
index 979ebb2..058b0cb 100644
--- a/src/expr/mod.rs
+++ b/src/expr/mod.rs
@@ -6,7 +6,6 @@
use std::fmt::Debug;
use rustables_macros::nfnetlink_struct;
-use thiserror::Error;
use crate::error::DecodeError;
use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable};
@@ -55,35 +54,6 @@ pub use self::register::Register;
mod verdict;
pub use self::verdict::*;
-#[derive(Debug, Error)]
-pub enum ExpressionError {
- #[error("The log prefix string is more than 127 characters long")]
- /// The log prefix string is more than 127 characters long
- TooLongLogPrefix,
-
- #[error("The expected expression type doesn't match the name of the raw expression")]
- /// The expected expression type doesn't match the name of the raw expression.
- InvalidExpressionKind,
-
- #[error("Deserializing the requested type isn't implemented yet")]
- /// Deserializing the requested type isn't implemented yet.
- NotImplemented,
-
- #[error("The expression value cannot be deserialized to the requested type")]
- /// The expression value cannot be deserialized to the requested type.
- InvalidValue,
-
- #[error("A pointer was null while a non-null pointer was expected")]
- /// A pointer was null while a non-null pointer was expected.
- NullPointer,
-
- #[error(
- "The size of a raw value was incoherent with the expected type of the deserialized value"
- )]
- /// The size of a raw value was incoherent with the expected type of the deserialized value/
- InvalidDataSize,
-}
-
pub trait Expression {
fn get_name() -> &'static str;
}
diff --git a/src/lib.rs b/src/lib.rs
index 1ad1eed..dec5b76 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -53,7 +53,7 @@ use std::convert::TryFrom;
mod batch;
pub use batch::{default_batch_page_size, Batch};
-mod data_type;
+pub mod data_type;
mod table;
pub use table::list_tables;
@@ -65,9 +65,6 @@ pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass};
pub mod error;
-//mod chain_methods;
-//pub use chain_methods::ChainMethods;
-
pub mod query;
pub(crate) mod nlmsg;
@@ -80,8 +77,8 @@ pub use rule::Rule;
pub mod expr;
-//mod rule_methods;
-//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods};
+mod rule_methods;
+pub use rule_methods::{iface_index, Protocol};
pub mod set;
pub use set::Set;
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index b3710bf..1c5b519 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -62,10 +62,6 @@ impl<'a> NfNetlinkWriter<'a> {
&mut self.buf[start..start + size]
}
- pub fn extract_buffer(self) -> &'a mut Vec<u8> {
- self.buf
- }
-
// rewrite of `__nftnl_nlmsg_build_hdr`
pub fn write_header(
&mut self,
diff --git a/src/rule.rs b/src/rule.rs
index 7f732d3..858b9ce 100644
--- a/src/rule.rs
+++ b/src/rule.rs
@@ -4,7 +4,7 @@ use rustables_macros::nfnetlink_struct;
use crate::chain::Chain;
use crate::error::{BuilderError, QueryError};
-use crate::expr::ExpressionList;
+use crate::expr::{ExpressionList, RawExpression};
use crate::nlmsg::NfNetlinkObject;
use crate::query::list_objects_with_data;
use crate::sys::{
@@ -12,7 +12,7 @@ use crate::sys::{
NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND,
NLM_F_CREATE,
};
-use crate::ProtocolFamily;
+use crate::{Batch, ProtocolFamily};
/// A nftables firewall rule.
#[derive(Clone, PartialEq, Eq, Default, Debug)]
@@ -53,6 +53,28 @@ impl Rule {
.ok_or(BuilderError::MissingChainInformationError)?,
))
}
+
+ pub fn add_expr(&mut self, e: impl Into<RawExpression>) {
+ let exprs = match self.get_mut_expressions() {
+ Some(x) => x,
+ None => {
+ self.set_expressions(ExpressionList::default());
+ self.get_mut_expressions().unwrap()
+ }
+ };
+ exprs.add_value(e);
+ }
+
+ pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self {
+ self.add_expr(e);
+ self
+ }
+
+ /// Appends this rule to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
+ }
}
impl NfNetlinkObject for Rule {
diff --git a/src/rule_methods.rs b/src/rule_methods.rs
index d7145d7..dff9bf6 100644
--- a/src/rule_methods.rs
+++ b/src/rule_methods.rs
@@ -1,230 +1,211 @@
-use crate::{Batch, Rule, nft_expr, sys::libc};
-use crate::expr::{LogGroup, LogPrefix};
-use ipnetwork::IpNetwork;
-use std::ffi::{CString, NulError};
+use std::ffi::CString;
use std::net::IpAddr;
-use std::num::ParseIntError;
-
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
- #[error("Firewall is already started")]
- AlreadyDone,
- #[error("Error converting from a C string to a string")]
- CStringError(#[from] NulError),
- #[error("no interface found under that name")]
- NoSuchIface,
- #[error("Error converting from a string to an integer")]
- ParseError(#[from] ParseIntError),
- #[error("the interface name is too long")]
- NameTooLong,
-}
+use ipnetwork::IpNetwork;
+use crate::data_type::ip_to_vec;
+use crate::error::BuilderError;
+use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey};
+use crate::expr::{
+ Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta,
+ MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField,
+ VerdictKind,
+};
+use crate::Rule;
/// Simple protocol description. Note that it does not implement other layer 4 protocols as
/// IGMP et al. See [`Rule::igmp`] for a workaround.
-#[derive(Debug, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Protocol {
TCP,
- UDP
+ UDP,
}
-/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a
-/// verdict. Mostly adapted from [talpid-core's firewall].
-/// All methods return the rule itself, allowing them to be chained. Usage example :
-/// ```rust
-/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook};
-/// use std::ffi::CString;
-/// use std::rc::Rc;
-/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet));
-/// let mut batch = Batch::new();
-/// batch.add(&table, MsgType::Add);
-/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table))
-/// .add_to_batch(&mut batch));
-/// let rule = Rule::new(inbound)
-/// .dport("80", &Protocol::TCP).unwrap()
-/// .accept()
-/// .add_to_batch(&mut batch);
-/// ```
-/// [talpid-core's firewall]:
-/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs
-pub trait RuleMethods {
- /// Matches ICMP packets.
- fn icmp(self) -> Self;
- /// Matches IGMP packets.
- fn igmp(self) -> Self;
- /// Matches packets to destination `port` and `protocol`.
- fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets on `protocol`.
- fn protocol(self, protocol: Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets in an already established connection.
- fn established(self) -> Self where Self: std::marker::Sized;
- /// Matches packets going through `iface_index`. Interface indexes can be queried with
- /// `iface_index()`.
- fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo".
- fn iface(self, iface_name: &str) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix
- /// appended to each log line.
- fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self;
- /// Matches packets whose source IP address is `saddr`.
- fn saddr(self, ip: IpAddr) -> Self;
- /// Matches packets whose source network is `snet`.
- fn snetwork(self, ip: IpNetwork) -> Self;
- /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
- fn accept(self) -> Self;
- /// Adds the `Drop` verdict to the rule. The packet will be dropped.
- fn drop(self) -> Self;
- /// Appends this rule to `batch`.
- fn add_to_batch(self, batch: &mut Batch) -> Self;
-}
-
-/// A trait to add helper functions to match some criterium over `crate::Rule`.
-impl RuleMethods for Rule {
- fn icmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8));
- self
- }
- fn igmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8));
+impl Rule {
+ fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self {
+ self = self.protocol(protocol);
+ self.add_expr(
+ HighLevelPayload::Transport(match protocol {
+ Protocol::TCP => TransportHeaderField::Tcp(if source {
+ TCPHeaderField::Sport
+ } else {
+ TCPHeaderField::Dport
+ }),
+ Protocol::UDP => TransportHeaderField::Udp(if source {
+ UDPHeaderField::Sport
+ } else {
+ UDPHeaderField::Dport
+ }),
+ })
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes()));
self
}
- fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- &Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- self.add_expr(&nft_expr!(payload tcp dport));
- },
- &Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- self.add_expr(&nft_expr!(payload udp dport));
- }
- }
- // Convert the port to Big-Endian number spelling.
- // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969
- self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be()));
- Ok(self)
- }
- fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- },
- Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- }
- }
- Ok(self)
- }
- fn established(mut self) -> Self {
- let allowed_states = crate::expr::ct::States::ESTABLISHED.bits();
- self.add_expr(&nft_expr!(ct state));
- self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32));
- self.add_expr(&nft_expr!(cmp != 0u32));
- self
- }
- fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta iif));
- self.add_expr(&nft_expr!(cmp == iface_index));
- Ok(self)
- }
- fn iface(mut self, iface_name: &str) -> Result<Self, Error> {
- if iface_name.len() >= libc::IFNAMSIZ {
- return Err(Error::NameTooLong);
- }
- let mut name_arr = [0u8; libc::IFNAMSIZ];
- for (pos, i) in iface_name.bytes().enumerate() {
- name_arr[pos] = i;
- }
- self.add_expr(&nft_expr!(meta iifname));
- self.add_expr(&nft_expr!(cmp == name_arr.as_ref()));
- Ok(self)
- }
- fn saddr(mut self, ip: IpAddr) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+ pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self {
+ self.add_expr(Meta::new(MetaType::NfProto));
match ip {
IpAddr::V4(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
+ }
IpAddr::V6(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
}
}
self
}
- fn snetwork(mut self, net: IpNetwork) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+
+ pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> {
+ self.add_expr(Meta::new(MetaType::NfProto));
match net {
IpNetwork::V4(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32));
- self.add_expr(&nft_expr!(cmp == net.network()));
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?);
+ }
IpNetwork::V6(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..]));
- self.add_expr(&nft_expr!(cmp == net.network()));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?);
}
}
+ self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network())));
+ Ok(self)
+ }
+}
+
+impl Rule {
+ /// Matches ICMP packets.
+ pub fn icmp(mut self) -> Self {
+ // quid of icmpv6?
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8]));
self
}
- fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self {
- match (group.is_some(), prefix.is_some()) {
- (true, true) => {
- self.add_expr(&nft_expr!(log group group prefix prefix));
- },
- (false, true) => {
- self.add_expr(&nft_expr!(log prefix prefix));
- },
- (true, false) => {
- self.add_expr(&nft_expr!(log group group));
- },
- (false, false) => {
- self.add_expr(&nft_expr!(log));
- }
- }
+ /// Matches IGMP packets.
+ pub fn igmp(mut self) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8]));
self
}
- fn accept(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict accept));
+ /// Matches packets from source `port` and `protocol`.
+ pub fn sport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets to destination `port` and `protocol`.
+ pub fn dport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets on `protocol`.
+ pub fn protocol(mut self, protocol: Protocol) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(
+ CmpOp::Eq,
+ [match protocol {
+ Protocol::TCP => libc::IPPROTO_TCP,
+ Protocol::UDP => libc::IPPROTO_UDP,
+ } as u8],
+ ));
+ self
+ }
+ /// Matches packets in an already established connection.
+ pub fn established(mut self) -> Result<Self, BuilderError> {
+ let allowed_states = ConnTrackState::ESTABLISHED.bits();
+ self.add_expr(Conntrack::new(ConntrackKey::State));
+ self.add_expr(Bitwise::new(
+ allowed_states.to_le_bytes(),
+ 0u32.to_be_bytes(),
+ )?);
+ self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes()));
+ Ok(self)
+ }
+ /// Matches packets going through `iface_index`. Interface indexes can be queried with
+ /// `iface_index()`.
+ pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self {
+ self.add_expr(Meta::new(MetaType::Iif));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes()));
self
}
- fn drop(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict drop));
+ /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo"
+ pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> {
+ if iface_name.len() >= libc::IFNAMSIZ {
+ return Err(BuilderError::InterfaceNameTooLong);
+ }
+ let mut iface_vec = iface_name.as_bytes().to_vec();
+ // null terminator
+ iface_vec.push(0u8);
+
+ self.add_expr(Meta::new(MetaType::IifName));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_vec));
+ Ok(self)
+ }
+ /// Matches packets whose source IP address is `saddr`.
+ pub fn saddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, true)
+ }
+ /// Matches packets whose destination IP address is `saddr`.
+ pub fn daddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, false)
+ }
+ /// Matches packets whose source network is `net`.
+ pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, true)
+ }
+ /// Matches packets whose destination network is `net`.
+ pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, false)
+ }
+ /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
+ pub fn accept(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Accept));
self
}
- fn add_to_batch(self, batch: &mut Batch) -> Self {
- batch.add(&self, crate::MsgType::Add);
+ /// Adds the `Drop` verdict to the rule. The packet will be dropped.
+ pub fn drop(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Drop));
self
}
}
/// Looks up the interface index for a given interface name.
-pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> {
+pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> {
let c_name = CString::new(name)?;
let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
match index {
- 0 => Err(Error::NoSuchIface),
- _ => Ok(index)
+ 0 => Err(std::io::Error::last_os_error()),
+ _ => Ok(index),
}
}
-
-
diff --git a/src/set.rs b/src/set.rs
index 32d1666..ab29770 100644
--- a/src/set.rs
+++ b/src/set.rs
@@ -55,11 +55,10 @@ pub struct SetBuilder<K: DataType> {
}
impl<K: DataType> SetBuilder<K> {
- pub fn new(name: impl Into<String>, id: u32, table: &Table) -> Result<Self, BuilderError> {
+ pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> {
let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?;
let set_name = name.into();
let set = Set::default()
- .with_id(id)
.with_key_type(K::TYPE)
.with_key_len(K::LEN)
.with_table(table_name)
diff --git a/src/table.rs b/src/table.rs
index 63bf669..81a26ef 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -8,7 +8,7 @@ use crate::sys::{
NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
NFT_MSG_NEWTABLE,
};
-use crate::ProtocolFamily;
+use crate::{Batch, ProtocolFamily};
/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol
/// family and contains [`Chain`]s that in turn hold the rules.
@@ -32,6 +32,12 @@ impl Table {
res.family = family;
res
}
+
+ /// Appends this rule to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
+ }
}
impl NfNetlinkObject for Table {
diff --git a/src/tests/expr.rs b/src/tests/expr.rs
index 141f6ac..35c4fea 100644
--- a/src/tests/expr.rs
+++ b/src/tests/expr.rs
@@ -5,21 +5,23 @@ use libc::NF_DROP;
use crate::{
expr::{
Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField,
- HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType,
- Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind,
+ HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat,
+ NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind,
},
+ set::SetBuilder,
sys::{
NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG,
NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES,
NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT,
NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM,
- NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_META_DREG, NFTA_META_KEY, NFTA_NAT_FAMILY,
- NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, NFTA_PAYLOAD_DREG,
- NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE,
- NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_CMP_EQ,
- NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1,
- NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH,
+ NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG,
+ NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE,
+ NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE,
+ NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE,
+ NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT,
+ NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH,
},
+ tests::{get_test_table, SET_NAME},
ProtocolFamily,
};
@@ -283,39 +285,40 @@ fn log_expr_is_valid() {
);
}
-/*
#[test]
fn lookup_expr_is_valid() {
- let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap();
- let mut rule = get_test_rule();
- let table = rule.get_chain().get_table();
- let mut set = Set::new(set_name, 0, table);
+ let table = get_test_table();
+ let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap();
let address: Ipv4Addr = [8, 8, 8, 8].into();
- set.add(&address);
+ set_builder.add(&address);
+ let (set, _set_elements) = set_builder.finish();
let lookup = Lookup::new(&set).unwrap();
- let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup);
- assert_eq!(nlmsghdr.nlmsg_len, 104);
+
+ let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup));
+
+ let mut buf = Vec::new();
+ let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule);
+ assert_eq!(nlmsghdr.nlmsg_len, 96);
assert_eq!(
raw_expr,
NetlinkExpr::List(vec![
- NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()),
- NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()),
+ NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()),
NetlinkExpr::Nested(
NFTA_RULE_EXPRESSIONS,
vec![NetlinkExpr::Nested(
NFTA_LIST_ELEM,
vec![
- NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()),
+ NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()),
NetlinkExpr::Nested(
NFTA_EXPR_DATA,
vec![
+ NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()),
NetlinkExpr::Final(
NFTA_LOOKUP_SREG,
NFT_REG_1.to_be_bytes().to_vec()
),
- NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()),
- NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()),
]
)
]
@@ -325,7 +328,6 @@ fn lookup_expr_is_valid() {
.to_raw()
);
}
-*/
#[test]
fn masquerade_expr_is_valid() {
diff --git a/src/tests/mod.rs b/src/tests/mod.rs
index 3693d35..75fe8b0 100644
--- a/src/tests/mod.rs
+++ b/src/tests/mod.rs
@@ -20,8 +20,6 @@ pub const CHAIN_USERDATA: &'static str = "mockchaindata";
pub const RULE_USERDATA: &'static str = "mockruledata";
pub const SET_USERDATA: &'static str = "mocksetdata";
-pub const SET_ID: u32 = 123456;
-
type NetLinkType = u16;
#[derive(Debug, thiserror::Error)]
@@ -157,7 +155,7 @@ pub fn get_test_rule() -> Rule {
}
pub fn get_test_set<K: DataType>() -> Set {
- SetBuilder::<K>::new(SET_NAME, SET_ID, &get_test_table())
+ SetBuilder::<K>::new(SET_NAME, &get_test_table())
.expect("Couldn't create a set")
.finish()
.0
diff --git a/src/tests/set.rs b/src/tests/set.rs
index db27ced..6c8247c 100644
--- a/src/tests/set.rs
+++ b/src/tests/set.rs
@@ -6,16 +6,16 @@ use crate::{
set::SetBuilder,
sys::{
NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS,
- NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN,
- NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET,
- NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM,
+ NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE,
+ NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET,
+ NFT_MSG_NEWSETELEM,
},
MsgType,
};
use super::{
get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr,
- SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME,
+ SET_NAME, SET_USERDATA, TABLE_NAME,
};
#[test]
@@ -28,7 +28,7 @@ fn new_empty_set() {
get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
NFT_MSG_NEWSET as u8
);
- assert_eq!(nlmsghdr.nlmsg_len, 88);
+ assert_eq!(nlmsghdr.nlmsg_len, 80);
assert_eq!(
raw_expr,
@@ -37,7 +37,6 @@ fn new_empty_set() {
NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()),
- NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
])
.to_raw()
@@ -55,7 +54,7 @@ fn delete_empty_set() {
get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type),
NFT_MSG_DELSET as u8
);
- assert_eq!(nlmsghdr.nlmsg_len, 88);
+ assert_eq!(nlmsghdr.nlmsg_len, 80);
assert_eq!(
raw_expr,
@@ -64,7 +63,6 @@ fn delete_empty_set() {
NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()),
- NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()),
NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()),
])
.to_raw()
@@ -75,9 +73,8 @@ fn delete_empty_set() {
fn new_set_with_data() {
let ip1 = Ipv4Addr::new(127, 0, 0, 1);
let ip2 = Ipv4Addr::new(1, 1, 1, 1);
- let mut set_builder =
- SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), SET_ID, &get_test_table())
- .expect("Couldn't create a set");
+ let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table())
+ .expect("Couldn't create a set");
set_builder.add(&ip1);
set_builder.add(&ip2);