diff options
author | Simon THOBY <git@nightmared.fr> | 2023-01-08 22:24:40 +0100 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2023-01-08 22:28:42 +0100 |
commit | dc3c2ffab697b5d8fce7c69f76528fcfdf2edf38 (patch) | |
tree | af2dcb95c21a009933492ea80d71b25bdb0e24f6 | |
parent | 1d68fa40916295465be142b340f1a6381ea079a1 (diff) |
rewrite the examples
-rw-r--r-- | examples/add-rules.rs | 168 | ||||
-rw-r--r-- | examples/filter-ethernet.rs | 145 | ||||
-rw-r--r-- | examples/firewall.rs | 234 | ||||
-rw-r--r-- | macros/src/lib.rs | 7 | ||||
-rw-r--r-- | src/chain.rs | 8 | ||||
-rw-r--r-- | src/chain_methods.rs | 40 | ||||
-rw-r--r-- | src/data_type.rs | 9 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/expr/cmp.rs | 5 | ||||
-rw-r--r-- | src/expr/ct.rs | 4 | ||||
-rw-r--r-- | src/expr/log.rs | 14 | ||||
-rw-r--r-- | src/expr/mod.rs | 30 | ||||
-rw-r--r-- | src/lib.rs | 9 | ||||
-rw-r--r-- | src/nlmsg.rs | 4 | ||||
-rw-r--r-- | src/rule.rs | 26 | ||||
-rw-r--r-- | src/rule_methods.rs | 355 | ||||
-rw-r--r-- | src/set.rs | 3 | ||||
-rw-r--r-- | src/table.rs | 8 | ||||
-rw-r--r-- | src/tests/expr.rs | 46 | ||||
-rw-r--r-- | src/tests/mod.rs | 4 | ||||
-rw-r--r-- | src/tests/set.rs | 19 |
21 files changed, 535 insertions, 609 deletions
diff --git a/examples/add-rules.rs b/examples/add-rules.rs index b145291..a2b9c9c 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -38,12 +38,14 @@ use ipnetwork::{IpNetwork, Ipv4Network}; use rustables::{ - expr::{Cmp, CmpOp, ExpressionList, Immediate, Meta, MetaType, Verdict, VerdictKind}, - list_chains_for_table, list_rules_for_chain, list_tables, Batch, Chain, ChainPolicy, Hook, - HookClass, MsgType, ProtocolFamily, Rule, Table, + data_type::ip_to_vec, + expr::{ + Bitwise, Cmp, CmpOp, Counter, HighLevelPayload, ICMPv6HeaderField, IPv4HeaderField, + IcmpCode, Immediate, Meta, MetaType, NetworkHeaderField, TransportHeaderField, VerdictKind, + }, + iface_index, Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table, }; -//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, Rule, Table}; -use std::{ffi::CString, io, net::Ipv4Addr, rc::Rc}; +use std::net::Ipv4Addr; const TABLE_NAME: &str = "example-table"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; @@ -81,105 +83,94 @@ fn main() -> Result<(), Error> { batch.add(&out_chain, MsgType::Add); batch.add(&in_chain, MsgType::Add); - let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Accept)), - ); - batch.add(&rule, MsgType::Add); - - let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Continue)), - ); - - batch.add(&rule, MsgType::Add); - // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === - // Create a new rule object under the input chain. - let mut allow_loopback_in_rule = Rule::new(&in_chain)?; // Lookup the interface index of the loopback interface. let lo_iface_index = iface_index("lo")?; - allow_loopback_in_rule.set_expressions( - ExpressionList::default() + // Create a new rule object under the input chain. + let allow_loopback_in_rule = Rule::new(&in_chain)? // First expression to be evaluated in this rule is load the meta information "iif" // (incoming interface index) into the comparison register of netfilter. // When an incoming network packet is processed by this rule it will first be processed by this // expression, which will load the interface index of the interface the packet came from into // a special "register" in netfilter. - .with_value(Meta::new(MetaType::Iif)) + .with_expr(Meta::new(MetaType::Iif)) + // Next expression in the rule is to compare the value loaded into the register with our desired // interface index, and succeed only if it's equal. For any packet processed where the equality // does not hold the packet is said to not match this rule, and the packet moves on to be // processed by the next rule in the chain instead. - .with_value(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) + .with_expr(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) // Add a verdict expression to the rule. Any packet getting this far in the expression // processing without failing any expression will be given the verdict added here. - .with_value(Immediate::new_verdict(VerdictKind::Accept)), - ); + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); // Add the rule to the batch. batch.add(&allow_loopback_in_rule, rustables::MsgType::Add); - // // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === - // - // let mut block_out_to_private_net_rule = Rule::new(Rc::clone(&out_chain)); - // let private_net_ip = Ipv4Addr::new(10, 1, 0, 0); - // let private_net_prefix = 24; - // let private_net = IpNetwork::V4(Ipv4Network::new(private_net_ip, private_net_prefix)?); - // - // // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 - // // protocol the packet being processed is using. - // block_out_to_private_net_rule.add_expr(&nft_expr!(meta nfproto)); - // // Check if the currently processed packet is an IPv4 packet. This must be done before payload - // // data assuming the packet uses IPv4 can be loaded in the next expression. - // block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - // - // // Load the IPv4 destination address into the netfilter register. - // block_out_to_private_net_rule.add_expr(&nft_expr!(payload ipv4 daddr)); - // // Mask out the part of the destination address that is not part of the network bits. The result - // // of this bitwise masking is stored back into the same netfilter register. - // block_out_to_private_net_rule.add_expr(&nft_expr!(bitwise mask private_net.mask(), xor 0)); - // // Compare the result of the masking with the IP of the network we are interested in. - // block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == private_net.ip())); - // - // // Add a packet counter to the rule. Shows how many packets have been evaluated against this - // // expression. Since expressions are evaluated from first to last, putting this counter before - // // the above IP net check would make the counter increment on all packets also *not* matching - // // those expressions. Because the counter would then be evaluated before it fails a check. - // // Similarly, if the counter was added after the verdict it would always remain at zero. Since - // // when the packet hits the verdict expression any further processing of expressions stop. - // block_out_to_private_net_rule.add_expr(&nft_expr!(counter)); - // - // // Accept all the packets matching the rule so far. - // block_out_to_private_net_rule.add_expr(&nft_expr!(verdict accept)); - // - // // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, - // // and all the work on `block_out_to_private_net_rule` so far would go to waste. - // batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); - // - // // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === - // - // let mut allow_router_solicitation = Rule::new(Rc::clone(&out_chain)); - // - // // Check that the packet is IPv6 and ICMPv6 - // allow_router_solicitation.add_expr(&nft_expr!(meta nfproto)); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - // allow_router_solicitation.add_expr(&nft_expr!(meta l4proto)); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - // - // allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - // rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Type), - // )); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == 133u8)); - // allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - // rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Code), - // )); - // allow_router_solicitation.add_expr(&nft_expr!(cmp == 0u8)); - // - // allow_router_solicitation.add_expr(&nft_expr!(verdict accept)); - // - // batch.add(&allow_router_solicitation, rustables::MsgType::Add); + // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === + + let private_net_ip = Ipv4Addr::new(10, 1, 0, 0); + let private_net_prefix = 24; + let private_net = IpNetwork::V4(Ipv4Network::new(private_net_ip, private_net_prefix)?); + + let block_out_to_private_net_rule = Rule::new(&out_chain)? + // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 + // protocol the packet being processed is using. + .with_expr(Meta::new(MetaType::NfProto)) + + // Check if the currently processed packet is an IPv4 packet. This must be done before payload + // data assuming the packet uses IPv4 can be loaded in the next expression. + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])) + + // Load the IPv4 destination address into the netfilter register. + .with_expr(HighLevelPayload::Network(NetworkHeaderField::IPv4(IPv4HeaderField::Daddr)).build()) + + // Mask out the part of the destination address that is not part of the network bits. The result + // of this bitwise masking is stored back into the same netfilter register. + .with_expr(Bitwise::new(ip_to_vec(private_net.mask()), [0u8; 4])?) + + // Compare the result of the masking with the IP of the network we are interested in. + .with_expr(Cmp::new(CmpOp::Eq, ip_to_vec(private_net.ip()))) + + // Add a packet counter to the rule. Shows how many packets have been evaluated against this + // expression. Since expressions are evaluated from first to last, putting this counter before + // the above IP net check would make the counter increment on all packets also *not* matching + // those expressions. Because the counter would then be evaluated before it fails a check. + // Similarly, if the counter was added after the verdict it would always remain at zero. Since + // when the packet hits the verdict expression any further processing of expressions stop. + .with_expr(Counter::default()) + + // Accept all the packets matching the rule so far. + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); + + // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, + // and all the work on `block_out_to_private_net_rule` so far would go to waste. + batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); + + // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === + + let allow_router_solicitation = Rule::new(&out_chain)? + // Check that the packet is IPv6 and ICMPv6 + .with_expr(Meta::new(MetaType::NfProto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])) + .with_expr(Meta::new(MetaType::L4Proto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMPV6 as u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Type)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [133u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Code)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [IcmpCode::NoRoute as u8])) + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); + + batch.add(&allow_router_solicitation, rustables::MsgType::Add); // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === @@ -190,17 +181,6 @@ fn main() -> Result<(), Error> { Ok(batch.send()?) } -// Look up the interface index for a given interface name. -fn iface_index(name: &str) -> Result<libc::c_uint, Error> { - let c_name = CString::new(name)?; - let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; - if index == 0 { - Err(Error::from(io::Error::last_os_error())) - } else { - Ok(index) - } -} - #[derive(Debug)] struct Error(String); diff --git a/examples/filter-ethernet.rs b/examples/filter-ethernet.rs index 732c8cb..a136731 100644 --- a/examples/filter-ethernet.rs +++ b/examples/filter-ethernet.rs @@ -10,7 +10,7 @@ ///! table inet example-filter-ethernet { ///! chain chain-for-outgoing-packets { ///! type filter hook output priority 3; policy accept; -///! ether daddr 00:00:00:00:00:00 drop +///! ether daddr 01:02:03:04:05:06 drop ///! counter packets 0 bytes 0 meta random > 2147483647 counter packets 0 bytes 0 ///! } ///! } @@ -21,75 +21,78 @@ ///! ```bash ///! # nft delete table inet example-filter-ethernet ///! ``` -//use rustables::{nft_expr, query::send_batch, sys::libc, Batch, Chain, ProtoFamily, Rule, Table}; -use std::{ffi::CString, rc::Rc}; -// -//const TABLE_NAME: &str = "example-filter-ethernet"; -//const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; -// -//const BLOCK_THIS_MAC: &[u8] = &[0, 0, 0, 0, 0, 0]; -// +use rustables::{ + expr::{ + Cmp, CmpOp, Counter, ExpressionList, HighLevelPayload, Immediate, LLHeaderField, Meta, + MetaType, VerdictKind, + }, + Batch, Chain, ChainPolicy, Hook, HookClass, ProtocolFamily, Rule, Table, +}; + +const TABLE_NAME: &str = "example-filter-ethernet"; +const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; + +const BLOCK_THIS_MAC: &[u8] = &[1, 2, 3, 4, 5, 6]; + fn main() { - // // For verbose explanations of what all these lines up until the rule creation does, see the - // // `add-rules` example. - // let mut batch = Batch::new(); - // let table = Rc::new(Table::new( - // &CString::new(TABLE_NAME).unwrap(), - // ProtoFamily::Inet, - // )); - // batch.add(&Rc::clone(&table), rustables::MsgType::Add); - // - // let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table)); - // out_chain.set_hook(rustables::Hook::Out, 3); - // out_chain.set_policy(rustables::Policy::Accept); - // let out_chain = Rc::new(out_chain); - // batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add); - // - // // === ADD RULE DROPPING ALL TRAFFIC TO THE MAC ADDRESS IN `BLOCK_THIS_MAC` === - // - // let mut block_ethernet_rule = Rule::new(Rc::clone(&out_chain)); - // - // // Check that the interface type is an ethernet interface. Must be done before we can check - // // payload values in the ethernet header. - // block_ethernet_rule.add_expr(&nft_expr!(meta iiftype)); - // block_ethernet_rule.add_expr(&nft_expr!(cmp == libc::ARPHRD_ETHER)); - // - // // Compare the ethernet destination address against the MAC address we want to drop - // block_ethernet_rule.add_expr(&nft_expr!(payload ethernet daddr)); - // block_ethernet_rule.add_expr(&nft_expr!(cmp == BLOCK_THIS_MAC)); - // - // // Drop the matching packets. - // block_ethernet_rule.add_expr(&nft_expr!(verdict drop)); - // - // batch.add(&block_ethernet_rule, rustables::MsgType::Add); - // - // // === FOR FUN, ADD A PACKET THAT MATCHES 50% OF ALL PACKETS === - // - // // This packet has a counter before and after the check that has 50% chance of matching. - // // So after a number of packets has passed through this rule, the first counter should have a - // // value approximately double that of the second counter. This rule has no verdict, so it never - // // does anything with the matching packets. - // let mut random_rule = Rule::new(Rc::clone(&out_chain)); - // // This counter expression will be evaluated (and increment the counter) for all packets coming - // // through. - // random_rule.add_expr(&nft_expr!(counter)); - // - // // Load a pseudo-random 32 bit unsigned integer into the netfilter register. - // random_rule.add_expr(&nft_expr!(meta random)); - // // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. - // random_rule.add_expr(&nft_expr!(cmp > (::std::u32::MAX / 2).to_be())); - // - // // Add a second counter. This will only be incremented for the packets passing the random check. - // random_rule.add_expr(&nft_expr!(counter)); - // - // batch.add(&random_rule, rustables::MsgType::Add); - // - // // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === - // - // match batch.finalize() { - // Some(mut finalized_batch) => { - // send_batch(&mut finalized_batch).expect("Couldn't process the batch"); - // } - // None => todo!(), - // } + // For verbose explanations of what all these lines up until the rule creation does, see the + // `add-rules` example. + let mut batch = Batch::new(); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); + batch.add(&table, rustables::MsgType::Add); + + let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME); + out_chain.set_hook(Hook::new(HookClass::Out, 3)); + out_chain.set_policy(ChainPolicy::Accept); + batch.add(&out_chain, rustables::MsgType::Add); + + // === ADD RULE DROPPING ALL TRAFFIC TO THE MAC ADDRESS IN `BLOCK_THIS_MAC` === + + let mut block_ethernet_rule = Rule::new(&out_chain).unwrap(); + + block_ethernet_rule.set_expressions( + ExpressionList::default() + // Check that the interface type is an ethernet interface. Must be done before we can check + // payload values in the ethernet header. + .with_value(Meta::new(MetaType::IifType)) + .with_value(Cmp::new(CmpOp::Eq, (libc::ARPHRD_ETHER as u16).to_le_bytes())) + + // Compare the ethernet destination address against the MAC address we want to drop + .with_value(HighLevelPayload::LinkLayer(LLHeaderField::Daddr).build()) + .with_value(Cmp::new(CmpOp::Eq, BLOCK_THIS_MAC)) + + // Drop the matching packets. + .with_value(Immediate::new_verdict(VerdictKind::Drop)), + ); + + batch.add(&block_ethernet_rule, rustables::MsgType::Add); + + // === FOR FUN, ADD A PACKET THAT MATCHES 50% OF ALL PACKETS === + + // This packet has a counter before and after the check that has 50% chance of matching. + // So after a number of packets has passed through this rule, the first counter should have a + // value approximately double that of the second counter. This rule has no verdict, so it never + // does anything with the matching packets. + let mut random_rule = Rule::new(&out_chain).unwrap(); + + random_rule.set_expressions( + ExpressionList::default() + // This counter expression will be evaluated (and increment the counter) for all packets coming + // through. + .with_value(Counter::default()) + + // Load a pseudo-random 32 bit unsigned integer into the netfilter register. + .with_value(Meta::new(MetaType::PRandom)) + // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. + .with_value(Cmp::new(CmpOp::Gt, (::std::u32::MAX / 2).to_be_bytes())) + + // Add a second counter. This will only be incremented for the packets passing the random check. + .with_value(Counter::default()), + ); + + batch.add(&random_rule, rustables::MsgType::Add); + + // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === + + batch.send().unwrap(); } diff --git a/examples/firewall.rs b/examples/firewall.rs index fc25010..3169cdc 100644 --- a/examples/firewall.rs +++ b/examples/firewall.rs @@ -3,138 +3,124 @@ //use rustables::query::{send_batch, Error as QueryError}; //use rustables::expr::{LogGroup, LogPrefix, LogPrefixError}; use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; -use std::rc::Rc; -use std::str::Utf8Error; +use rustables::error::{BuilderError, QueryError}; +use rustables::expr::Log; +use rustables::{ + Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, Protocol, ProtocolFamily, Rule, Table, +}; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C String")] - NulError(#[from] NulError), - //#[error("Error creating match")] - //MatchError(#[from] MatchError), - #[error("Error converting to utf-8 string")] - Utf8Error(#[from] Utf8Error), + #[error("Error building a netlink object")] + BuildError(#[from] BuilderError), #[error("Error applying batch")] - BatchError(#[from] std::io::Error), - //#[error("Error applying batch")] - //QueryError(#[from] QueryError), - //#[error("Error encoding the prefix")] - //LogPrefixError(#[from] LogPrefixError), + QueryError(#[from] QueryError), } const TABLE_NAME: &str = "main-table"; +const INBOUND_CHAIN_NAME: &str = "in-chain"; +const FORWARD_CHAIN_NAME: &str = "forward-chain"; +const OUTBOUND_CHAIN_NAME: &str = "out-chain"; fn main() -> Result<(), Error> { - // let fw = Firewall::new()?; - // fw.start()?; + let fw = Firewall::new()?; + fw.start()?; Ok(()) } -// -// -///// An example firewall. See the source of its `start()` method. -//pub struct Firewall { -// batch: Batch, -// inbound: Rc<Chain>, -// _outbound: Rc<Chain>, -// _forward: Rc<Chain>, -// table: Rc<Table>, -//} -// -//impl Firewall { -// pub fn new() -> Result<Self, Error> { -// let mut batch = Batch::new(); -// let table = Rc::new( -// Table::new(&CString::new(TABLE_NAME)?, ProtoFamily::Inet) -// ); -// batch.add(&table, MsgType::Add); -// -// // Create base chains. Base chains are hooked into a Direction/Hook. -// let inbound = Rc::new( -// Chain::from_hook(Hook::In, Rc::clone(&table)) -// .verdict(Policy::Drop) -// .add_to_batch(&mut batch) -// ); -// let _outbound = Rc::new( -// Chain::from_hook(Hook::Out, Rc::clone(&table)) -// .verdict(Policy::Accept) -// .add_to_batch(&mut batch) -// ); -// let _forward = Rc::new( -// Chain::from_hook(Hook::Forward, Rc::clone(&table)) -// .verdict(Policy::Accept) -// .add_to_batch(&mut batch) -// ); -// -// Ok(Firewall { -// table, -// batch, -// inbound, -// _outbound, -// _forward -// }) -// } -// /// Allow some common-sense exceptions to inbound drop, and accept outbound and forward. -// pub fn start(mut self) -> Result<(), Error> { -// // Allow all established connections to get in. -// Rule::new(Rc::clone(&self.inbound)) -// .established() -// .accept() -// .add_to_batch(&mut self.batch); -// // Allow all traffic on the loopback interface. -// Rule::new(Rc::clone(&self.inbound)) -// .iface("lo")? -// .accept() -// .add_to_batch(&mut self.batch); -// // Allow ssh from anywhere, and log to dmesg with a prefix. -// Rule::new(Rc::clone(&self.inbound)) -// .dport("22", &Protocol::TCP)? -// .accept() -// .log(None, Some(LogPrefix::new("allow ssh connection:")?)) -// .add_to_batch(&mut self.batch); -// -// // Allow http from all IPs in 192.168.1.255/24 . -// let local_net = IpNetwork::new([192, 168, 1, 0].into(), 24).unwrap(); -// Rule::new(Rc::clone(&self.inbound)) -// .dport("80", &Protocol::TCP)? -// .snetwork(local_net) -// .accept() -// .add_to_batch(&mut self.batch); -// -// // Allow ICMP traffic, drop IGMP. -// Rule::new(Rc::clone(&self.inbound)) -// .icmp() -// .accept() -// .add_to_batch(&mut self.batch); -// Rule::new(Rc::clone(&self.inbound)) -// .igmp() -// .drop() -// .add_to_batch(&mut self.batch); -// -// // Log all traffic not accepted to NF_LOG group 1, accessible with ulogd. -// Rule::new(Rc::clone(&self.inbound)) -// .log(Some(LogGroup(1)), None) -// .add_to_batch(&mut self.batch); -// -// let mut finalized_batch = self.batch.finalize().unwrap(); -// send_batch(&mut finalized_batch)?; -// println!("table {} commited", TABLE_NAME); -// Ok(()) -// } -// /// If there is any table with name TABLE_NAME, remove it. -// pub fn stop(mut self) -> Result<(), Error> { -// self.batch.add(&self.table, MsgType::Add); -// self.batch.add(&self.table, MsgType::Del); -// -// let mut finalized_batch = self.batch.finalize().unwrap(); -// send_batch(&mut finalized_batch)?; -// println!("table {} destroyed", TABLE_NAME); -// Ok(()) -// } -//} -// -// + +/// An example firewall. See the source of its `start()` method. +pub struct Firewall { + batch: Batch, + inbound: Chain, + _outbound: Chain, + _forward: Chain, + table: Table, +} + +impl Firewall { + pub fn new() -> Result<Self, Error> { + let mut batch = Batch::new(); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); + batch.add(&table, MsgType::Add); + + // Create base chains. Base chains are hooked into a Direction/Hook. + let inbound = Chain::new(&table) + .with_name(INBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::In, 0)) + .with_policy(ChainPolicy::Drop) + .add_to_batch(&mut batch); + let _outbound = Chain::new(&table) + .with_name(OUTBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Out, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); + let _forward = Chain::new(&table) + .with_name(FORWARD_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Forward, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); + + Ok(Firewall { + table, + batch, + inbound, + _outbound, + _forward, + }) + } + /// Allow some common-sense exceptions to inbound drop, and accept outbound and forward. + pub fn start(mut self) -> Result<(), Error> { + // Allow all established connections to get in. + Rule::new(&self.inbound)? + .established()? + .accept() + .add_to_batch(&mut self.batch); + // Allow all traffic on the loopback interface. + Rule::new(&self.inbound)? + .iface("lo")? + .accept() + .add_to_batch(&mut self.batch); + // Allow ssh from anywhere, and log to dmesg with a prefix. + Rule::new(&self.inbound)? + .dport(22, Protocol::TCP) + .accept() + .with_expr(Log::new(None, Some("allow ssh connection:"))?) + .add_to_batch(&mut self.batch); + + // Allow http from all IPs in 192.168.1.255/24 . + let local_net = IpNetwork::new([192, 168, 1, 0].into(), 24).unwrap(); + Rule::new(&self.inbound)? + .dport(80, Protocol::TCP) + .snetwork(local_net)? + .accept() + .add_to_batch(&mut self.batch); + + // Allow ICMP traffic, drop IGMP. + Rule::new(&self.inbound)? + .icmp() + .accept() + .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .igmp() + .drop() + .add_to_batch(&mut self.batch); + + // Log all traffic not accepted to NF_LOG group 1, accessible with ulogd. + Rule::new(&self.inbound)? + .with_expr(Log::new(Some(1), None::<String>)?) + .add_to_batch(&mut self.batch); + + self.batch.send()?; + println!("table {} commited", TABLE_NAME); + Ok(()) + } + /// If there is any table with name TABLE_NAME, remove it. + pub fn stop(mut self) -> Result<(), Error> { + self.batch.add(&self.table, MsgType::Add); + self.batch.add(&self.table, MsgType::Del); + + self.batch.send()?; + println!("table {} destroyed", TABLE_NAME); + Ok(()) + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 9170e82..39f0d01 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -187,6 +187,9 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let getter_name = format!("get_{}", field_str); let getter_name = Ident::new(&getter_name, field.name.span()); + let muttable_getter_name = format!("get_mut_{}", field_str); + let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span()); + let setter_name = format!("set_{}", field_str); let setter_name = Ident::new(&setter_name, field.name.span()); @@ -199,6 +202,10 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { self.#field_name.as_ref() } + pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> { + self.#field_name.as_mut() + } + pub fn #setter_name(&mut self, val: impl Into<#field_type>) { self.#field_name = Some(val.into()); } diff --git a/src/chain.rs b/src/chain.rs index 0ce0ad8..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; -use crate::{ProtocolFamily, Table}; +use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; pub type ChainPriority = i32; @@ -169,6 +169,12 @@ impl Chain { chain } + + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Chain { diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc<Table>) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc<Table>) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/data_type.rs b/src/data_type.rs index f9c97cb..43a7f1a 100644 --- a/src/data_type.rs +++ b/src/data_type.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; pub trait DataType { const TYPE: u32; @@ -33,3 +33,10 @@ impl<const N: usize> DataType for [u8; N] { self.to_vec() } } + +pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> { + match ip { + IpAddr::V4(x) => x.octets().to_vec(), + IpAddr::V6(x) => x.octets().to_vec(), + } +} diff --git a/src/error.rs b/src/error.rs index eae6898..f6b6247 100644 --- a/src/error.rs +++ b/src/error.rs @@ -129,6 +129,12 @@ pub enum BuilderError { #[error("Missing name for the set")] MissingSetName, + + #[error("The interface name is too long to be written")] + InterfaceNameTooLong, + + #[error("The log prefix string is more than 127 characters long")] + TooLongLogPrefix, } #[derive(thiserror::Error, Debug)] diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index 223902f..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,7 +1,6 @@ use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::{ - data_type::DataType, parser_impls::NfNetlinkData, sys::{ NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, @@ -44,11 +43,11 @@ pub struct Cmp { impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: impl DataType) -> Self { + pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { Cmp { sreg: Some(Register::Reg1), op: Some(op), - data: Some(NfNetlinkData::default().with_value(data.data())), + data: Some(NfNetlinkData::default().with_value(data.into())), } } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index ccf61e1..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -43,6 +43,10 @@ impl Expression for Conntrack { } impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) + } + pub fn set_mark_value(&mut self, reg: Register) { self.set_sreg(reg); self.set_key(ConntrackKey::Mark); diff --git a/src/expr/log.rs b/src/expr/log.rs index 80bb7a9..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,7 +1,10 @@ use rustables_macros::nfnetlink_struct; -use super::{Expression, ExpressionError}; -use crate::sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct] @@ -14,10 +17,7 @@ pub struct Log { } impl Log { - pub fn new( - group: Option<u16>, - prefix: Option<impl Into<String>>, - ) -> Result<Log, ExpressionError> { + pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> { let mut res = Log::default(); if let Some(group) = group { res.set_group(group); @@ -26,7 +26,7 @@ impl Log { let prefix = prefix.into(); if prefix.bytes().count() > 127 { - return Err(ExpressionError::TooLongLogPrefix); + return Err(BuilderError::TooLongLogPrefix); } res.set_prefix(prefix); } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 979ebb2..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -6,7 +6,6 @@ use std::fmt::Debug; use rustables_macros::nfnetlink_struct; -use thiserror::Error; use crate::error::DecodeError; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; @@ -55,35 +54,6 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -#[derive(Debug, Error)] -pub enum ExpressionError { - #[error("The log prefix string is more than 127 characters long")] - /// The log prefix string is more than 127 characters long - TooLongLogPrefix, - - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, -} - pub trait Expression { fn get_name() -> &'static str; } @@ -53,7 +53,7 @@ use std::convert::TryFrom; mod batch; pub use batch::{default_batch_page_size, Batch}; -mod data_type; +pub mod data_type; mod table; pub use table::list_tables; @@ -65,9 +65,6 @@ pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; pub mod error; -//mod chain_methods; -//pub use chain_methods::ChainMethods; - pub mod query; pub(crate) mod nlmsg; @@ -80,8 +77,8 @@ pub use rule::Rule; pub mod expr; -//mod rule_methods; -//pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; +mod rule_methods; +pub use rule_methods::{iface_index, Protocol}; pub mod set; pub use set::Set; diff --git a/src/nlmsg.rs b/src/nlmsg.rs index b3710bf..1c5b519 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -62,10 +62,6 @@ impl<'a> NfNetlinkWriter<'a> { &mut self.buf[start..start + size] } - pub fn extract_buffer(self) -> &'a mut Vec<u8> { - self.buf - } - // rewrite of `__nftnl_nlmsg_build_hdr` pub fn write_header( &mut self, diff --git a/src/rule.rs b/src/rule.rs index 7f732d3..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -4,7 +4,7 @@ use rustables_macros::nfnetlink_struct; use crate::chain::Chain; use crate::error::{BuilderError, QueryError}; -use crate::expr::ExpressionList; +use crate::expr::{ExpressionList, RawExpression}; use crate::nlmsg::NfNetlinkObject; use crate::query::list_objects_with_data; use crate::sys::{ @@ -12,7 +12,7 @@ use crate::sys::{ NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, NLM_F_CREATE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. #[derive(Clone, PartialEq, Eq, Default, Debug)] @@ -53,6 +53,28 @@ impl Rule { .ok_or(BuilderError::MissingChainInformationError)?, )) } + + pub fn add_expr(&mut self, e: impl Into<RawExpression>) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() + } + }; + exprs.add_value(e); + } + + pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self { + self.add_expr(e); + self + } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Rule { diff --git a/src/rule_methods.rs b/src/rule_methods.rs index d7145d7..dff9bf6 100644 --- a/src/rule_methods.rs +++ b/src/rule_methods.rs @@ -1,230 +1,211 @@ -use crate::{Batch, Rule, nft_expr, sys::libc}; -use crate::expr::{LogGroup, LogPrefix}; -use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; +use std::ffi::CString; use std::net::IpAddr; -use std::num::ParseIntError; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C string to a string")] - CStringError(#[from] NulError), - #[error("no interface found under that name")] - NoSuchIface, - #[error("Error converting from a string to an integer")] - ParseError(#[from] ParseIntError), - #[error("the interface name is too long")] - NameTooLong, -} +use ipnetwork::IpNetwork; +use crate::data_type::ip_to_vec; +use crate::error::BuilderError; +use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey}; +use crate::expr::{ + Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta, + MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField, + VerdictKind, +}; +use crate::Rule; /// Simple protocol description. Note that it does not implement other layer 4 protocols as /// IGMP et al. See [`Rule::igmp`] for a workaround. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { TCP, - UDP + UDP, } -/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a -/// verdict. Mostly adapted from [talpid-core's firewall]. -/// All methods return the rule itself, allowing them to be chained. Usage example : -/// ```rust -/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook}; -/// use std::ffi::CString; -/// use std::rc::Rc; -/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet)); -/// let mut batch = Batch::new(); -/// batch.add(&table, MsgType::Add); -/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table)) -/// .add_to_batch(&mut batch)); -/// let rule = Rule::new(inbound) -/// .dport("80", &Protocol::TCP).unwrap() -/// .accept() -/// .add_to_batch(&mut batch); -/// ``` -/// [talpid-core's firewall]: -/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs -pub trait RuleMethods { - /// Matches ICMP packets. - fn icmp(self) -> Self; - /// Matches IGMP packets. - fn igmp(self) -> Self; - /// Matches packets to destination `port` and `protocol`. - fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets on `protocol`. - fn protocol(self, protocol: Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets in an already established connection. - fn established(self) -> Self where Self: std::marker::Sized; - /// Matches packets going through `iface_index`. Interface indexes can be queried with - /// `iface_index()`. - fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo". - fn iface(self, iface_name: &str) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix - /// appended to each log line. - fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self; - /// Matches packets whose source IP address is `saddr`. - fn saddr(self, ip: IpAddr) -> Self; - /// Matches packets whose source network is `snet`. - fn snetwork(self, ip: IpNetwork) -> Self; - /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. - fn accept(self) -> Self; - /// Adds the `Drop` verdict to the rule. The packet will be dropped. - fn drop(self) -> Self; - /// Appends this rule to `batch`. - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - -/// A trait to add helper functions to match some criterium over `crate::Rule`. -impl RuleMethods for Rule { - fn icmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8)); - self - } - fn igmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8)); +impl Rule { + fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self { + self = self.protocol(protocol); + self.add_expr( + HighLevelPayload::Transport(match protocol { + Protocol::TCP => TransportHeaderField::Tcp(if source { + TCPHeaderField::Sport + } else { + TCPHeaderField::Dport + }), + Protocol::UDP => TransportHeaderField::Udp(if source { + UDPHeaderField::Sport + } else { + UDPHeaderField::Dport + }), + }) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes())); self } - fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - &Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - self.add_expr(&nft_expr!(payload tcp dport)); - }, - &Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - self.add_expr(&nft_expr!(payload udp dport)); - } - } - // Convert the port to Big-Endian number spelling. - // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969 - self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be())); - Ok(self) - } - fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - }, - Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - } - } - Ok(self) - } - fn established(mut self) -> Self { - let allowed_states = crate::expr::ct::States::ESTABLISHED.bits(); - self.add_expr(&nft_expr!(ct state)); - self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32)); - self.add_expr(&nft_expr!(cmp != 0u32)); - self - } - fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta iif)); - self.add_expr(&nft_expr!(cmp == iface_index)); - Ok(self) - } - fn iface(mut self, iface_name: &str) -> Result<Self, Error> { - if iface_name.len() >= libc::IFNAMSIZ { - return Err(Error::NameTooLong); - } - let mut name_arr = [0u8; libc::IFNAMSIZ]; - for (pos, i) in iface_name.bytes().enumerate() { - name_arr[pos] = i; - } - self.add_expr(&nft_expr!(meta iifname)); - self.add_expr(&nft_expr!(cmp == name_arr.as_ref())); - Ok(self) - } - fn saddr(mut self, ip: IpAddr) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self { + self.add_expr(Meta::new(MetaType::NfProto)); match ip { IpAddr::V4(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); + } IpAddr::V6(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); } } self } - fn snetwork(mut self, net: IpNetwork) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + + pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> { + self.add_expr(Meta::new(MetaType::NfProto)); match net { IpNetwork::V4(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32)); - self.add_expr(&nft_expr!(cmp == net.network())); - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?); + } IpNetwork::V6(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..])); - self.add_expr(&nft_expr!(cmp == net.network())); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?); } } + self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network()))); + Ok(self) + } +} + +impl Rule { + /// Matches ICMP packets. + pub fn icmp(mut self) -> Self { + // quid of icmpv6? + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8])); self } - fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self { - match (group.is_some(), prefix.is_some()) { - (true, true) => { - self.add_expr(&nft_expr!(log group group prefix prefix)); - }, - (false, true) => { - self.add_expr(&nft_expr!(log prefix prefix)); - }, - (true, false) => { - self.add_expr(&nft_expr!(log group group)); - }, - (false, false) => { - self.add_expr(&nft_expr!(log)); - } - } + /// Matches IGMP packets. + pub fn igmp(mut self) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8])); self } - fn accept(mut self) -> Self { - self.add_expr(&nft_expr!(verdict accept)); + /// Matches packets from source `port` and `protocol`. + pub fn sport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets to destination `port` and `protocol`. + pub fn dport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets on `protocol`. + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new( + CmpOp::Eq, + [match protocol { + Protocol::TCP => libc::IPPROTO_TCP, + Protocol::UDP => libc::IPPROTO_UDP, + } as u8], + )); + self + } + /// Matches packets in an already established connection. + pub fn established(mut self) -> Result<Self, BuilderError> { + let allowed_states = ConnTrackState::ESTABLISHED.bits(); + self.add_expr(Conntrack::new(ConntrackKey::State)); + self.add_expr(Bitwise::new( + allowed_states.to_le_bytes(), + 0u32.to_be_bytes(), + )?); + self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes())); + Ok(self) + } + /// Matches packets going through `iface_index`. Interface indexes can be queried with + /// `iface_index()`. + pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self { + self.add_expr(Meta::new(MetaType::Iif)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes())); self } - fn drop(mut self) -> Self { - self.add_expr(&nft_expr!(verdict drop)); + /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo" + pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> { + if iface_name.len() >= libc::IFNAMSIZ { + return Err(BuilderError::InterfaceNameTooLong); + } + let mut iface_vec = iface_name.as_bytes().to_vec(); + // null terminator + iface_vec.push(0u8); + + self.add_expr(Meta::new(MetaType::IifName)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_vec)); + Ok(self) + } + /// Matches packets whose source IP address is `saddr`. + pub fn saddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, true) + } + /// Matches packets whose destination IP address is `saddr`. + pub fn daddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, false) + } + /// Matches packets whose source network is `net`. + pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, true) + } + /// Matches packets whose destination network is `net`. + pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, false) + } + /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. + pub fn accept(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Accept)); self } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, crate::MsgType::Add); + /// Adds the `Drop` verdict to the rule. The packet will be dropped. + pub fn drop(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Drop)); self } } /// Looks up the interface index for a given interface name. -pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> { +pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> { let c_name = CString::new(name)?; let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; match index { - 0 => Err(Error::NoSuchIface), - _ => Ok(index) + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(index), } } - - @@ -55,11 +55,10 @@ pub struct SetBuilder<K: DataType> { } impl<K: DataType> SetBuilder<K> { - pub fn new(name: impl Into<String>, id: u32, table: &Table) -> Result<Self, BuilderError> { + pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> { let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; let set_name = name.into(); let set = Set::default() - .with_id(id) .with_key_type(K::TYPE) .with_key_len(K::LEN) .with_table(table_name) diff --git a/src/table.rs b/src/table.rs index 63bf669..81a26ef 100644 --- a/src/table.rs +++ b/src/table.rs @@ -8,7 +8,7 @@ use crate::sys::{ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; -use crate::ProtocolFamily; +use crate::{Batch, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. @@ -32,6 +32,12 @@ impl Table { res.family = family; res } + + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self + } } impl NfNetlinkObject for Table { diff --git a/src/tests/expr.rs b/src/tests/expr.rs index 141f6ac..35c4fea 100644 --- a/src/tests/expr.rs +++ b/src/tests/expr.rs @@ -5,21 +5,23 @@ use libc::NF_DROP; use crate::{ expr::{ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, - HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType, - Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, + HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat, + NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, }, + set::SetBuilder, sys::{ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES, NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, - NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_META_DREG, NFTA_META_KEY, NFTA_NAT_FAMILY, - NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, NFTA_PAYLOAD_DREG, - NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, NFTA_REJECT_TYPE, - NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, NFTA_VERDICT_CODE, NFT_CMP_EQ, - NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, - NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, + NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG, + NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, + NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, + NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, + NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, + NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, }, + tests::{get_test_table, SET_NAME}, ProtocolFamily, }; @@ -283,39 +285,40 @@ fn log_expr_is_valid() { ); } -/* #[test] fn lookup_expr_is_valid() { - let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); - let mut rule = get_test_rule(); - let table = rule.get_chain().get_table(); - let mut set = Set::new(set_name, 0, table); + let table = get_test_table(); + let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap(); let address: Ipv4Addr = [8, 8, 8, 8].into(); - set.add(&address); + set_builder.add(&address); + let (set, _set_elements) = set_builder.finish(); let lookup = Lookup::new(&set).unwrap(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); - assert_eq!(nlmsghdr.nlmsg_len, 104); + + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()), NetlinkExpr::Final( NFTA_LOOKUP_SREG, NFT_REG_1.to_be_bytes().to_vec() ), - NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), - NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), ] ) ] @@ -325,7 +328,6 @@ fn lookup_expr_is_valid() { .to_raw() ); } -*/ #[test] fn masquerade_expr_is_valid() { diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3693d35..75fe8b0 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -20,8 +20,6 @@ pub const CHAIN_USERDATA: &'static str = "mockchaindata"; pub const RULE_USERDATA: &'static str = "mockruledata"; pub const SET_USERDATA: &'static str = "mocksetdata"; -pub const SET_ID: u32 = 123456; - type NetLinkType = u16; #[derive(Debug, thiserror::Error)] @@ -157,7 +155,7 @@ pub fn get_test_rule() -> Rule { } pub fn get_test_set<K: DataType>() -> Set { - SetBuilder::<K>::new(SET_NAME, SET_ID, &get_test_table()) + SetBuilder::<K>::new(SET_NAME, &get_test_table()) .expect("Couldn't create a set") .finish() .0 diff --git a/src/tests/set.rs b/src/tests/set.rs index db27ced..6c8247c 100644 --- a/src/tests/set.rs +++ b/src/tests/set.rs @@ -6,16 +6,16 @@ use crate::{ set::SetBuilder, sys::{ NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, - NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN, - NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, - NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, }, MsgType, }; use super::{ get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, - SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME, + SET_NAME, SET_USERDATA, TABLE_NAME, }; #[test] @@ -28,7 +28,7 @@ fn new_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_NEWSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -37,7 +37,6 @@ fn new_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -55,7 +54,7 @@ fn delete_empty_set() { get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), NFT_MSG_DELSET as u8 ); - assert_eq!(nlmsghdr.nlmsg_len, 88); + assert_eq!(nlmsghdr.nlmsg_len, 80); assert_eq!( raw_expr, @@ -64,7 +63,6 @@ fn delete_empty_set() { NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), ]) .to_raw() @@ -75,9 +73,8 @@ fn delete_empty_set() { fn new_set_with_data() { let ip1 = Ipv4Addr::new(127, 0, 0, 1); let ip2 = Ipv4Addr::new(1, 1, 1, 1); - let mut set_builder = - SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), SET_ID, &get_test_table()) - .expect("Couldn't create a set"); + let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table()) + .expect("Couldn't create a set"); set_builder.add(&ip1); set_builder.add(&ip2); |