aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml4
-rw-r--r--flake.nix8
-rw-r--r--macros/Cargo.toml5
-rw-r--r--macros/src/lib.rs147
-rw-r--r--src/batch.rs1
-rw-r--r--src/chain.rs11
-rw-r--r--src/error.rs6
-rw-r--r--src/expr/mod.rs4
-rw-r--r--src/expr/verdict.rs7
-rw-r--r--src/nlmsg.rs19
-rw-r--r--src/parser.rs77
-rw-r--r--src/parser_impls.rs50
-rw-r--r--src/query.rs14
-rw-r--r--src/table.rs6
14 files changed, 232 insertions, 127 deletions
diff --git a/Cargo.toml b/Cargo.toml
index c72fa39..9c9a170 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "rustables"
-version = "0.8.0"
+version = "0.8.1-alpha1"
authors = ["lafleur@boum.org", "Simon Thoby", "Mullvad VPN"]
license = "GPL-3.0-or-later"
description = "Safe abstraction for libnftnl. Provides low-level userspace access to the in-kernel nf_tables subsystem"
@@ -20,7 +20,7 @@ log = "0.4"
libc = "0.2.43"
nix = "0.23"
ipnetwork = { version = "0.20", default-features = false }
-rustables-macros = "0.1.0"
+rustables-macros = { path = "macros", version = "0.1.1-alpha1" }
[dev-dependencies]
env_logger = "0.9"
diff --git a/flake.nix b/flake.nix
index 0ea7178..3f0a50c 100644
--- a/flake.nix
+++ b/flake.nix
@@ -14,11 +14,13 @@
channel = "1.66.0";
sha256 = "S7epLlflwt0d1GZP44u5Xosgf6dRrmr8xxC+Ml2Pq7c=";
};
+ rust = rustChannel.rust.override {
+ targets = [ "x86_64-unknown-linux-musl" ];
+ };
in
{
- inherit rustChannel;
- rustc = rustChannel.rust;
- cargo = rustChannel.rust;
+ rustc = rust;
+ cargo = rust;
}
);
rustDevOverlay = final: prev: {
diff --git a/macros/Cargo.toml b/macros/Cargo.toml
index 5d0f297..20c3b5f 100644
--- a/macros/Cargo.toml
+++ b/macros/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "rustables-macros"
-version = "0.1.0"
-authors = ["Simon Thoby"]
+version = "0.1.1-alpha1"
+authors = ["lafleur@boum.org", "Simon Thoby"]
license = "GPL-3.0-or-later"
description = "Internal macros for generation netlink structures for the rustables project"
repository = "https://gitlab.com/rustwall/rustables"
@@ -16,3 +16,4 @@ syn = { version = "1.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
proc-macro-error = "1"
+once_cell = "1.1"
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 39f0d01..af90dab 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -1,16 +1,58 @@
+#![allow(rustdoc::broken_intra_doc_links)]
+
+use std::fs::File;
+use std::io::Read;
+use std::path::PathBuf;
+
use proc_macro::TokenStream;
use proc_macro2::{Group, Span};
-use quote::quote;
+use quote::{quote, quote_spanned};
use proc_macro_error::{abort, proc_macro_error};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
- parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result,
- Token, Type, TypePath, Visibility,
+ parse, parse2, Attribute, Expr, ExprCast, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path,
+ Result, Token, Type, TypePath, Visibility,
};
+use once_cell::sync::OnceCell;
+
+struct GlobalState {
+ declared_identifiers: Vec<String>,
+}
+
+static STATE: OnceCell<GlobalState> = OnceCell::new();
+
+fn get_state() -> &'static GlobalState {
+ STATE.get_or_init(|| {
+ let sys_file = {
+ // Load the header file and extract the constants defined inside.
+ // This is what determines whether optional attributes (or enum variants)
+ // will be supported or not in the resulting binary.
+ let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
+ let mut sys_file = String::new();
+ File::open(out_path)
+ .expect("Error: could not open the output header file")
+ .read_to_string(&mut sys_file)
+ .expect("Could not read the header file");
+ syn::parse_file(&sys_file).expect("Could not parse the header file")
+ };
+
+ let mut declared_identifiers = Vec::new();
+ for item in sys_file.items {
+ if let Item::Const(v) = item {
+ declared_identifiers.push(v.ident.to_string());
+ }
+ }
+
+ GlobalState {
+ declared_identifiers,
+ }
+ })
+}
+
struct Field<'a> {
name: &'a Ident,
ty: &'a Type,
@@ -24,6 +66,7 @@ struct Field<'a> {
struct FieldArgs {
netlink_type: Option<Path>,
override_function_name: Option<String>,
+ optional: bool,
}
fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> {
@@ -57,7 +100,14 @@ fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> {
abort!(&namevalue.lit.span(), "Expected a string literal");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ "optional" => {
+ if let Lit::Bool(boolean) = &namevalue.lit {
+ args.optional = boolean.value;
+ } else {
+ abort!(&namevalue.lit.span(), "Expected a boolean");
+ }
+ }
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
}
_ => abort!(arg.span(), "Unrecognized argument"),
@@ -115,7 +165,7 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> {
abort!(&namevalue.lit.span(), "Expected a boolean");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
} else {
abort!(arg.span(), "Unrecognized argument");
@@ -124,6 +174,66 @@ fn parse_struct_args(input: TokenStream) -> Result<StructArgs> {
Ok(args)
}
+/// `nfnetlink_struct` is a macro wrapping structures that describe nftables objects.
+/// It allows serializing and deserializing these objects to the corresponding nfnetlink
+/// attributes.
+///
+/// It automatically generates getter and setter functions for each netlink properties.
+///
+/// # Parameters
+/// The macro have multiple parameters:
+/// - `nested` (defaults to `false`): the structure is nested (in the netlink sense)
+/// inside its parent structure. This is the case of most structures outside
+/// of the main nftables objects (batches, sets, rules, chains and tables), which are
+/// the outermost structures, and as such cannot be nested.
+/// - `derive_decoder` (defaults to `true`): derive a [`rustables::nlmsg::AttributeDecoder`]
+/// implementation for the structure
+/// - `derive_deserialize` (defaults to `true`): derive a [`rustables::nlmsg::NfNetlinkDeserializable`]
+/// implementation for the structure
+///
+/// # Example use
+/// ```
+/// #[nfnetlink_struct(derive_deserialize = false)]
+/// #[derive(PartialEq, Eq, Default, Debug)]
+/// pub struct Chain {
+/// family: ProtocolFamily,
+/// #[field(NFTA_CHAIN_TABLE)]
+/// table: String,
+/// #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")]
+/// chain_type: ChainType,
+/// #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
+/// userdata: Vec<u8>,
+/// ...
+/// }
+/// ```
+///
+/// # Type of fields
+/// This contrived example show the two possible type of fields:
+/// - A field that is not converted to a netlink attribute (`family`) because it is not
+/// annotated in `#[field]` attribute.
+/// When deserialized, this field will take the value it is given in the Default implementation
+/// of the struct.
+/// - A field that is annotated with the `#[field]` attribute.
+/// That attribute takes parameters (there are none here), and the netlink attribute type.
+/// When annotated with that attribute, the macro will generate `get_<name>`, `set_<name>` and
+/// `with_<name>` methods to manipulate the attribute (e.g. `get_table`, `set_table` and
+/// `with_table`).
+/// It will also replace the field type (here `String`) with an Option (`Option<String>`)
+/// so the struct may represent objects where that attribute is not set.
+///
+/// # `#[field]` parameters
+/// The `#[field]` attribute can be parametrized through two options:
+/// - `optional` (defaults to `false`): if the netlink attribute type (here `NFTA_CHAIN_USERDATA`)
+/// does not exist, do not generate methods and ignore this attribute if encountered
+/// while deserializing a nftables object.
+/// This is useful for attributes added recently to the kernel, which may not be supported on
+/// older kernels.
+/// Support for an attribute is detected according to the existence of that attribute in the kernel
+/// headers.
+/// - `name_in_functions` (not defined by default): overwrite the `<name`> in the name of the methods
+/// `get_<name>`, `set_<name>` and `with_<name>`.
+/// Here, this means that even though the field is called `chain_type`, users can query it with
+/// the method `get_type` instead of `get_chain_type`.
#[proc_macro_error]
#[proc_macro_attribute]
pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
@@ -135,6 +245,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"),
};
+ let state = get_state();
+
let mut fields = Vec::with_capacity(ast.fields.len());
let mut identical_fields = Vec::new();
@@ -149,6 +261,21 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
}
};
if let Some(netlink_type) = field_args.netlink_type.clone() {
+ // optional fields are not generated when the kernel version you have on
+ // the system does not support that field
+ if field_args.optional {
+ let netlink_type_ident = netlink_type
+ .segments
+ .last()
+ .expect("empty path?")
+ .ident
+ .to_string();
+ if !state.declared_identifiers.contains(&netlink_type_ident) {
+ // reject the optional identifier
+ continue 'out;
+ }
+ }
+
fields.push(Field {
name: field.ident.as_ref().expect("Should be a names struct"),
ty: &field.ty,
@@ -276,7 +403,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
{
let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
+ crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
- addr = addr.offset(size as isize);
+ addr = &mut addr[size..];
}
}
)
@@ -296,7 +423,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
size
}
- unsafe fn write_payload(&self, mut addr: *mut u8) {
+ fn write_payload(&self, mut addr: &mut [u8]) {
use crate::nlmsg::NfNetlinkAttribute;
#(#write_entries) *
@@ -312,7 +439,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
let ty = field.ty;
let attrs = &field.attrs;
let vis = &field.vis;
- quote!( #(#attrs) * #vis #name: Option<#ty>, )
+ quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
});
let nfnetlinkdeserialize_impl = if args.derive_deserialize {
quote!(
@@ -382,7 +509,7 @@ fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> {
abort!(&namevalue.lit.span(), "Expected a boolean");
}
}
- _ => abort!(key.span(), "Unsupported macro parameter"),
+ _ => abort!(arg.span(), "Unsupported macro parameter"),
}
}
_ => abort!(arg.span(), "Unrecognized argument"),
@@ -483,7 +610,7 @@ pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream {
(*self as #repr_type).get_size()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
(*self as #repr_type).write_payload(addr);
}
}
diff --git a/src/batch.rs b/src/batch.rs
index b5c88b8..980194b 100644
--- a/src/batch.rs
+++ b/src/batch.rs
@@ -33,6 +33,7 @@ impl Batch {
pub fn new() -> Self {
// TODO: use a pinned Box ?
let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
+ // Safe because we hold onto the buffer for as long as `writer` exists
let mut writer = NfNetlinkWriter::new(unsafe {
std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
});
diff --git a/src/chain.rs b/src/chain.rs
index 37e4cb3..53ac595 100644
--- a/src/chain.rs
+++ b/src/chain.rs
@@ -5,8 +5,7 @@ use crate::error::{DecodeError, QueryError};
use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject};
use crate::sys::{
NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE,
- NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN,
- NFT_MSG_NEWCHAIN,
+ NFTA_CHAIN_TYPE, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN,
};
use crate::{Batch, ProtocolFamily, Table};
use std::fmt::Debug;
@@ -63,7 +62,7 @@ impl NfNetlinkAttribute for ChainPolicy {
(*self as i32).get_size()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
(*self as i32).write_payload(addr);
}
}
@@ -111,7 +110,7 @@ impl NfNetlinkAttribute for ChainType {
self.as_str().len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
self.as_str().to_string().write_payload(addr);
}
}
@@ -135,8 +134,8 @@ impl NfNetlinkDeserializable for ChainType {
///
/// [`Table`]: struct.Table.html
/// [`Rule`]: struct.Rule.html
-#[derive(PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(PartialEq, Eq, Default, Debug)]
pub struct Chain {
family: ProtocolFamily,
#[field(NFTA_CHAIN_TABLE)]
@@ -151,7 +150,7 @@ pub struct Chain {
chain_type: ChainType,
#[field(NFTA_CHAIN_FLAGS)]
flags: u32,
- #[field(NFTA_CHAIN_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
userdata: Vec<u8>,
}
diff --git a/src/error.rs b/src/error.rs
index f6b6247..80f06d7 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -111,9 +111,6 @@ pub enum DecodeError {
#[error("Invalid value for a protocol family")]
UnknownProtocolFamily(i32),
-
- #[error("A custom error occured")]
- Custom(Box<dyn std::error::Error + 'static>),
}
#[derive(thiserror::Error, Debug)]
@@ -157,9 +154,6 @@ pub enum QueryError {
#[error("Error received from the kernel")]
NetlinkError(nlmsgerr),
- #[error("Custom error when customizing the query")]
- InitError(#[from] Box<dyn std::error::Error + Send + 'static>),
-
#[error("Couldn't allocate a netlink object, out of memory ?")]
NetlinkAllocationFailed,
diff --git a/src/expr/mod.rs b/src/expr/mod.rs
index 058b0cb..af29460 100644
--- a/src/expr/mod.rs
+++ b/src/expr/mod.rs
@@ -101,7 +101,7 @@ macro_rules! create_expr_variant {
}
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
match self {
$(
$enum::$name(val) => val.write_payload(addr),
@@ -194,7 +194,7 @@ impl NfNetlinkAttribute for ExpressionRaw {
self.0.get_size()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
+ fn write_payload(&self, addr: &mut [u8]) {
self.0.write_payload(addr);
}
}
diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs
index 7edf7cd..c42ad32 100644
--- a/src/expr/verdict.rs
+++ b/src/expr/verdict.rs
@@ -4,8 +4,7 @@ use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE};
use rustables_macros::{nfnetlink_enum, nfnetlink_struct};
use crate::sys::{
- NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE,
- NFT_GOTO, NFT_JUMP, NFT_RETURN,
+ NFTA_VERDICT_CHAIN, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
@@ -21,14 +20,14 @@ pub enum VerdictType {
Return = NFT_RETURN,
}
-#[derive(Clone, PartialEq, Eq, Default, Debug)]
#[nfnetlink_struct(nested = true)]
+#[derive(Clone, PartialEq, Eq, Default, Debug)]
pub struct Verdict {
#[field(NFTA_VERDICT_CODE)]
code: VerdictType,
#[field(NFTA_VERDICT_CHAIN)]
chain: String,
- #[field(NFTA_VERDICT_CHAIN_ID)]
+ #[field(optional = true, crate::sys::NFTA_VERDICT_CHAIN_ID)]
chain_id: u32,
}
diff --git a/src/nlmsg.rs b/src/nlmsg.rs
index 1c5b519..b8fa857 100644
--- a/src/nlmsg.rs
+++ b/src/nlmsg.rs
@@ -39,6 +39,8 @@ pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
pub struct NfNetlinkWriter<'a> {
buf: &'a mut Vec<u8>,
+ // hold the position of the nlmsghdr and nfgenmsg structures for the object currently being
+ // written
headers: Option<(usize, usize)>,
}
@@ -52,6 +54,7 @@ impl<'a> NfNetlinkWriter<'a> {
let start = self.buf.len();
self.buf.resize(start + padded_size, 0);
+ // if we are *inside* an object begin written, extend the netlink object size
if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers {
let mut hdr: &mut nlmsghdr = unsafe {
std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr)
@@ -78,6 +81,7 @@ impl<'a> NfNetlinkWriter<'a> {
let nlmsghdr_len = pad_netlink_object::<nlmsghdr>();
let nfgenmsg_len = pad_netlink_object::<nfgenmsg>();
+ // serialize the nlmsghdr
let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len);
let mut hdr: &mut nlmsghdr =
unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
@@ -90,6 +94,7 @@ impl<'a> NfNetlinkWriter<'a> {
hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
hdr.nlmsg_seq = seq;
+ // serialize the nfgenmsg
let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len);
let mut nfgenmsg: &mut nfgenmsg =
unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) };
@@ -108,8 +113,10 @@ impl<'a> NfNetlinkWriter<'a> {
}
}
+pub type NetlinkType = u16;
+
pub trait AttributeDecoder {
- fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>;
+ fn decode_attribute(&mut self, attr_type: NetlinkType, buf: &[u8]) -> Result<(), DecodeError>;
}
pub trait NfNetlinkDeserializable: Sized {
@@ -139,9 +146,7 @@ pub trait NfNetlinkObject:
None,
);
let buf = writer.add_data_zeroed(self.get_size());
- unsafe {
- self.write_payload(buf.as_mut_ptr());
- }
+ self.write_payload(buf);
writer.finalize_writing_object();
}
@@ -165,8 +170,6 @@ pub trait NfNetlinkObject:
}
}
-pub type NetlinkType = u16;
-
pub trait NfNetlinkAttribute: Debug + Sized {
// is it a nested argument that must be marked with a NLA_F_NESTED flag?
fn is_nested(&self) -> bool {
@@ -177,6 +180,6 @@ pub trait NfNetlinkAttribute: Debug + Sized {
size_of::<Self>()
}
- // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size());
- unsafe fn write_payload(&self, addr: *mut u8);
+ // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr.as_mut_ptr(), self.get_size());
+ fn write_payload(&self, addr: &mut [u8]);
}
diff --git a/src/parser.rs b/src/parser.rs
index 6ea34c1..82dd27e 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -105,14 +105,10 @@ pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeErr
/// Write the attribute, preceded by a `libc::nlattr`
// rewrite of `mnl_attr_put`
-pub unsafe fn write_attribute<'a>(
- ty: NetlinkType,
- obj: &impl NfNetlinkAttribute,
- mut buf: *mut u8,
-) {
- let header_len = pad_netlink_object::<libc::nlattr>();
+pub fn write_attribute<'a>(ty: NetlinkType, obj: &impl NfNetlinkAttribute, mut buf: &mut [u8]) {
+ let header_len = pad_netlink_object::<nlattr>();
// copy the header
- *(buf as *mut nlattr) = nlattr {
+ let header = nlattr {
// nla_len contains the header size + the unpadded attribute length
nla_len: (header_len + obj.get_size() as usize) as u16,
nla_type: if obj.is_nested() {
@@ -121,7 +117,12 @@ pub unsafe fn write_attribute<'a>(
ty
},
};
- buf = buf.offset(pad_netlink_object::<nlattr>() as isize);
+
+ unsafe {
+ *(buf.as_mut_ptr() as *mut nlattr) = header;
+ }
+
+ buf = &mut buf[header_len..];
// copy the attribute data itself
obj.write_payload(buf);
}
@@ -169,48 +170,30 @@ pub trait InnerFormat {
) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>;
}
-pub trait Parsable
-where
- Self: Sized,
-{
- fn parse_object(
- buf: &[u8],
- add_obj: u32,
- del_obj: u32,
- ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>;
-}
+pub(crate) fn parse_object<T: AttributeDecoder + Default + Sized>(
+ buf: &[u8],
+ add_obj: u32,
+ del_obj: u32,
+) -> Result<(T, nfgenmsg, &[u8]), DecodeError> {
+ debug!("parse_object() started");
+ let (hdr, msg) = parse_nlmsg(buf)?;
-impl<T> Parsable for T
-where
- T: AttributeDecoder + Default + Sized,
-{
- fn parse_object(
- buf: &[u8],
- add_obj: u32,
- del_obj: u32,
- ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> {
- debug!("parse_object() started");
- let (hdr, msg) = parse_nlmsg(buf)?;
-
- let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
-
- if op != add_obj && op != del_obj {
- return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
- }
+ let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32;
- let obj_size = hdr.nlmsg_len as usize
- - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>());
+ if op != add_obj && op != del_obj {
+ return Err(DecodeError::UnexpectedType(hdr.nlmsg_type));
+ }
- let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
- let remaining_data = &buf[remaining_data_offset..];
+ let obj_size = hdr.nlmsg_len as usize
+ - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>());
- let (nfgenmsg, res) = match msg {
- NlMsg::NfGenMsg(nfgenmsg, content) => {
- (nfgenmsg, read_attributes(&content[..obj_size])?)
- }
- _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
- };
+ let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize);
+ let remaining_data = &buf[remaining_data_offset..];
- Ok((res, nfgenmsg, remaining_data))
- }
+ let (nfgenmsg, res) = match msg {
+ NlMsg::NfGenMsg(nfgenmsg, content) => (nfgenmsg, read_attributes(&content[..obj_size])?),
+ _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)),
+ };
+
+ Ok((res, nfgenmsg, remaining_data))
}
diff --git a/src/parser_impls.rs b/src/parser_impls.rs
index b2681bb..c49c876 100644
--- a/src/parser_impls.rs
+++ b/src/parser_impls.rs
@@ -1,4 +1,7 @@
-use std::{fmt::Debug, mem::transmute};
+use std::{
+ fmt::Debug,
+ mem::{size_of, transmute},
+};
use rustables_macros::nfnetlink_struct;
@@ -6,17 +9,17 @@ use crate::{
error::DecodeError,
expr::Verdict,
nlmsg::{
- pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
- NfNetlinkDeserializable, NfNetlinkObject,
+ pad_netlink_object, pad_netlink_object_with_variable_size, AttributeDecoder,
+ NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject,
},
- parser::{write_attribute, Parsable},
+ parser::{parse_object, write_attribute},
sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK},
ProtocolFamily,
};
impl NfNetlinkAttribute for u8 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *addr = *self;
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0] = *self;
}
}
@@ -27,8 +30,8 @@ impl NfNetlinkDeserializable for u8 {
}
impl NfNetlinkAttribute for u16 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -39,8 +42,8 @@ impl NfNetlinkDeserializable for u16 {
}
impl NfNetlinkAttribute for i32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -54,8 +57,8 @@ impl NfNetlinkDeserializable for i32 {
}
impl NfNetlinkAttribute for u32 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -69,8 +72,8 @@ impl NfNetlinkDeserializable for u32 {
}
impl NfNetlinkAttribute for u64 {
- unsafe fn write_payload(&self, addr: *mut u8) {
- *(addr as *mut Self) = self.to_be();
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..size_of::<Self>()].copy_from_slice(&self.to_be_bytes());
}
}
@@ -90,8 +93,8 @@ impl NfNetlinkAttribute for String {
self.len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len());
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..self.len()].copy_from_slice(&self.as_bytes());
}
}
@@ -110,8 +113,8 @@ impl NfNetlinkAttribute for Vec<u8> {
self.len()
}
- unsafe fn write_payload(&self, addr: *mut u8) {
- std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len());
+ fn write_payload(&self, addr: &mut [u8]) {
+ addr[0..self.len()].copy_from_slice(&self.as_slice());
}
}
@@ -170,10 +173,11 @@ where
})
}
- unsafe fn write_payload(&self, mut addr: *mut u8) {
+ fn write_payload(&self, mut addr: &mut [u8]) {
for item in &self.objs {
write_attribute(NFTA_LIST_ELEM, item, addr);
- addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize);
+ let offset = pad_netlink_object::<nlattr>() + item.get_size();
+ addr = &mut addr[offset..];
}
}
}
@@ -228,10 +232,10 @@ where
impl<T> NfNetlinkDeserializable for T
where
- T: NfNetlinkObject + Parsable,
+ T: NfNetlinkObject + AttributeDecoder + Default + Sized,
{
- fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
- let (mut obj, nfgenmsg, remaining_data) = Self::parse_object(
+ fn deserialize(buf: &[u8]) -> Result<(T, &[u8]), DecodeError> {
+ let (mut obj, nfgenmsg, remaining_data) = parse_object::<T>(
buf,
<T as NfNetlinkObject>::MSG_TYPE_ADD,
<T as NfNetlinkObject>::MSG_TYPE_DEL,
diff --git a/src/query.rs b/src/query.rs
index 7cf5050..3548d2a 100644
--- a/src/query.rs
+++ b/src/query.rs
@@ -59,7 +59,7 @@ pub(crate) fn recv_and_process<'a, T>(
}
// we cannot know when a sequence of messages will end if the messages do not end
- // with an NlMsg::Done marker while if a maximum sequence number wasn't specified
+ // with an NlMsg::Done marker if a maximum sequence number wasn't specified
if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
return Err(QueryError::UndecidableMessageTermination);
}
@@ -79,13 +79,7 @@ pub(crate) fn recv_and_process<'a, T>(
// We achieve this by relocating the buffer content at the beginning of the buffer
if end_pos >= nft_nlmsg_maxsize() as usize {
if buf_start < end_pos {
- unsafe {
- std::ptr::copy(
- msg_buffer[buf_start..end_pos].as_ptr(),
- msg_buffer.as_mut_ptr(),
- end_pos - buf_start,
- );
- }
+ msg_buffer.copy_within(buf_start..end_pos, 0);
}
end_pos = end_pos - buf_start;
buf_start = 0;
@@ -128,9 +122,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>(
);
if let Some(filter) = filter {
let buf = writer.add_data_zeroed(filter.get_size());
- unsafe {
- filter.write_payload(buf.as_mut_ptr());
- }
+ filter.write_payload(buf);
}
writer.finalize_writing_object();
Ok(buffer)
diff --git a/src/table.rs b/src/table.rs
index 81a26ef..1d19abe 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -5,7 +5,7 @@ use rustables_macros::nfnetlink_struct;
use crate::error::QueryError;
use crate::nlmsg::NfNetlinkObject;
use crate::sys::{
- NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
+ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE,
NFT_MSG_NEWTABLE,
};
use crate::{Batch, ProtocolFamily};
@@ -14,15 +14,15 @@ use crate::{Batch, ProtocolFamily};
/// family and contains [`Chain`]s that in turn hold the rules.
///
/// [`Chain`]: struct.Chain.html
-#[derive(Default, PartialEq, Eq, Debug)]
#[nfnetlink_struct(derive_deserialize = false)]
+#[derive(Default, PartialEq, Eq, Debug)]
pub struct Table {
family: ProtocolFamily,
#[field(NFTA_TABLE_NAME)]
name: String,
#[field(NFTA_TABLE_FLAGS)]
flags: u32,
- #[field(NFTA_TABLE_USERDATA)]
+ #[field(optional = true, crate::sys::NFTA_TABLE_USERDATA)]
userdata: Vec<u8>,
}