diff options
author | Simon Thoby <git@nightmared.fr> | 2023-02-26 18:35:39 +0100 |
---|---|---|
committer | Simon Thoby <git@nightmared.fr> | 2023-02-26 18:35:39 +0100 |
commit | b025914bd3bcf786ff7ab53c9dabdd6e37a05782 (patch) | |
tree | c49e9154bf15c2b32b76c09fcba896c768142589 | |
parent | e5c2b423473bb147763c8f6a73aec73212980e4b (diff) |
add support for optional attributes
-rw-r--r-- | macros/Cargo.toml | 1 | ||||
-rw-r--r-- | macros/src/lib.rs | 79 | ||||
-rw-r--r-- | src/chain.rs | 7 | ||||
-rw-r--r-- | src/expr/verdict.rs | 7 | ||||
-rw-r--r-- | src/table.rs | 6 |
5 files changed, 82 insertions, 18 deletions
diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 5d0f297..4ee8a45 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -16,3 +16,4 @@ syn = { version = "1.0", features = ["full"] } quote = "1.0" proc-macro2 = "1.0" proc-macro-error = "1" +once_cell = "1.1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index bfb5099..9929777 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,16 +1,56 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + use proc_macro::TokenStream; use proc_macro2::{Group, Span}; -use quote::quote; +use quote::{quote, quote_spanned}; use proc_macro_error::{abort, proc_macro_error}; use syn::parse::Parser; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ - parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result, - Token, Type, TypePath, Visibility, + parse, parse2, Attribute, Expr, ExprCast, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path, + Result, Token, Type, TypePath, Visibility, }; +use once_cell::sync::OnceCell; + +struct GlobalState { + declared_identifiers: Vec<String>, +} + +static STATE: OnceCell<GlobalState> = OnceCell::new(); + +fn get_state() -> &'static GlobalState { + STATE.get_or_init(|| { + let sys_file = { + // Load the header file and extract the constants defined inside. + // This is what determines whether optional attributes (or enum variants) + // will be supported or not in the resulting binary. + let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs"); + let mut sys_file = String::new(); + File::open(out_path) + .expect("Error: could not open the output header file") + .read_to_string(&mut sys_file) + .expect("Could not read the header file"); + syn::parse_file(&sys_file).expect("Could not parse the header file") + }; + + let mut declared_identifiers = Vec::new(); + for item in sys_file.items { + if let Item::Const(v) = item { + declared_identifiers.push(v.ident.to_string()); + } + } + + GlobalState { + declared_identifiers, + } + }) +} + struct Field<'a> { name: &'a Ident, ty: &'a Type, @@ -24,6 +64,7 @@ struct Field<'a> { struct FieldArgs { netlink_type: Option<Path>, override_function_name: Option<String>, + optional: bool, } fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { @@ -57,7 +98,14 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { abort!(&namevalue.lit.span(), "Expected a string literal"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + "optional" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.optional = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(arg.span(), "Unsupported macro parameter"), } } _ => abort!(arg.span(), "Unrecognized argument"), @@ -115,7 +163,7 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> { abort!(&namevalue.lit.span(), "Expected a boolean"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + _ => abort!(arg.span(), "Unsupported macro parameter"), } } else { abort!(arg.span(), "Unrecognized argument"); @@ -135,6 +183,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), }; + let state = get_state(); + let mut fields = Vec::with_capacity(ast.fields.len()); let mut identical_fields = Vec::new(); @@ -149,6 +199,21 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { } }; if let Some(netlink_type) = field_args.netlink_type.clone() { + // optional fields are not generated when the kernel version you have on + // the system does not support that field + if field_args.optional { + let netlink_type_ident = netlink_type + .segments + .last() + .expect("empty path?") + .ident + .to_string(); + if !state.declared_identifiers.contains(&netlink_type_ident) { + // reject the optional identifier + continue 'out; + } + } + fields.push(Field { name: field.ident.as_ref().expect("Should be a names struct"), ty: &field.ty, @@ -312,7 +377,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let ty = field.ty; let attrs = &field.attrs; let vis = &field.vis; - quote!( #(#attrs) * #vis #name: Option<#ty>, ) + quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, ) }); let nfnetlinkdeserialize_impl = if args.derive_deserialize { quote!( @@ -382,7 +447,7 @@ fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> { abort!(&namevalue.lit.span(), "Expected a boolean"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + _ => abort!(arg.span(), "Unsupported macro parameter"), } } _ => abort!(arg.span(), "Unrecognized argument"), diff --git a/src/chain.rs b/src/chain.rs index 7d365a1..53ac595 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError}; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; use crate::sys::{ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, - NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, - NFT_MSG_NEWCHAIN, + NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; @@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType { /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -#[derive(PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(PartialEq, Eq, Default, Debug)] pub struct Chain { family: ProtocolFamily, #[field(NFTA_CHAIN_TABLE)] @@ -151,7 +150,7 @@ pub struct Chain { chain_type: ChainType, #[field(NFTA_CHAIN_FLAGS)] flags: u32, - #[field(NFTA_CHAIN_USERDATA)] + #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)] userdata: Vec<u8>, } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 7edf7cd..c42ad32 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::sys::{ - NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, - NFT_GOTO, NFT_JUMP, NFT_RETURN, + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN, }; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -21,14 +20,14 @@ pub enum VerdictType { Return = NFT_RETURN, } -#[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(nested = true)] +#[derive(Clone, PartialEq, Eq, Default, Debug)] pub struct Verdict { #[field(NFTA_VERDICT_CODE)] code: VerdictType, #[field(NFTA_VERDICT_CHAIN)] chain: String, - #[field(NFTA_VERDICT_CHAIN_ID)] + #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)] chain_id: u32, } diff --git a/src/table.rs b/src/table.rs index 81a26ef..1d19abe 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct; use crate::error::QueryError; use crate::nlmsg::NfNetlinkObject; use crate::sys::{ - NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; use crate::{Batch, ProtocolFamily}; @@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily}; /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html -#[derive(Default, PartialEq, Eq, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(Default, PartialEq, Eq, Debug)] pub struct Table { family: ProtocolFamily, #[field(NFTA_TABLE_NAME)] name: String, #[field(NFTA_TABLE_FLAGS)] flags: u32, - #[field(NFTA_TABLE_USERDATA)] + #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)] userdata: Vec<u8>, } |