diff options
author | Simon THOBY <git@nightmared.fr> | 2022-12-03 22:53:23 +0100 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2022-12-05 22:40:01 +0100 |
commit | edb440a952320ea4f021c1d7063ff6d5f2f13818 (patch) | |
tree | 5c18e7f1fabdcef8e140920ea75bfd0d0b400bd0 /macros | |
parent | 4b60b3cd41f5198c47a260ce69abf4c15b60ca92 (diff) |
Macros: introduce a macro to simplify enums
Diffstat (limited to 'macros')
-rw-r--r-- | macros/Cargo.toml | 2 | ||||
-rw-r--r-- | macros/src/lib.rs | 194 |
2 files changed, 175 insertions, 21 deletions
diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 7d9167f..82c8ad6 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" proc-macro = true [dependencies] -syn = { version = "1.0", features = ["full", "extra-traits"] } +syn = { version = "1.0", features = ["full"] } quote = "1.0" proc-macro2 = "1.0" proc-macro-error = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 38cde50..11aedaf 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,28 +1,26 @@ use proc_macro::TokenStream; -use proc_macro2::Group; +use proc_macro2::{Group, Span}; use quote::quote; use proc_macro_error::{abort, proc_macro_error}; use syn::parse::Parser; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::token::Struct; use syn::{ - parse, parse2, parse_macro_input, Attribute, Expr, ExprLit, FnArg, Ident, ItemFn, ItemStruct, - Lit, Meta, NestedMeta, Pat, PatIdent, Path, Result, ReturnType, Token, Type, TypePath, + parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result, + Token, Type, TypePath, Visibility, }; -use syn::{parse::Parse, PatReference}; -use syn::{parse::ParseStream, TypeReference}; struct Field<'a> { name: &'a Ident, ty: &'a Type, args: FieldArgs, netlink_type: Path, + vis: &'a Visibility, attrs: Vec<&'a Attribute>, } -#[derive(Debug, Default)] +#[derive(Default)] struct FieldArgs { netlink_type: Option<Path>, override_function_name: Option<String>, @@ -68,7 +66,6 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { Ok(args) } -#[derive(Debug)] struct StructArgs { nested: bool, derive_decoder: bool, @@ -85,12 +82,10 @@ impl Default for StructArgs { } } -fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { - if input.is_empty() { - return Ok(()); - } +fn parse_struct_args(input: TokenStream) -> Result<StructArgs> { + let mut args = StructArgs::default(); let parser = Punctuated::<Meta, Token![,]>::parse_terminated; - let attribute_args = parser.parse(input)?; + let attribute_args = parser.parse(input.clone())?; for arg in attribute_args.iter() { if let Meta::NameValue(namevalue) = arg { let key = namevalue @@ -126,7 +121,7 @@ fn parse_struct_args(args: &mut StructArgs, input: TokenStream) -> Result<()> { abort!(arg.span(), "Unrecognized argument"); } } - Ok(()) + Ok(args) } #[proc_macro_error] @@ -135,8 +130,10 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let ast: ItemStruct = parse(item).unwrap(); let name = ast.ident; - let mut args = StructArgs::default(); - parse_struct_args(&mut args, attrs).expect("Could not parse the macro arguments"); + let args = match parse_struct_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; let mut fields = Vec::with_capacity(ast.fields.len()); let mut identical_fields = Vec::new(); @@ -145,15 +142,25 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { for attr in field.attrs.iter() { if let Some(id) = attr.path.get_ident() { if id == "field" { - let field_args = parse_field_args(attr.tokens.clone()) - .expect("Could not parse the field attributes"); + let field_args = match parse_field_args(attr.tokens.clone()) { + Ok(x) => x, + Err(_) => { + abort!(attr.tokens.span(), "Could not parse the field attributes") + } + }; if let Some(netlink_type) = field_args.netlink_type.clone() { fields.push(Field { name: field.ident.as_ref().expect("Should be a names struct"), ty: &field.ty, args: field_args, netlink_type, - attrs: field.attrs.iter().filter(|x| *x != attr).collect(), + vis: &field.vis, + // drop the "field" attribute + attrs: field + .attrs + .iter() + .filter(|x| x.path.get_ident() != attr.path.get_ident()) + .collect(), }); } else { abort!(attr.tokens.span(), "Missing Netlink Type in field"); @@ -297,7 +304,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let name = field.name; let ty = field.ty; let attrs = &field.attrs; - quote!( #(#attrs) * #name: Option<#ty>, ) + let vis = &field.vis; + quote!( #(#attrs) * #vis #name: Option<#ty>, ) }); let nfnetlinkdeserialize_impl = if args.derive_deserialize { quote!( @@ -327,3 +335,149 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { res.into() } + +struct Variant<'a> { + inner: &'a syn::Variant, + name: &'a Ident, + value: &'a Path, +} + +#[derive(Default)] +struct EnumArgs { + nested: bool, + ty: Option<Path>, +} + +fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> { + let mut args = EnumArgs::default(); + let parser = Punctuated::<Meta, Token![,]>::parse_terminated; + let attribute_args = parser.parse(input)?; + for arg in attribute_args.iter() { + match arg { + Meta::Path(path) => { + if args.ty.is_none() { + args.ty = Some(path.clone()); + } else { + abort!(arg.span(), "A value can only have a single representation"); + } + } + Meta::NameValue(namevalue) => { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "nested" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.nested = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } + _ => abort!(arg.span(), "Unrecognized argument"), + } + } + Ok(args) +} + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { + let ast: ItemEnum = parse(item).unwrap(); + let name = ast.ident; + + let args = match parse_enum_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; + + if args.ty.is_none() { + abort!( + Span::call_site(), + "The target type representation is unspecified" + ); + } + + let mut variants = Vec::with_capacity(ast.variants.len()); + + for variant in ast.variants.iter() { + if variant.discriminant.is_none() { + abort!(variant.ident.span(), "Missing value"); + } + let discriminant = variant.discriminant.as_ref().unwrap(); + if let syn::Expr::Path(path) = &discriminant.1 { + variants.push(Variant { + inner: variant, + name: &variant.ident, + value: &path.path, + }); + } else { + abort!(discriminant.1.span(), "Expected a path"); + } + } + + let repr_type = args.ty.unwrap(); + let match_entries = variants.iter().map(|variant| { + let variant_name = variant.name; + let variant_value = &variant.value; + quote!( x if x == (#variant_value as #repr_type) => Self::#variant_name, ) + }); + let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span()); + let nfnetlinkdeserialize_impl = quote!( + impl crate::nlmsg::NfNetlinkDeserializable for #name { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + let (v, remaining_data) = #repr_type::deserialize(buf)?; + Ok(( + match v { + #(#match_entries) * + value => return Err(crate::parser::DecodeError::#unknown_type_ident(value)) + }, + remaining_data, + )) + } + } + ); + let vis = &ast.vis; + let attrs = ast.attrs; + let original_variants = variants.into_iter().map(|x| { + let mut inner = x.inner.clone(); + let mut discriminant = inner.discriminant.as_mut().unwrap(); + let cur_value = discriminant.1.clone(); + let cast_value = Expr::Cast(ExprCast { + attrs: vec![], + expr: Box::new(cur_value), + as_token: Token), + ty: Box::new(Type::Path(TypePath { + qself: None, + path: repr_type.clone(), + })), + }); + discriminant.1 = cast_value; + inner + }); + let res = quote! { + #[repr(#repr_type)] + #(#attrs) * #vis enum #name { + #(#original_variants),* + } + + impl crate::nlmsg::NfNetlinkAttribute for #name { + fn get_size(&self) -> usize { + (*self as #repr_type).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as #repr_type).write_payload(addr); + } + } + + #nfnetlinkdeserialize_impl + + }; + + res.into() +} |