aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Thoby <git@nightmared.fr>2023-02-26 18:35:39 +0100
committerSimon Thoby <git@nightmared.fr>2023-02-26 18:35:39 +0100
commitb025914bd3bcf786ff7ab53c9dabdd6e37a05782 (patch)
treec49e9154bf15c2b32b76c09fcba896c768142589
parente5c2b423473bb147763c8f6a73aec73212980e4b (diff)
add support for optional attributes
-rw-r--r--macros/Cargo.toml1
-rw-r--r--macros/src/lib.rs79
-rw-r--r--src/chain.rs7
-rw-r--r--src/expr/verdict.rs7
-rw-r--r--src/table.rs6
5 files changed, 82 insertions, 18 deletions
diff --git a/macros/Cargo.toml b/macros/Cargo.toml
index 5d0f297..4ee8a45 100644
--- a/macros/Cargo.toml
+++ b/macros/Cargo.toml
@@ -16,3 +16,4 @@ syn = { version = "1.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
proc-macro-error = "1"
+once_cell = "1.1"
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index bfb5099..9929777 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -1,16 +1,56 @@
+use std::fs::File;
+use std::io::Read;
+use std::path::PathBuf;
+
use proc_macro::TokenStream;
use proc_macro2::{Group, Span};
-use quote::quote;
+use quote::{quote, quote_spanned};
use proc_macro_error::{abort, proc_macro_error};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
- parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result,
- Token, Type, TypePath, Visibility,
+ parse, parse2, Attribute, Expr, ExprCast, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path,
+ Result, Token, Type, TypePath, Visibility,
};
+use once_cell::sync::OnceCell;
+
+struct GlobalState {
+ declared_identifiers: Vec<String>,
+}
+
+static STATE: OnceCell<GlobalState> = OnceCell::new();
+
+fn get_state() -> &'static GlobalState {
+ STATE.get_or_init(|| {
+ let sys_file = {
+ // Load the header file and extract the constants defined inside.
+ // This is what determines whether optional attributes (or enum variants)
+ // will be supported or not in the resulting binary.
+ let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
+ let mut sys_file = String::new();
+ File::open(out_path)
+ .expect("Error: could not open the output header file")
+ .read_to_string(&mut sys_file)
+ .expect("Could not read the header file");
+ syn::parse_file(&sys_file).expect("Could not parse the header file")
+ };
+
+ let mut declared_identifiers = Vec::new();
+ for item in sys_file.items {
+ if let Item::Const(v) = item {
+ declared_identifiers.push(v.ident.to_string());
+ }
+ }
+
+ GlobalState {
+ declared_identifiers,
+ }
+ })
+}
+
struct Field<'a> {
name: &'a Ident,
ty: &'a Type,
@@ -24,6 +64,7 @@ struct Field<'a> {
struct FieldArgs {
netlink_type: Option<Path>,
override_function_name: Option<String>,
+ optional: bool,
}
fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> {
@@ -57,7 +98,14 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> {
abort!(&namevalue.lit.span(), "Expected a string literal");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ "optional" => {
+ if let Lit::Bool(boolean) = &namevalue.lit {
+ args.optional = boolean.value;
+ } else {
+ abort!(&namevalue.lit.span(), "Expected a boolean");
+ }
+ }
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
}
_ => abort!(arg.span(), "Unrecognized argument"),
@@ -115,7 +163,7 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> {
abort!(&namevalue.lit.span(), "Expected a boolean");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
} else {
abort!(arg.span(), "Unrecognized argument");
@@ -135,6 +183,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"),
};
+ let state = get_state();
+
let mut fields = Vec::with_capacity(ast.fields.len());
let mut identical_fields = Vec::new();
@@ -149,6 +199,21 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
}
};
if let Some(netlink_type) = field_args.netlink_type.clone() {
+ // optional fields are not generated when the kernel version you have on
+ // the system does not support that field
+ if field_args.optional {
+ let netlink_type_ident = netlink_type
+ .segments
+ .last()
+ .expect("empty path?")
+ .ident
+ .to_string();
+ if !state.declared_identifiers.contains(&netlink_type_ident) {
+ // reject the optional identifier
+ continue 'out;
+ }
+ }
+
fields.push(Field {
name: field.ident.as_ref().expect("Should be a names struct"),
ty: &field.ty,
@@ -312,7 +377,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
let ty = field.ty;
let attrs = &field.attrs;
let vis = &field.vis;
- quote!( #(#attrs) * #vis #name: Option<#ty>, )
+ quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
});
let nfnetlinkdeserialize_impl = if args.derive_deserialize {
quote!(
@@ -382,7 +447,7 @@ fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> {
abort!(&namevalue.lit.span(), "Expected a boolean");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
}
_ => abort!(arg.span(), "Unrecognized argument"),
diff --git a/src/chain.rs b/src/chain.rs
index 7d365a1..53ac595 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError};
use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
use crate::sys::{
NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
- NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
- NFT_MSG_NEWCHAIN,
+ NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN,
};
use crate::{Batch, ProtocolFamily, Table};
use std::fmt::Debug;
@@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType {
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-#[derive(PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(PartialEq, Eq, Default, Debug)]
pub struct Chain {
family: ProtocolFamily,
#[field(NFTA_CHAIN_TABLE)]
@@ -151,7 +150,7 @@ pub struct Chain {
chain_type: ChainType,
#[field(NFTA_CHAIN_FLAGS)]
flags: u32,
- #[field(NFTA_CHAIN_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
userdata: Vec<u8>,
}
diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs
index 7edf7cd..c42ad32 100644
--- a/src/expr/verdict.rs
+++ b/src/expr/verdict.rs
@@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE};
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
use crate::sys::{
- NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE,
- NFT_GOTO, NFT_JUMP, NFT_RETURN,
+ NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
@@ -21,14 +20,14 @@ pub enum VerdictType {
Return = NFT_RETURN,
}
-#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(nested = true)]
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
pub struct Verdict {
#[field(NFTA_VERDICT_CODE)]
code: VerdictType,
#[field(NFTA_VERDICT_CHAIN)]
chain: String,
- #[field(NFTA_VERDICT_CHAIN_ID)]
+ #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)]
chain_id: u32,
}
diff --git a/src/table.rs b/src/table.rs
index 81a26ef..1d19abe 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct;
use crate::error::QueryError;
use crate::nlmsg::NfNetlinkObject;
use crate::sys::{
- NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
+ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
NFT_MSG_NEWTABLE,
};
use crate::{Batch, ProtocolFamily};
@@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily};
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
-#[derive(Default, PartialEq, Eq, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(Default, PartialEq, Eq, Debug)]
pub struct Table {
family: ProtocolFamily,
#[field(NFTA_TABLE_NAME)]
name: String,
#[field(NFTA_TABLE_FLAGS)]
flags: u32,
- #[field(NFTA_TABLE_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)]
userdata: Vec<u8>,
}