aboutsummaryrefslogtreecommitdiff
path: root/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/chain.rs')
-rw-r--r--src/chain.rs381
1 files changed, 151 insertions, 230 deletions
diff --git a/src/chain.rs b/src/chain.rs
index a942a37..37e4cb3 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -1,41 +1,85 @@
-use crate::{MsgType, Table};
-use crate::sys::{self as sys, libc};
-#[cfg(feature = "query")]
-use std::convert::TryFrom;
-use std::{
- ffi::{c_void, CStr, CString},
- fmt,
- os::raw::c_char,
- rc::Rc,
+use libc::{NF_ACCEPT, NF_DROP};
+use rustables_macros::nfnetlink_struct;
+
+use crate::error::{DecodeError, QueryError};
+use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
+use crate::sys::{
+ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
+ NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
+ NFT_MSG_NEWCHAIN,
};
+use crate::{Batch, ProtocolFamily, Table};
+use std::fmt::Debug;
-pub type Priority = i32;
+pub type ChainPriority = i32;
/// The netfilter event hooks a chain can register for.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u16)]
-pub enum Hook {
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[repr(i32)]
+pub enum HookClass {
/// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`.
- PreRouting = libc::NF_INET_PRE_ROUTING as u16,
+ PreRouting = libc::NF_INET_PRE_ROUTING,
/// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`.
- In = libc::NF_INET_LOCAL_IN as u16,
+ In = libc::NF_INET_LOCAL_IN,
/// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`.
- Forward = libc::NF_INET_FORWARD as u16,
+ Forward = libc::NF_INET_FORWARD,
/// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`.
- Out = libc::NF_INET_LOCAL_OUT as u16,
+ Out = libc::NF_INET_LOCAL_OUT,
/// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`.
- PostRouting = libc::NF_INET_POST_ROUTING as u16,
+ PostRouting = libc::NF_INET_POST_ROUTING,
+}
+
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(nested = true)]
+pub struct Hook {
+ /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it.
+ #[field(NFTA_HOOK_HOOKNUM)]
+ class: u32,
+ #[field(NFTA_HOOK_PRIORITY)]
+ priority: u32,
+}
+
+impl Hook {
+ pub fn new(class: HookClass, priority: ChainPriority) -> Self {
+ Hook::default()
+ .with_class(class as u32)
+ .with_priority(priority as u32)
+ }
}
/// A chain policy. Decides what to do with a packet that was processed by the chain but did not
/// match any rules.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-#[repr(u32)]
-pub enum Policy {
+#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[repr(i32)]
+pub enum ChainPolicy {
/// Accept the packet.
- Accept = libc::NF_ACCEPT as u32,
+ Accept = NF_ACCEPT,
/// Drop the packet.
- Drop = libc::NF_DROP as u32,
+ Drop = NF_DROP,
+}
+
+impl NfNetlinkAttribute for ChainPolicy {
+ fn get_size(&self) -> usize {
+ (*self as i32).get_size()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ (*self as i32).write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ChainPolicy {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (v, remaining_data) = i32::deserialize(buf)?;
+ Ok((
+ match v {
+ NF_ACCEPT => ChainPolicy::Accept,
+ NF_DROP => ChainPolicy::Accept,
+ _ => return Err(DecodeError::UnknownChainPolicy),
+ },
+ remaining_data,
+ ))
+ }
}
/// Base chain type.
@@ -53,240 +97,117 @@ pub enum ChainType {
}
impl ChainType {
- fn as_c_str(&self) -> &'static [u8] {
+ fn as_str(&self) -> &'static str {
match *self {
- ChainType::Filter => b"filter\0",
- ChainType::Route => b"route\0",
- ChainType::Nat => b"nat\0",
+ ChainType::Filter => "filter",
+ ChainType::Route => "route",
+ ChainType::Nat => "nat",
}
}
}
-/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s.
-///
-/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more
-/// details.
+impl NfNetlinkAttribute for ChainType {
+ fn get_size(&self) -> usize {
+ self.as_str().len()
+ }
+
+ unsafe fn write_payload(&self, addr: *mut u8) {
+ self.as_str().to_string().write_payload(addr);
+ }
+}
+
+impl NfNetlinkDeserializable for ChainType {
+ fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
+ let (s, remaining_data) = String::deserialize(buf)?;
+ Ok((
+ match s.as_str() {
+ "filter" => ChainType::Filter,
+ "route" => ChainType::Route,
+ "nat" => ChainType::Nat,
+ _ => return Err(DecodeError::UnknownChainType),
+ },
+ remaining_data,
+ ))
+ }
+}
+
+/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s.
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-/// [`set_hook`]: #method.set_hook
+#[derive(PartialEq, Eq, Default, Debug)]
+#[nfnetlink_struct(derive_deserialize = false)]
pub struct Chain {
- pub(crate) chain: *mut sys::nftnl_chain,
- pub(crate) table: Rc<Table>,
+ family: ProtocolFamily,
+ #[field(NFTA_CHAIN_TABLE)]
+ table: String,
+ #[field(NFTA_CHAIN_NAME)]
+ name: String,
+ #[field(NFTA_CHAIN_HOOK)]
+ hook: Hook,
+ #[field(NFTA_CHAIN_POLICY)]
+ policy: ChainPolicy,
+ #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")]
+ chain_type: ChainType,
+ #[field(NFTA_CHAIN_FLAGS)]
+ flags: u32,
+ #[field(NFTA_CHAIN_USERDATA)]
+ userdata: Vec<u8>,
}
impl Chain {
- /// Creates a new chain instance inside the given [`Table`] and with the given name.
+ /// Creates a new chain instance inside the given [`Table`].
///
/// [`Table`]: struct.Table.html
- pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain {
- unsafe {
- let chain = try_alloc!(sys::nftnl_chain_alloc());
- sys::nftnl_chain_set_u32(
- chain,
- sys::NFTNL_CHAIN_FAMILY as u16,
- table.get_family() as u32,
- );
- sys::nftnl_chain_set_str(
- chain,
- sys::NFTNL_CHAIN_TABLE as u16,
- table.get_name().as_ptr(),
- );
- sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr());
- Chain { chain, table }
- }
- }
-
- pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self {
- Chain { chain, table }
- }
+ pub fn new(table: &Table) -> Chain {
+ let mut chain = Chain::default();
+ chain.family = table.get_family();
- /// Sets the hook and priority for this chain. Without calling this method the chain will
- /// become a "regular chain" without any hook and will thus not receive any traffic unless
- /// some rule forward packets to it via goto or jump verdicts.
- ///
- /// By calling `set_hook` with a hook the chain that is created will be registered with that
- /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the
- /// networking stack.
- pub fn set_hook(&mut self, hook: Hook, priority: Priority) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32);
- sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority);
+ if let Some(table_name) = table.get_name() {
+ chain.set_table(table_name);
}
- }
- /// Set the type of a base chain. This only applies if the chain has been registered
- /// with a hook by calling `set_hook`.
- pub fn set_type(&mut self, chain_type: ChainType) {
- unsafe {
- sys::nftnl_chain_set_str(
- self.chain,
- sys::NFTNL_CHAIN_TYPE as u16,
- chain_type.as_c_str().as_ptr() as *const c_char,
- );
- }
+ chain
}
- /// Sets the default policy for this chain. That means what action netfilter will apply to
- /// packets processed by this chain, but that did not match any rules in it.
- pub fn set_policy(&mut self, policy: Policy) {
- unsafe {
- sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32);
- }
- }
-
- /// Returns the userdata of this chain.
- pub fn get_userdata(&self) -> Option<&CStr> {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16);
- if ptr == std::ptr::null() {
- return None;
- }
- Some(CStr::from_ptr(ptr))
- }
- }
-
- /// Updates the userdata of this chain.
- pub fn set_userdata(&self, data: &CStr) {
- unsafe {
- sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr());
- }
- }
-
- /// Returns the name of this chain.
- pub fn get_name(&self) -> &CStr {
- unsafe {
- let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16);
- if ptr.is_null() {
- panic!("Impossible situation: retrieving the name of a chain failed")
- } else {
- CStr::from_ptr(ptr)
- }
- }
- }
-
- /// Returns a textual description of the chain.
- pub fn get_str(&self) -> CString {
- let mut descr_buf = vec![0i8; 4096];
- unsafe {
- sys::nftnl_chain_snprintf(
- descr_buf.as_mut_ptr() as *mut c_char,
- (descr_buf.len() - 1) as u64,
- self.chain,
- sys::NFTNL_OUTPUT_DEFAULT,
- 0,
- );
- CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned()
- }
- }
-
- /// Returns a reference to the [`Table`] this chain belongs to.
- ///
- /// [`Table`]: struct.Table.html
- pub fn get_table(&self) -> Rc<Table> {
- self.table.clone()
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns the raw handle.
- pub fn as_ptr(&self) -> *const sys::nftnl_chain {
- self.chain as *const sys::nftnl_chain
- }
-
- #[cfg(feature = "unsafe-raw-handles")]
- /// Returns a mutable version of the raw handle.
- pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain {
- self.chain
+ /// Appends this chain to `batch`
+ pub fn add_to_batch(self, batch: &mut Batch) -> Self {
+ batch.add(&self, crate::MsgType::Add);
+ self
}
}
-impl fmt::Debug for Chain {
- /// Returns a string representation of the chain.
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "{:?}", self.get_str())
- }
-}
+impl NfNetlinkObject for Chain {
+ const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN;
+ const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN;
-impl PartialEq for Chain {
- fn eq(&self, other: &Self) -> bool {
- self.get_table() == other.get_table() && self.get_name() == other.get_name()
+ fn get_family(&self) -> ProtocolFamily {
+ self.family
}
-}
-unsafe impl crate::NlMsg for Chain {
- unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
- let raw_msg_type = match msg_type {
- MsgType::Add => libc::NFT_MSG_NEWCHAIN,
- MsgType::Del => libc::NFT_MSG_DELCHAIN,
- };
- let flags: u16 = match msg_type {
- MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16,
- MsgType::Del => libc::NLM_F_ACK as u16,
- } | libc::NLM_F_ACK as u16;
- let header = sys::nftnl_nlmsg_build_hdr(
- buf as *mut c_char,
- raw_msg_type as u16,
- self.table.get_family() as u16,
- flags,
- seq,
- );
- sys::nftnl_chain_nlmsg_build_payload(header, self.chain);
+ fn set_family(&mut self, family: ProtocolFamily) {
+ self.family = family;
}
}
-impl Drop for Chain {
- fn drop(&mut self) {
- unsafe { sys::nftnl_chain_free(self.chain) };
- }
-}
-
-#[cfg(feature = "query")]
-pub fn get_chains_cb<'a>(
- header: &libc::nlmsghdr,
- (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>),
-) -> libc::c_int {
- unsafe {
- let chain = sys::nftnl_chain_alloc();
- if chain == std::ptr::null_mut() {
- return mnl::mnl_sys::MNL_CB_ERROR;
- }
- let err = sys::nftnl_chain_nlmsg_parse(header, chain);
- if err < 0 {
- error!("Failed to parse nelink chain message - {}", err);
- sys::nftnl_chain_free(chain);
- return err;
- }
-
- let table_name = CStr::from_ptr(sys::nftnl_chain_get_str(
- chain,
- sys::NFTNL_CHAIN_TABLE as u16,
- ));
- let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16);
- let family = match crate::ProtoFamily::try_from(family as i32) {
- Ok(family) => family,
- Err(crate::InvalidProtocolFamily) => {
- error!("The netlink table didn't have a valid protocol family !?");
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_ERROR;
+pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> {
+ let mut result = Vec::new();
+ crate::query::list_objects_with_data(
+ libc::NFT_MSG_GETCHAIN as u16,
+ &|chain: Chain, (table, chains): &mut (&Table, &mut Vec<Chain>)| {
+ if chain.get_table() == table.get_name() {
+ chains.push(chain);
+ } else {
+ info!(
+ "Ignoring chain {:?} because it doesn't map the table {:?}",
+ chain.get_name(),
+ table.get_name()
+ );
}
- };
-
- if table_name != table.get_name() {
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
- }
-
- if family != crate::ProtoFamily::Unspec && family != table.get_family() {
- sys::nftnl_chain_free(chain);
- return mnl::mnl_sys::MNL_CB_OK;
- }
-
- chains.push(Chain::from_raw(chain, table.clone()));
- }
- mnl::mnl_sys::MNL_CB_OK
-}
-
-#[cfg(feature = "query")]
-pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> {
- crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None)
+ Ok(())
+ },
+ None,
+ &mut (&table, &mut result),
+ )?;
+ Ok(result)
}