From 9d55957b737cce1ace694376bdd157624286c020 Mon Sep 17 00:00:00 2001 From: Himbeer Date: Thu, 20 Mar 2025 14:18:55 +0100 Subject: Completely redesign rule API Rules require special representation in the type system due to the way rtnetlink exposes them. This is inconsistent with the design of the other APIs. The new API compiles. --- src/blocking.rs | 34 +++++++++++++-- src/error.rs | 2 + src/rule.rs | 130 +++++++++++++++++++++++++++++++++++--------------------- 3 files changed, 115 insertions(+), 51 deletions(-) diff --git a/src/blocking.rs b/src/blocking.rs index 7c47b5b..04b15c8 100644 --- a/src/blocking.rs +++ b/src/blocking.rs @@ -112,10 +112,38 @@ pub mod route { pub mod rule { use super::Connection; + use std::net::{Ipv4Addr, Ipv6Addr}; + use crate::rule::Rule; + use crate::Result; - impl Connection { - blockify!(rule_add, r: Rule); - blockify!(rule_del, r: Rule); + impl Rule<()> { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } + } + + impl Rule { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } + } + + impl Rule { + pub fn blocking_add(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.add(&c.conn)) + } + + pub fn blocking_del(self, c: &Connection) -> Result<()> { + c.rt.block_on(self.del(&c.conn)) + } } } diff --git a/src/error.rs b/src/error.rs index d0bc7e7..cb2bd32 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ use thiserror::Error; pub enum Error { #[error("link {0} not found")] LinkNotFound(String), + #[error("prefix matching in protocol-agnostic route")] + PrefixesDisallowed, #[error("link name contains nul bytes: {0}")] Nul(#[from] ffi::NulError), diff --git a/src/rule.rs b/src/rule.rs index 044c067..bd68c73 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -4,26 +4,22 @@ use crate::{Connection, Error, Result}; pub use netlink_packet_route::rule::RuleAction; -use netlink_packet_route::rule::{RuleAttribute, RuleFlag, RuleHeader, RuleMessage}; -use rtnetlink::RuleAddRequest; - -trait IpAddr46 {} +use std::net::{Ipv4Addr, Ipv6Addr}; -impl IpAddr46 for () {} -impl IpAddr46 for Ipv4Addr {} -impl IpAddr46 for Ipv6Addr {} +use netlink_packet_route::rule::RuleFlag; +use rtnetlink::RuleAddRequest; /// A rule entry. #[derive(Debug)] -pub struct Rule { +pub struct Rule { /// Whether to invert the matching criteria. pub invert: bool, /// Firewall mark to match against. pub fwmark: Option, /// Destination prefix to match against. - pub dst: Option<(A, u8)>, + pub dst: Option<(T, u8)>, /// Source prefix to match against. - pub src: Option<(A, u8)>, + pub src: Option<(T, u8)>, /// Action to perform. pub action: RuleAction, /// Routing table to use if `RuleAction::ToTable` is selected. @@ -31,75 +27,93 @@ pub struct Rule { } impl Rule<()> { - fn addSrcDst(&self, rq: RuleAddRequest) -> RuleAddRequest { - rq + pub async fn add(self, c: &Connection) -> Result<()> { + let add = self.prepare_add(c); + + if self.dst.is_some() || self.src.is_some() { + return Err(Error::PrefixesDisallowed); + } + + add.execute().await?; + Ok(()) + } + + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c); + + if self.dst.is_some() || self.src.is_some() { + return Err(Error::PrefixesDisallowed); + } + + c.handle() + .rule() + .del(add.message_mut().clone()) + .execute() + .await?; + Ok(()) } } impl Rule { - fn addSrcDst(&self, mut rq: RuleAddRequest) -> RuleAddRequest { - rq = rq.v4(); + pub async fn add(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v4(); + if let Some(dst) = self.dst { - rq = rq.destination_prefix(dst.0, dst.1); + add = add.destination_prefix(dst.0, dst.1); } if let Some(src) = self.src { - rq = rq.destination_prefix(src.0, src.1); + add = add.source_prefix(src.0, src.1) } - rq + add.execute().await?; + Ok(()) } -} -impl Rule { - fn addSrcDst(&self, mut rq: RuleAddRequest) -> RuleAddRequest { - rq = rq.v6(); + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v4(); + if let Some(dst) = self.dst { - rq = rq.destination_prefix(dst.0, dst.1); + add = add.destination_prefix(dst.0, dst.1); } if let Some(src) = self.src { - rq = rq.destination_prefix(src.0, src.1); + add = add.source_prefix(src.0, src.1) } - rq + c.handle() + .rule() + .del(add.message_mut().clone()) + .execute() + .await?; + Ok(()) } } -impl Connection { - /// Adds a rule entry. - pub async fn rule_add(&self, r: Rule) -> Result<()> { - let mut add = self.handle().rule().add().action(r.action); +impl Rule { + pub async fn add(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v6(); - if let Some(fwmark) = r.fwmark { - add = add.fw_mark(fwmark); + if let Some(dst) = self.dst { + add = add.destination_prefix(dst.0, dst.1); } - if let Some(table) = r.table { - add = add.table_id(table); + if let Some(src) = self.src { + add = add.source_prefix(src.0, src.1) } - add = r.addSrcDst(add); - - add.message_mut().header.flags.push(RuleFlag::Invert); - add.execute().await?; Ok(()) } - /// Deletes a rule entry. - pub async fn rule_del(&self, r: Rule) -> Result<()> { - let mut add = self.handle().rule().add().action(r.action); + pub async fn del(self, c: &Connection) -> Result<()> { + let mut add = self.prepare_add(c).v6(); - if let Some(fwmark) = r.fwmark { - add = add.fw_mark(fwmark); + if let Some(dst) = self.dst { + add = add.destination_prefix(dst.0, dst.1); } - if let Some(table) = r.table { - add = add.table_id(table); + if let Some(src) = self.src { + add = add.source_prefix(src.0, src.1) } - add = r.addSrcDst(add); - - add.message_mut().header.flags.push(RuleFlag::Invert); - - self.handle() + c.handle() .rule() .del(add.message_mut().clone()) .execute() @@ -107,3 +121,23 @@ impl Connection { Ok(()) } } + +impl Rule { + fn prepare_add(&self, c: &Connection) -> RuleAddRequest { + let mut add = c.handle().rule().add().action(self.action); + + if self.invert { + add.message_mut().header.flags.push(RuleFlag::Invert); + } + + if let Some(fwmark) = self.fwmark { + add = add.fw_mark(fwmark) + } + + if self.action == RuleAction::ToTable { + add = add.table_id(self.table) + } + + add + } +} -- cgit v1.2.3