aboutsummaryrefslogtreecommitdiff
path: root/src/rule_methods.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/rule_methods.rs')
-rw-r--r--src/rule_methods.rs355
1 files changed, 168 insertions, 187 deletions
diff --git a/src/rule_methods.rs b/src/rule_methods.rs
index d7145d7..dff9bf6 100644
--- a/src/rule_methods.rs
+++ b/src/rule_methods.rs
@@ -1,230 +1,211 @@
-use crate::{Batch, Rule, nft_expr, sys::libc};
-use crate::expr::{LogGroup, LogPrefix};
-use ipnetwork::IpNetwork;
-use std::ffi::{CString, NulError};
+use std::ffi::CString;
use std::net::IpAddr;
-use std::num::ParseIntError;
-
-#[derive(thiserror::Error, Debug)]
-pub enum Error {
- #[error("Unable to open netlink socket to netfilter")]
- NetlinkOpenError(#[source] std::io::Error),
- #[error("Firewall is already started")]
- AlreadyDone,
- #[error("Error converting from a C string to a string")]
- CStringError(#[from] NulError),
- #[error("no interface found under that name")]
- NoSuchIface,
- #[error("Error converting from a string to an integer")]
- ParseError(#[from] ParseIntError),
- #[error("the interface name is too long")]
- NameTooLong,
-}
+use ipnetwork::IpNetwork;
+use crate::data_type::ip_to_vec;
+use crate::error::BuilderError;
+use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey};
+use crate::expr::{
+ Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta,
+ MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField,
+ VerdictKind,
+};
+use crate::Rule;
/// Simple protocol description. Note that it does not implement other layer 4 protocols as
/// IGMP et al. See [`Rule::igmp`] for a workaround.
-#[derive(Debug, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Protocol {
TCP,
- UDP
+ UDP,
}
-/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a
-/// verdict. Mostly adapted from [talpid-core's firewall].
-/// All methods return the rule itself, allowing them to be chained. Usage example :
-/// ```rust
-/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook};
-/// use std::ffi::CString;
-/// use std::rc::Rc;
-/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet));
-/// let mut batch = Batch::new();
-/// batch.add(&table, MsgType::Add);
-/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table))
-/// .add_to_batch(&mut batch));
-/// let rule = Rule::new(inbound)
-/// .dport("80", &Protocol::TCP).unwrap()
-/// .accept()
-/// .add_to_batch(&mut batch);
-/// ```
-/// [talpid-core's firewall]:
-/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs
-pub trait RuleMethods {
- /// Matches ICMP packets.
- fn icmp(self) -> Self;
- /// Matches IGMP packets.
- fn igmp(self) -> Self;
- /// Matches packets to destination `port` and `protocol`.
- fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets on `protocol`.
- fn protocol(self, protocol: Protocol) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets in an already established connection.
- fn established(self) -> Self where Self: std::marker::Sized;
- /// Matches packets going through `iface_index`. Interface indexes can be queried with
- /// `iface_index()`.
- fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo".
- fn iface(self, iface_name: &str) -> Result<Self, Error>
- where Self: std::marker::Sized;
- /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix
- /// appended to each log line.
- fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self;
- /// Matches packets whose source IP address is `saddr`.
- fn saddr(self, ip: IpAddr) -> Self;
- /// Matches packets whose source network is `snet`.
- fn snetwork(self, ip: IpNetwork) -> Self;
- /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
- fn accept(self) -> Self;
- /// Adds the `Drop` verdict to the rule. The packet will be dropped.
- fn drop(self) -> Self;
- /// Appends this rule to `batch`.
- fn add_to_batch(self, batch: &mut Batch) -> Self;
-}
-
-/// A trait to add helper functions to match some criterium over `crate::Rule`.
-impl RuleMethods for Rule {
- fn icmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8));
- self
- }
- fn igmp(mut self) -> Self {
- self.add_expr(&nft_expr!(meta l4proto));
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8));
+impl Rule {
+ fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self {
+ self = self.protocol(protocol);
+ self.add_expr(
+ HighLevelPayload::Transport(match protocol {
+ Protocol::TCP => TransportHeaderField::Tcp(if source {
+ TCPHeaderField::Sport
+ } else {
+ TCPHeaderField::Dport
+ }),
+ Protocol::UDP => TransportHeaderField::Udp(if source {
+ UDPHeaderField::Sport
+ } else {
+ UDPHeaderField::Dport
+ }),
+ })
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes()));
self
}
- fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- &Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- self.add_expr(&nft_expr!(payload tcp dport));
- },
- &Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- self.add_expr(&nft_expr!(payload udp dport));
- }
- }
- // Convert the port to Big-Endian number spelling.
- // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969
- self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be()));
- Ok(self)
- }
- fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta l4proto));
- match protocol {
- Protocol::TCP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8));
- },
- Protocol::UDP => {
- self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8));
- }
- }
- Ok(self)
- }
- fn established(mut self) -> Self {
- let allowed_states = crate::expr::ct::States::ESTABLISHED.bits();
- self.add_expr(&nft_expr!(ct state));
- self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32));
- self.add_expr(&nft_expr!(cmp != 0u32));
- self
- }
- fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> {
- self.add_expr(&nft_expr!(meta iif));
- self.add_expr(&nft_expr!(cmp == iface_index));
- Ok(self)
- }
- fn iface(mut self, iface_name: &str) -> Result<Self, Error> {
- if iface_name.len() >= libc::IFNAMSIZ {
- return Err(Error::NameTooLong);
- }
- let mut name_arr = [0u8; libc::IFNAMSIZ];
- for (pos, i) in iface_name.bytes().enumerate() {
- name_arr[pos] = i;
- }
- self.add_expr(&nft_expr!(meta iifname));
- self.add_expr(&nft_expr!(cmp == name_arr.as_ref()));
- Ok(self)
- }
- fn saddr(mut self, ip: IpAddr) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+ pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self {
+ self.add_expr(Meta::new(MetaType::NfProto));
match ip {
IpAddr::V4(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
+ }
IpAddr::V6(addr) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(cmp == addr))
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Cmp::new(CmpOp::Eq, addr.octets()));
}
}
self
}
- fn snetwork(mut self, net: IpNetwork) -> Self {
- self.add_expr(&nft_expr!(meta nfproto));
+
+ pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> {
+ self.add_expr(Meta::new(MetaType::NfProto));
match net {
IpNetwork::V4(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8));
- self.add_expr(&nft_expr!(payload ipv4 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32));
- self.add_expr(&nft_expr!(cmp == net.network()));
- },
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv4(if source {
+ IPv4HeaderField::Saddr
+ } else {
+ IPv4HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?);
+ }
IpNetwork::V6(_) => {
- self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8));
- self.add_expr(&nft_expr!(payload ipv6 saddr));
- self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..]));
- self.add_expr(&nft_expr!(cmp == net.network()));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]));
+ self.add_expr(
+ HighLevelPayload::Network(NetworkHeaderField::IPv6(if source {
+ IPv6HeaderField::Saddr
+ } else {
+ IPv6HeaderField::Daddr
+ }))
+ .build(),
+ );
+ self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?);
}
}
+ self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network())));
+ Ok(self)
+ }
+}
+
+impl Rule {
+ /// Matches ICMP packets.
+ pub fn icmp(mut self) -> Self {
+ // quid of icmpv6?
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8]));
self
}
- fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self {
- match (group.is_some(), prefix.is_some()) {
- (true, true) => {
- self.add_expr(&nft_expr!(log group group prefix prefix));
- },
- (false, true) => {
- self.add_expr(&nft_expr!(log prefix prefix));
- },
- (true, false) => {
- self.add_expr(&nft_expr!(log group group));
- },
- (false, false) => {
- self.add_expr(&nft_expr!(log));
- }
- }
+ /// Matches IGMP packets.
+ pub fn igmp(mut self) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8]));
self
}
- fn accept(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict accept));
+ /// Matches packets from source `port` and `protocol`.
+ pub fn sport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets to destination `port` and `protocol`.
+ pub fn dport(self, port: u16, protocol: Protocol) -> Self {
+ self.match_port(port, protocol, false)
+ }
+ /// Matches packets on `protocol`.
+ pub fn protocol(mut self, protocol: Protocol) -> Self {
+ self.add_expr(Meta::new(MetaType::L4Proto));
+ self.add_expr(Cmp::new(
+ CmpOp::Eq,
+ [match protocol {
+ Protocol::TCP => libc::IPPROTO_TCP,
+ Protocol::UDP => libc::IPPROTO_UDP,
+ } as u8],
+ ));
+ self
+ }
+ /// Matches packets in an already established connection.
+ pub fn established(mut self) -> Result<Self, BuilderError> {
+ let allowed_states = ConnTrackState::ESTABLISHED.bits();
+ self.add_expr(Conntrack::new(ConntrackKey::State));
+ self.add_expr(Bitwise::new(
+ allowed_states.to_le_bytes(),
+ 0u32.to_be_bytes(),
+ )?);
+ self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes()));
+ Ok(self)
+ }
+ /// Matches packets going through `iface_index`. Interface indexes can be queried with
+ /// `iface_index()`.
+ pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self {
+ self.add_expr(Meta::new(MetaType::Iif));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes()));
self
}
- fn drop(mut self) -> Self {
- self.add_expr(&nft_expr!(verdict drop));
+ /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo"
+ pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> {
+ if iface_name.len() >= libc::IFNAMSIZ {
+ return Err(BuilderError::InterfaceNameTooLong);
+ }
+ let mut iface_vec = iface_name.as_bytes().to_vec();
+ // null terminator
+ iface_vec.push(0u8);
+
+ self.add_expr(Meta::new(MetaType::IifName));
+ self.add_expr(Cmp::new(CmpOp::Eq, iface_vec));
+ Ok(self)
+ }
+ /// Matches packets whose source IP address is `saddr`.
+ pub fn saddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, true)
+ }
+ /// Matches packets whose destination IP address is `saddr`.
+ pub fn daddr(self, ip: IpAddr) -> Self {
+ self.match_ip(ip, false)
+ }
+ /// Matches packets whose source network is `net`.
+ pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, true)
+ }
+ /// Matches packets whose destination network is `net`.
+ pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> {
+ self.match_network(net, false)
+ }
+ /// Adds the `Accept` verdict to the rule. The packet will be sent to destination.
+ pub fn accept(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Accept));
self
}
- fn add_to_batch(self, batch: &mut Batch) -> Self {
- batch.add(&self, crate::MsgType::Add);
+ /// Adds the `Drop` verdict to the rule. The packet will be dropped.
+ pub fn drop(mut self) -> Self {
+ self.add_expr(Immediate::new_verdict(VerdictKind::Drop));
self
}
}
/// Looks up the interface index for a given interface name.
-pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> {
+pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> {
let c_name = CString::new(name)?;
let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
match index {
- 0 => Err(Error::NoSuchIface),
- _ => Ok(index)
+ 0 => Err(std::io::Error::last_os_error()),
+ _ => Ok(index),
}
}
-
-