diff options
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | flake.nix | 8 | ||||
-rw-r--r-- | macros/Cargo.toml | 5 | ||||
-rw-r--r-- | macros/src/lib.rs | 147 | ||||
-rw-r--r-- | src/batch.rs | 1 | ||||
-rw-r--r-- | src/chain.rs | 11 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/expr/mod.rs | 4 | ||||
-rw-r--r-- | src/expr/verdict.rs | 7 | ||||
-rw-r--r-- | src/nlmsg.rs | 19 | ||||
-rw-r--r-- | src/parser.rs | 77 | ||||
-rw-r--r-- | src/parser_impls.rs | 50 | ||||
-rw-r--r-- | src/query.rs | 14 | ||||
-rw-r--r-- | src/table.rs | 6 |
14 files changed, 232 insertions, 127 deletions
@@ -1,6 +1,6 @@ [package] name = "rustables" -version = "0.8.0" +version = "0.8.1-alpha1" authors = ["lafleur@boum.org", "Simon Thoby", "Mullvad VPN"] license = "GPL-3.0-or-later" description = "Safe abstraction for libnftnl. Provides low-level userspace access to the in-kernel nf_tables subsystem" @@ -20,7 +20,7 @@ log = "0.4" libc = "0.2.43" nix = "0.23" ipnetwork = { version = "0.20", default-features = false } -rustables-macros = "0.1.0" +rustables-macros = { path = "macros", version = "0.1.1-alpha1" } [dev-dependencies] env_logger = "0.9" @@ -14,11 +14,13 @@ channel = "1.66.0"; sha256 = "S7epLlflwt0d1GZP44u5Xosgf6dRrmr8xxC+Ml2Pq7c="; }; + rust = rustChannel.rust.override { + targets = [ "x86_64-unknown-linux-musl" ]; + }; in { - inherit rustChannel; - rustc = rustChannel.rust; - cargo = rustChannel.rust; + rustc = rust; + cargo = rust; } ); rustDevOverlay = final: prev: { diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 5d0f297..20c3b5f 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "rustables-macros" -version = "0.1.0" -authors = ["Simon Thoby"] +version = "0.1.1-alpha1" +authors = ["lafleur@boum.org", "Simon Thoby"] license = "GPL-3.0-or-later" description = "Internal macros for generation netlink structures for the rustables project" repository = "https://gitlab.com/rustwall/rustables" @@ -16,3 +16,4 @@ syn = { version = "1.0", features = ["full"] } quote = "1.0" proc-macro2 = "1.0" proc-macro-error = "1" +once_cell = "1.1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 39f0d01..af90dab 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,16 +1,58 @@ +#![allow(rustdoc::broken_intra_doc_links)] + +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + use proc_macro::TokenStream; use proc_macro2::{Group, Span}; -use quote::quote; +use quote::{quote, quote_spanned}; use proc_macro_error::{abort, proc_macro_error}; use syn::parse::Parser; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ - parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result, - Token, Type, TypePath, Visibility, + parse, parse2, Attribute, Expr, ExprCast, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path, + Result, Token, Type, TypePath, Visibility, }; +use once_cell::sync::OnceCell; + +struct GlobalState { + declared_identifiers: Vec<String>, +} + +static STATE: OnceCell<GlobalState> = OnceCell::new(); + +fn get_state() -> &'static GlobalState { + STATE.get_or_init(|| { + let sys_file = { + // Load the header file and extract the constants defined inside. + // This is what determines whether optional attributes (or enum variants) + // will be supported or not in the resulting binary. + let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs"); + let mut sys_file = String::new(); + File::open(out_path) + .expect("Error: could not open the output header file") + .read_to_string(&mut sys_file) + .expect("Could not read the header file"); + syn::parse_file(&sys_file).expect("Could not parse the header file") + }; + + let mut declared_identifiers = Vec::new(); + for item in sys_file.items { + if let Item::Const(v) = item { + declared_identifiers.push(v.ident.to_string()); + } + } + + GlobalState { + declared_identifiers, + } + }) +} + struct Field<'a> { name: &'a Ident, ty: &'a Type, @@ -24,6 +66,7 @@ struct Field<'a> { struct FieldArgs { netlink_type: Option<Path>, override_function_name: Option<String>, + optional: bool, } fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { @@ -57,7 +100,14 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { abort!(&namevalue.lit.span(), "Expected a string literal"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + "optional" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.optional = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(arg.span(), "Unsupported macro parameter"), } } _ => abort!(arg.span(), "Unrecognized argument"), @@ -115,7 +165,7 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> { abort!(&namevalue.lit.span(), "Expected a boolean"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + _ => abort!(arg.span(), "Unsupported macro parameter"), } } else { abort!(arg.span(), "Unrecognized argument"); @@ -124,6 +174,66 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> { Ok(args) } +/// `nfnetlink_struct` is a macro wrapping structures that describe nftables objects. +/// It allows serializing and deserializing these objects to the corresponding nfnetlink +/// attributes. +/// +/// It automatically generates getter and setter functions for each netlink properties. +/// +/// # Parameters +/// The macro have multiple parameters: +/// - `nested` (defaults to `false`): the structure is nested (in the netlink sense) +/// inside its parent structure. This is the case of most structures outside +/// of the main nftables objects (batches, sets, rules, chains and tables), which are +/// the outermost structures, and as such cannot be nested. +/// - `derive_decoder` (defaults to `true`): derive a [`rustables::nlmsg::AttributeDecoder`] +/// implementation for the structure +/// - `derive_deserialize` (defaults to `true`): derive a [`rustables::nlmsg::NfNetlinkDeserializable`] +/// implementation for the structure +/// +/// # Example use +/// ``` +/// #[nfnetlink_struct(derive_deserialize = false)] +/// #[derive(PartialEq, Eq, Default, Debug)] +/// pub struct Chain { +/// family: ProtocolFamily, +/// #[field(NFTA_CHAIN_TABLE)] +/// table: String, +/// #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")] +/// chain_type: ChainType, +/// #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)] +/// userdata: Vec<u8>, +/// ... +/// } +/// ``` +/// +/// # Type of fields +/// This contrived example show the two possible type of fields: +/// - A field that is not converted to a netlink attribute (`family`) because it is not +/// annotated in `#[field]` attribute. +/// When deserialized, this field will take the value it is given in the Default implementation +/// of the struct. +/// - A field that is annotated with the `#[field]` attribute. +/// That attribute takes parameters (there are none here), and the netlink attribute type. +/// When annotated with that attribute, the macro will generate `get_<name>`, `set_<name>` and +/// `with_<name>` methods to manipulate the attribute (e.g. `get_table`, `set_table` and +/// `with_table`). +/// It will also replace the field type (here `String`) with an Option (`Option<String>`) +/// so the struct may represent objects where that attribute is not set. +/// +/// # `#[field]` parameters +/// The `#[field]` attribute can be parametrized through two options: +/// - `optional` (defaults to `false`): if the netlink attribute type (here `NFTA_CHAIN_USERDATA`) +/// does not exist, do not generate methods and ignore this attribute if encountered +/// while deserializing a nftables object. +/// This is useful for attributes added recently to the kernel, which may not be supported on +/// older kernels. +/// Support for an attribute is detected according to the existence of that attribute in the kernel +/// headers. +/// - `name_in_functions` (not defined by default): overwrite the `<name`> in the name of the methods +/// `get_<name>`, `set_<name>` and `with_<name>`. +/// Here, this means that even though the field is called `chain_type`, users can query it with +/// the method `get_type` instead of `get_chain_type`. #[proc_macro_error] #[proc_macro_attribute] pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { @@ -135,6 +245,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), }; + let state = get_state(); + let mut fields = Vec::with_capacity(ast.fields.len()); let mut identical_fields = Vec::new(); @@ -149,6 +261,21 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { } }; if let Some(netlink_type) = field_args.netlink_type.clone() { + // optional fields are not generated when the kernel version you have on + // the system does not support that field + if field_args.optional { + let netlink_type_ident = netlink_type + .segments + .last() + .expect("empty path?") + .ident + .to_string(); + if !state.declared_identifiers.contains(&netlink_type_ident) { + // reject the optional identifier + continue 'out; + } + } + fields.push(Field { name: field.ident.as_ref().expect("Should be a names struct"), ty: &field.ty, @@ -276,7 +403,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { { let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); - addr = addr.offset(size as isize); + addr = &mut addr[size..]; } } ) @@ -296,7 +423,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { size } - unsafe fn write_payload(&self, mut addr: *mut u8) { + fn write_payload(&self, mut addr: &mut [u8]) { use crate::nlmsg::NfNetlinkAttribute; #(#write_entries) * @@ -312,7 +439,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let ty = field.ty; let attrs = &field.attrs; let vis = &field.vis; - quote!( #(#attrs) * #vis #name: Option<#ty>, ) + quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, ) }); let nfnetlinkdeserialize_impl = if args.derive_deserialize { quote!( @@ -382,7 +509,7 @@ fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> { abort!(&namevalue.lit.span(), "Expected a boolean"); } } - _ => abort!(key.span(), "Unsupported macro parameter"), + _ => abort!(arg.span(), "Unsupported macro parameter"), } } _ => abort!(arg.span(), "Unrecognized argument"), @@ -483,7 +610,7 @@ pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { (*self as #repr_type).get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { (*self as #repr_type).write_payload(addr); } } diff --git a/src/batch.rs b/src/batch.rs index b5c88b8..980194b 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -33,6 +33,7 @@ impl Batch { pub fn new() -> Self { // TODO: use a pinned Box ? let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize)); + // Safe because we hold onto the buffer for as long as `writer` exists let mut writer = NfNetlinkWriter::new(unsafe { std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) }); diff --git a/src/chain.rs b/src/chain.rs index 37e4cb3..53ac595 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError}; use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; use crate::sys::{ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, - NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, - NFT_MSG_NEWCHAIN, + NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, }; use crate::{Batch, ProtocolFamily, Table}; use std::fmt::Debug; @@ -63,7 +62,7 @@ impl NfNetlinkAttribute for ChainPolicy { (*self as i32).get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { (*self as i32).write_payload(addr); } } @@ -111,7 +110,7 @@ impl NfNetlinkAttribute for ChainType { self.as_str().len() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.as_str().to_string().write_payload(addr); } } @@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType { /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -#[derive(PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(PartialEq, Eq, Default, Debug)] pub struct Chain { family: ProtocolFamily, #[field(NFTA_CHAIN_TABLE)] @@ -151,7 +150,7 @@ pub struct Chain { chain_type: ChainType, #[field(NFTA_CHAIN_FLAGS)] flags: u32, - #[field(NFTA_CHAIN_USERDATA)] + #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)] userdata: Vec<u8>, } diff --git a/src/error.rs b/src/error.rs index f6b6247..80f06d7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -111,9 +111,6 @@ pub enum DecodeError { #[error("Invalid value for a protocol family")] UnknownProtocolFamily(i32), - - #[error("A custom error occured")] - Custom(Box<dyn std::error::Error + 'static>), } #[derive(thiserror::Error, Debug)] @@ -157,9 +154,6 @@ pub enum QueryError { #[error("Error received from the kernel")] NetlinkError(nlmsgerr), - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + Send + 'static>), - #[error("Couldn't allocate a netlink object, out of memory ?")] NetlinkAllocationFailed, diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 058b0cb..af29460 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -101,7 +101,7 @@ macro_rules! create_expr_variant { } } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { match self { $( $enum::$name(val) => val.write_payload(addr), @@ -194,7 +194,7 @@ impl NfNetlinkAttribute for ExpressionRaw { self.0.get_size() } - unsafe fn write_payload(&self, addr: *mut u8) { + fn write_payload(&self, addr: &mut [u8]) { self.0.write_payload(addr); } } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 7edf7cd..c42ad32 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; use crate::sys::{ - NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, - NFT_GOTO, NFT_JUMP, NFT_RETURN, + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN, }; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -21,14 +20,14 @@ pub enum VerdictType { Return = NFT_RETURN, } -#[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(nested = true)] +#[derive(Clone, PartialEq, Eq, Default, Debug)] pub struct Verdict { #[field(NFTA_VERDICT_CODE)] code: VerdictType, #[field(NFTA_VERDICT_CHAIN)] chain: String, - #[field(NFTA_VERDICT_CHAIN_ID)] + #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)] chain_id: u32, } diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 1c5b519..b8fa857 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -39,6 +39,8 @@ pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { pub struct NfNetlinkWriter<'a> { buf: &'a mut Vec<u8>, + // hold the position of the nlmsghdr and nfgenmsg structures for the object currently being + // written headers: Option<(usize, usize)>, } @@ -52,6 +54,7 @@ impl<'a> NfNetlinkWriter<'a> { let start = self.buf.len(); self.buf.resize(start + padded_size, 0); + // if we are *inside* an object begin written, extend the netlink object size if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) @@ -78,6 +81,7 @@ impl<'a> NfNetlinkWriter<'a> { let nlmsghdr_len = pad_netlink_object::<nlmsghdr>(); let nfgenmsg_len = pad_netlink_object::<nfgenmsg>(); + // serialize the nlmsghdr let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; @@ -90,6 +94,7 @@ impl<'a> NfNetlinkWriter<'a> { hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; hdr.nlmsg_seq = seq; + // serialize the nfgenmsg let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); let mut nfgenmsg: &mut nfgenmsg = unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; @@ -108,8 +113,10 @@ impl<'a> NfNetlinkWriter<'a> { } } +pub type NetlinkType = u16; + pub trait AttributeDecoder { - fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; + fn decode_attribute(&mut self, attr_type: NetlinkType, buf: &[u8]) -> Result<(), DecodeError>; } pub trait NfNetlinkDeserializable: Sized { @@ -139,9 +146,7 @@ pub trait NfNetlinkObject: None, ); let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } + self.write_payload(buf); writer.finalize_writing_object(); } @@ -165,8 +170,6 @@ pub trait NfNetlinkObject: } } -pub type NetlinkType = u16; - pub trait NfNetlinkAttribute: Debug + Sized { // is it a nested argument that must be marked with a NLA_F_NESTED flag? fn is_nested(&self) -> bool { @@ -177,6 +180,6 @@ pub trait NfNetlinkAttribute: Debug + Sized { size_of::<Self>() } - // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); - unsafe fn write_payload(&self, addr: *mut u8); + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr.as_mut_ptr(), self.get_size()); + fn write_payload(&self, addr: &mut [u8]); } diff --git a/src/parser.rs b/src/parser.rs index 6ea34c1..82dd27e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -105,14 +105,10 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr /// Write the attribute, preceded by a `libc::nlattr` // rewrite of `mnl_attr_put` -pub unsafe fn write_attribute<'a>( - ty: NetlinkType, - obj: &impl NfNetlinkAttribute, - mut buf: *mut u8, -) { - let header_len = pad_netlink_object::<libc::nlattr>(); +pub fn write_attribute<'a>(ty: NetlinkType, obj: &impl NfNetlinkAttribute, mut buf: &mut [u8]) { + let header_len = pad_netlink_object::<nlattr>(); // copy the header - *(buf as *mut nlattr) = nlattr { + let header = nlattr { // nla_len contains the header size + the unpadded attribute length nla_len: (header_len + obj.get_size() as usize) as u16, nla_type: if obj.is_nested() { @@ -121,7 +117,12 @@ pub unsafe fn write_attribute<'a>( ty }, }; - buf = buf.offset(pad_netlink_object::<nlattr>() as isize); + + unsafe { + *(buf.as_mut_ptr() as *mut nlattr) = header; + } + + buf = &mut buf[header_len..]; // copy the attribute data itself obj.write_payload(buf); } @@ -169,48 +170,30 @@ pub trait InnerFormat { ) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>; } -pub trait Parsable -where - Self: Sized, -{ - fn parse_object( - buf: &[u8], - add_obj: u32, - del_obj: u32, - ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>; -} +pub(crate) fn parse_object<T: AttributeDecoder + Default + Sized>( + buf: &[u8], + add_obj: u32, + del_obj: u32, +) -> Result<(T, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() started"); + let (hdr, msg) = parse_nlmsg(buf)?; -impl<T> Parsable for T -where - T: AttributeDecoder + Default + Sized, -{ - fn parse_object( - buf: &[u8], - add_obj: u32, - del_obj: u32, - ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> { - debug!("parse_object() started"); - let (hdr, msg) = parse_nlmsg(buf)?; - - let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; - - if op != add_obj && op != del_obj { - return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); - } + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; - let obj_size = hdr.nlmsg_len as usize - - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); + if op != add_obj && op != del_obj { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } - let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); - let remaining_data = &buf[remaining_data_offset..]; + let obj_size = hdr.nlmsg_len as usize + - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); - let (nfgenmsg, res) = match msg { - NlMsg::NfGenMsg(nfgenmsg, content) => { - (nfgenmsg, read_attributes(&content[..obj_size])?) - } - _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), - }; + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + let remaining_data = &buf[remaining_data_offset..]; - Ok((res, nfgenmsg, remaining_data)) - } + let (nfgenmsg, res) = match msg { + NlMsg::NfGenMsg(nfgenmsg, content) => (nfgenmsg, read_attributes(&content[..obj_size])?), + _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), + }; + + Ok((res, nfgenmsg, remaining_data)) } diff --git a/src/parser_impls.rs b/src/parser_impls.rs index b2681bb..c49c876 100644 --- a/src/parser_impls.rs +++ b/src/parser_impls.rs @@ -1,4 +1,7 @@ -use std::{fmt::Debug, mem::transmute}; +use std::{ + fmt::Debug, + mem::{size_of, transmute}, +}; use rustables_macros::nfnetlink_struct; @@ -6,17 +9,17 @@ use crate::{ error::DecodeError, expr::Verdict, nlmsg::{ - pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, - NfNetlinkDeserializable, NfNetlinkObject, + pad_netlink_object, pad_netlink_object_with_variable_size, AttributeDecoder, + NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, }, - parser::{write_attribute, Parsable}, + parser::{parse_object, write_attribute}, sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, ProtocolFamily, }; impl NfNetlinkAttribute for u8 { - unsafe fn write_payload(&self, addr: *mut u8) { - *addr = *self; + fn write_payload(&self, addr: &mut [u8]) { + addr[0] = *self; } } @@ -27,8 +30,8 @@ impl NfNetlinkDeserializable for u8 { } impl NfNetlinkAttribute for u16 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -39,8 +42,8 @@ impl NfNetlinkDeserializable for u16 { } impl NfNetlinkAttribute for i32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -54,8 +57,8 @@ impl NfNetlinkDeserializable for i32 { } impl NfNetlinkAttribute for u32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -69,8 +72,8 @@ impl NfNetlinkDeserializable for u32 { } impl NfNetlinkAttribute for u64 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes()); } } @@ -90,8 +93,8 @@ impl NfNetlinkAttribute for String { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_bytes()); } } @@ -110,8 +113,8 @@ impl NfNetlinkAttribute for Vec<u8> { self.len() } - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + fn write_payload(&self, addr: &mut [u8]) { + addr[0..self.len()].copy_from_slice(&self.as_slice()); } } @@ -170,10 +173,11 @@ where }) } - unsafe fn write_payload(&self, mut addr: *mut u8) { + fn write_payload(&self, mut addr: &mut [u8]) { for item in &self.objs { write_attribute(NFTA_LIST_ELEM, item, addr); - addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + let offset = pad_netlink_object::<nlattr>() + item.get_size(); + addr = &mut addr[offset..]; } } } @@ -228,10 +232,10 @@ where impl<T> NfNetlinkDeserializable for T where - T: NfNetlinkObject + Parsable, + T: NfNetlinkObject + AttributeDecoder + Default + Sized, { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + fn deserialize(buf: &[u8]) -> Result<(T, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = parse_object::<T>( buf, <T as NfNetlinkObject>::MSG_TYPE_ADD, <T as NfNetlinkObject>::MSG_TYPE_DEL, diff --git a/src/query.rs b/src/query.rs index 7cf5050..3548d2a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -59,7 +59,7 @@ pub(crate) fn recv_and_process<'a, T>( } // we cannot know when a sequence of messages will end if the messages do not end - // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + // with an NlMsg::Done marker if a maximum sequence number wasn't specified if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { return Err(QueryError::UndecidableMessageTermination); } @@ -79,13 +79,7 @@ pub(crate) fn recv_and_process<'a, T>( // We achieve this by relocating the buffer content at the beginning of the buffer if end_pos >= nft_nlmsg_maxsize() as usize { if buf_start < end_pos { - unsafe { - std::ptr::copy( - msg_buffer[buf_start..end_pos].as_ptr(), - msg_buffer.as_mut_ptr(), - end_pos - buf_start, - ); - } + msg_buffer.copy_within(buf_start..end_pos, 0); } end_pos = end_pos - buf_start; buf_start = 0; @@ -128,9 +122,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>( ); if let Some(filter) = filter { let buf = writer.add_data_zeroed(filter.get_size()); - unsafe { - filter.write_payload(buf.as_mut_ptr()); - } + filter.write_payload(buf); } writer.finalize_writing_object(); Ok(buffer) diff --git a/src/table.rs b/src/table.rs index 81a26ef..1d19abe 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct; use crate::error::QueryError; use crate::nlmsg::NfNetlinkObject; use crate::sys::{ - NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, NFT_MSG_NEWTABLE, }; use crate::{Batch, ProtocolFamily}; @@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily}; /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html -#[derive(Default, PartialEq, Eq, Debug)] #[nfnetlink_struct(derive_deserialize = false)] +#[derive(Default, PartialEq, Eq, Debug)] pub struct Table { family: ProtocolFamily, #[field(NFTA_TABLE_NAME)] name: String, #[field(NFTA_TABLE_FLAGS)] flags: u32, - #[field(NFTA_TABLE_USERDATA)] + #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)] userdata: Vec<u8>, } |