diff options
-rw-r--r-- | macros/src/lib.rs | 75 | ||||
-rw-r--r-- | src/chain.rs | 119 | ||||
-rw-r--r-- | src/table.rs | 11 |
3 files changed, 99 insertions, 106 deletions
diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 998e9eb..38cde50 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -17,6 +17,7 @@ use syn::{parse::ParseStream, TypeReference}; struct Field<'a> { name: &'a Ident, ty: &'a Type, + args: FieldArgs, netlink_type: Path, attrs: Vec<&'a Attribute>, } @@ -24,6 +25,7 @@ struct Field<'a> { #[derive(Debug, Default)] struct FieldArgs { netlink_type: Option<Path>, + override_function_name: Option<String>, } fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { @@ -32,17 +34,35 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { let parser = Punctuated::<Meta, Token![,]>::parse_terminated; let attribute_args = parser.parse2(input)?; for arg in attribute_args.iter() { - if let Meta::Path(path) = arg { - if args.netlink_type.is_none() { - args.netlink_type = Some(path.clone()); - } else { - abort!( - arg.span(), - "Only a single netlink value can exist for a given field" - ); + match arg { + Meta::Path(path) => { + if args.netlink_type.is_none() { + args.netlink_type = Some(path.clone()); + } else { + abort!( + arg.span(), + "Only a single netlink value can exist for a given field" + ); + } } - } else { - abort!(arg.span(), "Unrecognized argument"); + Meta::NameValue(namevalue) => { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "name_in_functions" => { + if let Lit::Str(val) = &namevalue.lit { + args.override_function_name = Some(val.value()); + } else { + abort!(&namevalue.lit.span(), "Expected a string literal"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } + _ => abort!(arg.span(), "Unrecognized argument"), } } Ok(args) @@ -52,6 +72,7 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { struct StructArgs { nested: bool, derive_decoder: bool, + derive_deserialize: bool, } impl Default for StructArgs { @@ -59,6 +80,7 @@ impl Default for StructArgs { Self { nested: false, derive_decoder: true, + derive_deserialize: true, } } } @@ -74,7 +96,7 @@ fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { let key = namevalue .path .get_ident() - .expect("the macro parameter is not an ident ?") + .expect("the macro parameter is not an ident?") .to_string(); match key.as_str() { "derive_decoder" => { @@ -91,7 +113,13 @@ fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { abort!(&namevalue.lit.span(), "Expected a boolean"); } } - + "derive_deserialize" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.derive_deserialize = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } _ => abort!(key.span(), "Unsupported macro parameter"), } } else { @@ -119,10 +147,11 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { if id == "field" { let field_args = parse_field_args(attr.tokens.clone()) .expect("Could not parse the field attributes"); - if let Some(netlink_type) = field_args.netlink_type { + if let Some(netlink_type) = field_args.netlink_type.clone() { fields.push(Field { name: field.ident.as_ref().expect("Should be a names struct"), ty: &field.ty, + args: field_args, netlink_type, attrs: field.attrs.iter().filter(|x| *x != attr).collect(), }); @@ -138,7 +167,14 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let getters_and_setters = fields.iter().map(|field| { let field_name = field.name; + // use the name override if any let field_str = field_name.to_string(); + let field_str = field + .args + .override_function_name + .as_ref() + .map(|x| x.as_str()) + .unwrap_or(field_str.as_str()); let field_type = field.ty; let getter_name = format!("get_{}", field_str); @@ -263,6 +299,17 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let attrs = &field.attrs; quote!( #(#attrs) * #name: Option<#ty>, ) }); + let nfnetlinkdeserialize_impl = if args.derive_deserialize { + quote!( + impl crate::nlmsg::NfNetlinkDeserializable for #name { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + Ok((crate::parser::read_attributes(buf)?, &[])) + } + } + ) + } else { + proc_macro2::TokenStream::new() + }; let res = quote! { #(#attrs) * #vis struct #name { #(#new_fields)* @@ -274,6 +321,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { #decoder #nfnetlinkattribute_impl + + #nfnetlinkdeserialize_impl }; res.into() diff --git a/src/chain.rs b/src/chain.rs index eeedcd1..8bdf95b 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,12 +1,14 @@ use libc::{NF_ACCEPT, NF_DROP}; +use rustables_macros::nfnetlink_struct; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; -use crate::parser::{DecodeError, InnerFormat, Parsable}; -use crate::sys::{self, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE}; -use crate::{ - create_wrapper_type, impl_attr_getters_and_setters, impl_nfnetlinkattribute, MsgType, - ProtocolFamily, Table, +use crate::parser::{DecodeError, Parsable}; +use crate::sys::{ + self, NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, + NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, + NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE, }; +use crate::{create_wrapper_type, MsgType, ProtocolFamily, Table}; use std::convert::TryFrom; use std::fmt::Debug; @@ -28,28 +30,15 @@ pub enum HookClass { PostRouting = libc::NF_INET_POST_ROUTING, } -create_wrapper_type!( - nested: Hook, - [ - // Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. - ( - get_class, - set_class, - with_class, - sys::NFTA_HOOK_HOOKNUM, - class, - u32 - ), - ( - get_priority, - set_priority, - with_priority, - sys::NFTA_HOOK_PRIORITY, - priority, - u32 - ) - ] -); +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Hook { + /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + #[field(NFTA_HOOK_HOOKNUM)] + class: u32, + #[field(NFTA_HOOK_PRIORITY)] + priority: u32, +} impl Hook { pub fn new(class: HookClass, priority: ChainPriority) -> Self { @@ -151,16 +140,24 @@ impl NfNetlinkDeserializable for ChainType { /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html /// [`set_hook`]: #method.set_hook -#[derive(PartialEq, Eq, Default)] +#[derive(PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Chain { family: ProtocolFamily, - flags: Option<u32>, - name: Option<String>, - hook: Option<Hook>, - policy: Option<ChainPolicy>, - table: Option<String>, - chain_type: Option<ChainType>, - userdata: Option<Vec<u8>>, + #[field(NFTA_CHAIN_TABLE)] + table: String, + #[field(NFTA_CHAIN_NAME)] + name: String, + #[field(NFTA_CHAIN_HOOK)] + hook: Hook, + #[field(NFTA_CHAIN_POLICY)] + policy: ChainPolicy, + #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")] + chain_type: ChainType, + #[field(NFTA_CHAIN_FLAGS)] + flags: u32, + #[field(NFTA_CHAIN_USERDATA)] + userdata: Vec<u8>, } impl Chain { @@ -208,14 +205,6 @@ impl PartialEq for Chain { } */ -impl Debug for Chain { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut res = f.debug_struct("Chain"); - res.field("family", &self.family); - self.inner_format_struct(res)?.finish() - } -} - impl NfNetlinkObject for Chain { fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { let raw_msg_type = match msg_type { @@ -251,50 +240,6 @@ impl NfNetlinkDeserializable for Chain { } } -impl_attr_getters_and_setters!( - Chain, - [ - (get_table, set_table, with_table, sys::NFTA_CHAIN_TABLE, table, String), - (get_name, set_name, with_name, sys::NFTA_CHAIN_NAME, name, String), - // Sets the hook and priority for this chain. Without calling this method the chain will - // become a "regular chain" without any hook and will thus not receive any traffic unless - // some rule forward packets to it via goto or jump verdicts. - // - // By calling `set_hook` with a hook the chain that is created will be registered with that - // hook and is thus a "base chain". A "base chain" is an entry point for packets from the - // networking stack. - (get_hook, set_hook, with_hook, sys::NFTA_CHAIN_HOOK, hook, Hook), - (get_policy, set_policy, with_policy, sys::NFTA_CHAIN_POLICY, policy, ChainPolicy), - // This only applies if the chain has been registered with a hook by calling `set_hook`. - (get_type, set_type, with_type, sys::NFTA_CHAIN_TYPE, chain_type, ChainType), - (get_flags, set_flags, with_flags, sys::NFTA_CHAIN_FLAGS, flags, u32), - ( - get_userdata, - set_userdata, - with_userdata, - sys::NFTA_CHAIN_USERDATA, - userdata, - Vec<u8> - ) - ] -); - -impl_nfnetlinkattribute!( - inline : Chain, - [ - (sys::NFTA_CHAIN_TABLE, table), - (sys::NFTA_CHAIN_NAME, name), - (sys::NFTA_CHAIN_HOOK, hook), - (sys::NFTA_CHAIN_POLICY, policy), - (sys::NFTA_CHAIN_TYPE, chain_type), - (sys::NFTA_CHAIN_FLAGS, flags), - ( - sys::NFTA_CHAIN_USERDATA, - userdata - ) - ] -); - pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, crate::query::Error> { let mut result = Vec::new(); crate::query::list_objects_with_data( diff --git a/src/table.rs b/src/table.rs index 820d765..e6a6a1a 100644 --- a/src/table.rs +++ b/src/table.rs @@ -4,20 +4,19 @@ use std::fmt::Debug; use rustables_macros::nfnetlink_struct; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; -use crate::parser::Parsable; -use crate::parser::{DecodeError, InnerFormat}; +use crate::parser::{DecodeError, Parsable}; use crate::sys::{ - self, NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, - NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, }; -use crate::{impl_attr_getters_and_setters, impl_nfnetlinkattribute, MsgType, ProtocolFamily}; +use crate::{MsgType, ProtocolFamily}; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html #[derive(Default, PartialEq, Eq, Debug)] -#[nfnetlink_struct] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Table { #[field(NFTA_TABLE_NAME)] name: String, |