diff options
-rw-r--r-- | src/blocking.rs | 34 | ||||
-rw-r--r-- | src/error.rs | 2 | ||||
-rw-r--r-- | src/rule.rs | 130 |
3 files changed, 115 insertions, 51 deletions
diff --git a/src/blocking.rs b/src/blocking.rs index 7c47b5b..04b15c8 100644 --- a/src/blocking.rs +++ b/src/blocking.rs @@ -112,10 +112,38 @@ pub mod route { pub mod rule { use super::Connection; + use std::net::{Ipv4Addr, Ipv6Addr}; + use crate::rule::Rule; + use crate::Result; - impl Connection { - blockify!(rule_add, r: Rule); - blockify!(rule_del, r: Rule); + impl Rule<()> { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } + } + + impl Rule<Ipv4Addr> { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } + } + + impl Rule<Ipv6Addr> { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } } } diff --git a/src/error.rs b/src/error.rs index d0bc7e7..cb2bd32 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ use thiserror::Error; pub enum Error { #[error("link {0} not found")] LinkNotFound(String), + #[error("prefix matching in protocol-agnostic route")] + PrefixesDisallowed, #[error("link name contains nul bytes: {0}")] Nul(#[from] ffi::NulError), diff --git a/src/rule.rs b/src/rule.rs index 044c067..bd68c73 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -4,26 +4,22 @@ use crate::{Connection, Error, Result}; pub use netlink_packet_route::rule::RuleAction; -use netlink_packet_route::rule::{RuleAttribute, RuleFlag, RuleHeader, RuleMessage}; -use rtnetlink::RuleAddRequest; - -trait IpAddr46 {} +use std::net::{Ipv4Addr, Ipv6Addr}; -impl IpAddr46 for () {} -impl IpAddr46 for Ipv4Addr {} -impl IpAddr46 for Ipv6Addr {} +use netlink_packet_route::rule::RuleFlag; +use rtnetlink::RuleAddRequest; /// A rule entry. #[derive(Debug)] -pub struct Rule<A: IpAddr46> { +pub struct Rule<T> { /// Whether to invert the matching criteria. pub invert: bool, /// Firewall mark to match against. pub fwmark: Option<u32>, /// Destination prefix to match against. - pub dst: Option<(A, u8)>, + pub dst: Option<(T, u8)>, /// Source prefix to match against. - pub src: Option<(A, u8)>, + pub src: Option<(T, u8)>, /// Action to perform. pub action: RuleAction, /// Routing table to use if `RuleAction::ToTable` is selected. @@ -31,75 +27,93 @@ pub struct Rule<A: IpAddr46> { } impl Rule<()> { - fn addSrcDst(&self, rq: RuleAddRequest) -> RuleAddRequest { - rq + pub async fn add(self, c: &Connection) -> Result<()> { + let add = self.prepare_add(c); + + if self.dst.is_some() || self.src.is_some() { + return Err(Error::PrefixesDisallowed); + } + + add.execute().await?; + Ok(()) + } + + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c); + + if self.dst.is_some() || self.src.is_some() { + return Err(Error::PrefixesDisallowed); + } + + c.handle() + .rule() + .del(add.message_mut().clone()) + .execute() + .await?; + Ok(()) } } impl Rule<Ipv4Addr> { - fn addSrcDst(&self, mut rq: RuleAddRequest) -> RuleAddRequest { - rq = rq.v4(); + pub async fn add(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v4(); + if let Some(dst) = self.dst { - rq = rq.destination_prefix(dst.0, dst.1); + add = add.destination_prefix(dst.0, dst.1); } if let Some(src) = self.src { - rq = rq.destination_prefix(src.0, src.1); + add = add.source_prefix(src.0, src.1) } - rq + add.execute().await?; + Ok(()) } -} -impl Rule<Ipv6Addr> { - fn addSrcDst(&self, mut rq: RuleAddRequest) -> RuleAddRequest { - rq = rq.v6(); + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v4(); + if let Some(dst) = self.dst { - rq = rq.destination_prefix(dst.0, dst.1); + add = add.destination_prefix(dst.0, dst.1); } if let Some(src) = self.src { - rq = rq.destination_prefix(src.0, src.1); + add = add.source_prefix(src.0, src.1) } - rq + c.handle() + .rule() + .del(add.message_mut().clone()) + .execute() + .await?; + Ok(()) } } -impl<A: IpAddr46> Connection { - /// Adds a rule entry. - pub async fn rule_add(&self, r: Rule<A>) -> Result<()> { - let mut add = self.handle().rule().add().action(r.action); +impl Rule<Ipv6Addr> { + pub async fn add(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v6(); - if let Some(fwmark) = r.fwmark { - add = add.fw_mark(fwmark); + if let Some(dst) = self.dst { + add = add.destination_prefix(dst.0, dst.1); } - if let Some(table) = r.table { - add = add.table_id(table); + if let Some(src) = self.src { + add = add.source_prefix(src.0, src.1) } - add = r.addSrcDst(add); - - add.message_mut().header.flags.push(RuleFlag::Invert); - add.execute().await?; Ok(()) } - /// Deletes a rule entry. - pub async fn rule_del(&self, r: Rule<A>) -> Result<()> { - let mut add = self.handle().rule().add().action(r.action); + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v6(); - if let Some(fwmark) = r.fwmark { - add = add.fw_mark(fwmark); + if let Some(dst) = self.dst { + add = add.destination_prefix(dst.0, dst.1); } - if let Some(table) = r.table { - add = add.table_id(table); + if let Some(src) = self.src { + add = add.source_prefix(src.0, src.1) } - add = r.addSrcDst(add); - - add.message_mut().header.flags.push(RuleFlag::Invert); - - self.handle() + c.handle() .rule() .del(add.message_mut().clone()) .execute() @@ -107,3 +121,23 @@ impl<A: IpAddr46> Connection { Ok(()) } } + +impl<T> Rule<T> { + fn prepare_add(&self, c: &Connection) -> RuleAddRequest { + let mut add = c.handle().rule().add().action(self.action); + + if self.invert { + add.message_mut().header.flags.push(RuleFlag::Invert); + } + + if let Some(fwmark) = self.fwmark { + add = add.fw_mark(fwmark) + } + + if self.action == RuleAction::ToTable { + add = add.table_id(self.table) + } + + add + } +} |