diff options
author | Simon THOBY <git@nightmared.fr> | 2022-12-28 16:28:42 +0100 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2023-01-08 13:46:02 +0100 |
commit | 9821456643bcb6a6a14e6b2a0d3895701f123d03 (patch) | |
tree | 094e023f855f4d153988da070079c3199dbcbe9d | |
parent | 603d540a45c968ad48071a73e1452f87abea377b (diff) |
Re-implement set support
-rw-r--r-- | Cargo.nix | 159 | ||||
-rw-r--r-- | README.md | 52 | ||||
-rw-r--r-- | examples/add-rules.rs | 51 | ||||
-rw-r--r-- | flake.nix | 4 | ||||
-rw-r--r-- | macros/src/lib.rs | 43 | ||||
-rw-r--r-- | src/batch.rs | 8 | ||||
-rw-r--r-- | src/chain.rs | 84 | ||||
-rw-r--r-- | src/data_type.rs | 35 | ||||
-rw-r--r-- | src/error.rs | 174 | ||||
-rw-r--r-- | src/expr/bitwise.rs | 17 | ||||
-rw-r--r-- | src/expr/cmp.rs | 18 | ||||
-rw-r--r-- | src/expr/immediate.rs | 31 | ||||
-rw-r--r-- | src/expr/lookup.rs | 94 | ||||
-rw-r--r-- | src/expr/meta.rs | 6 | ||||
-rw-r--r-- | src/expr/mod.rs | 165 | ||||
-rw-r--r-- | src/expr/payload.rs | 2 | ||||
-rw-r--r-- | src/expr/verdict.rs | 25 | ||||
-rw-r--r-- | src/lib.rs | 94 | ||||
-rw-r--r-- | src/nlmsg.rs | 135 | ||||
-rw-r--r-- | src/parser.rs | 280 | ||||
-rw-r--r-- | src/parser_impls.rs | 243 | ||||
-rw-r--r-- | src/query.rs | 84 | ||||
-rw-r--r-- | src/rule.rs | 73 | ||||
-rw-r--r-- | src/set.rs | 365 | ||||
-rw-r--r-- | src/table.rs | 49 | ||||
-rw-r--r-- | src/tests/batch.rs (renamed from tests/batch.rs) | 12 | ||||
-rw-r--r-- | src/tests/chain.rs (renamed from tests/chain.rs) | 10 | ||||
-rw-r--r-- | src/tests/expr.rs (renamed from tests/expr.rs) | 120 | ||||
-rw-r--r-- | src/tests/mod.rs (renamed from tests/common.rs) | 50 | ||||
-rw-r--r-- | src/tests/rule.rs (renamed from tests/rule.rs) | 10 | ||||
-rw-r--r-- | src/tests/set.rs | 122 | ||||
-rw-r--r-- | src/tests/table.rs (renamed from tests/table.rs) | 11 | ||||
-rw-r--r-- | tests/set.rs | 66 |
33 files changed, 1279 insertions, 1413 deletions
@@ -85,9 +85,9 @@ rec { crates = { "aho-corasick" = rec { crateName = "aho-corasick"; - version = "0.7.19"; + version = "0.7.20"; edition = "2018"; - sha256 = "0knl5n9f396068qk4zrvhcf01d5qp9ja2my4j7ywny093bcmpxdl"; + sha256 = "1b3if3nav4qzgjz9bf75b2cv2h2yisrqfs0np70i38kgz4cn94yc"; libName = "aho_corasick"; authors = [ "Andrew Gallant <jamslam@gmail.com>" @@ -295,10 +295,10 @@ rec { }; "cc" = rec { crateName = "cc"; - version = "1.0.74"; + version = "1.0.78"; edition = "2018"; crateBin = []; - sha256 = "0x0m14cizayy1ydiyvw76gl0wij8120w8ppb7zm55b1sj2x5s7sq"; + sha256 = "0gcch8g41jsjs4zk8fy7k4jhc33sfqdab4nxsrcsds2w6gi080d2"; authors = [ "Alex Crichton <alex@alexcrichton.com>" ]; @@ -548,9 +548,9 @@ rec { }; "glob" = rec { crateName = "glob"; - version = "0.3.0"; + version = "0.3.1"; edition = "2015"; - sha256 = "0x25wfr7vg3mzxc9x05dcphvd3nwlcmbnxrvwcvrrdwplcrrk4cv"; + sha256 = "16zca52nglanv23q5qrwd5jinw3d3as5ylya6y1pbx47vkxvrynj"; authors = [ "The Rust Project Developers" ]; @@ -649,9 +649,9 @@ rec { }; "libc" = rec { crateName = "libc"; - version = "0.2.137"; + version = "0.2.139"; edition = "2015"; - sha256 = "12dz2lk4a7lm03k079n2rkm1l6cpdhvy6nrngbfprzrv19icqzzw"; + sha256 = "0yaz3z56c72p2nfgv2y2zdi8bzi7x3kdq2hzgishgw0da8ky6790"; authors = [ "The Rust Project Developers" ]; @@ -751,9 +751,9 @@ rec { }; "nix" = rec { crateName = "nix"; - version = "0.23.1"; + version = "0.23.2"; edition = "2018"; - sha256 = "1iimixk7y2qk0jswqich4mkd8kqyzdghcgy6203j8fmxmhbn71lz"; + sha256 = "0p5kxhm5d8lry0szqbsllpcb5i3z7lg1dkglw0ni2l011b090dwg"; authors = [ "The nix-rust Project Developers" ]; @@ -829,11 +829,79 @@ rec { ]; }; + "proc-macro-error" = rec { + crateName = "proc-macro-error"; + version = "1.0.4"; + edition = "2018"; + sha256 = "1373bhxaf0pagd8zkyd03kkx6bchzf6g0dkwrwzsnal9z47lj9fs"; + authors = [ + "CreepySkeleton <creepy-skeleton@yandex.ru>" + ]; + dependencies = [ + { + name = "proc-macro-error-attr"; + packageId = "proc-macro-error-attr"; + } + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + { + name = "syn"; + packageId = "syn"; + optional = true; + usesDefaultFeatures = false; + } + ]; + buildDependencies = [ + { + name = "version_check"; + packageId = "version_check"; + } + ]; + features = { + "default" = [ "syn-error" ]; + "syn" = [ "dep:syn" ]; + "syn-error" = [ "syn" ]; + }; + resolvedDefaultFeatures = [ "default" "syn" "syn-error" ]; + }; + "proc-macro-error-attr" = rec { + crateName = "proc-macro-error-attr"; + version = "1.0.4"; + edition = "2018"; + sha256 = "0sgq6m5jfmasmwwy8x4mjygx5l7kp8s4j60bv25ckv2j1qc41gm1"; + procMacro = true; + authors = [ + "CreepySkeleton <creepy-skeleton@yandex.ru>" + ]; + dependencies = [ + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + ]; + buildDependencies = [ + { + name = "version_check"; + packageId = "version_check"; + } + ]; + + }; "proc-macro2" = rec { crateName = "proc-macro2"; - version = "1.0.47"; + version = "1.0.49"; edition = "2018"; - sha256 = "09g7alc7mlbycsadfh7lwskr1qfxbiic9qp9z751cqz3n04dk8sy"; + sha256 = "19b3xdfmnay9mchza82lhb3n8qjrfzkxwd23f50xxzy4z6lyra2p"; authors = [ "David Tolnay <dtolnay@gmail.com>" "Alex Crichton <alex@alexcrichton.com>" @@ -862,9 +930,9 @@ rec { }; "quote" = rec { crateName = "quote"; - version = "1.0.21"; + version = "1.0.23"; edition = "2018"; - sha256 = "0yai5cyd9h95n7hkwjcx8ig3yv0hindmz5gm60g9dmm7fzrlir5v"; + sha256 = "0ywwzw5xfwwgq62ihp4fbjbfdjb3ilss2vh3fka18ai59lvdhml8"; authors = [ "David Tolnay <dtolnay@gmail.com>" ]; @@ -883,9 +951,9 @@ rec { }; "regex" = rec { crateName = "regex"; - version = "1.6.0"; + version = "1.7.0"; edition = "2018"; - sha256 = "12wqvyh4i75j7pc8sgvmqh4yy3qaj4inc4alyv1cdf3lf4kb6kjc"; + sha256 = "12l6if07cb6fa6nigql90qrw0happnbnzqvr6jpg4hg2z2g5axp0"; authors = [ "The Rust Project Developers" ]; @@ -927,9 +995,9 @@ rec { }; "regex-syntax" = rec { crateName = "regex-syntax"; - version = "0.6.27"; + version = "0.6.28"; edition = "2018"; - sha256 = "0i32nnvyzzkvz1rqp2qyfxrp2170859z8ck37jd63c8irrrppy53"; + sha256 = "0j68z4jnxshfymb08j1drvxn9wgs1469047lfaq4im78wcxn0v25"; authors = [ "The Rust Project Developers" ]; @@ -942,7 +1010,7 @@ rec { "rustables" = rec { crateName = "rustables"; version = "0.7.0"; - edition = "2018"; + edition = "2021"; # We can't filter paths with references in Nix 2.4 # See https://github.com/NixOS/nix/issues/5410 src = if (lib.versionOlder builtins.nixVersion "2.4pre20211007") @@ -974,6 +1042,10 @@ rec { packageId = "nix"; } { + name = "rustables-macros"; + packageId = "rustables-macros"; + } + { name = "thiserror"; packageId = "thiserror"; } @@ -1000,6 +1072,37 @@ rec { ]; }; + "rustables-macros" = rec { + crateName = "rustables-macros"; + version = "0.1.0"; + edition = "2021"; + # We can't filter paths with references in Nix 2.4 + # See https://github.com/NixOS/nix/issues/5410 + src = if (lib.versionOlder builtins.nixVersion "2.4pre20211007") + then lib.cleanSourceWith { filter = sourceFilter; src = ./macros; } + else ./macros; + procMacro = true; + dependencies = [ + { + name = "proc-macro-error"; + packageId = "proc-macro-error"; + } + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + { + name = "syn"; + packageId = "syn"; + features = [ "full" ]; + } + ]; + + }; "rustc-hash" = rec { crateName = "rustc-hash"; version = "1.1.0"; @@ -1035,9 +1138,9 @@ rec { }; "syn" = rec { crateName = "syn"; - version = "1.0.103"; + version = "1.0.107"; edition = "2018"; - sha256 = "0pa4b6g938drphblgdhmjnzclp7gcbf4zdgkmfaxlfhk54i08r58"; + sha256 = "1xg3315vx8civ8y0l5zxq5mkx07qskaqwnjak18aw0vfn6sn8h0z"; authors = [ "David Tolnay <dtolnay@gmail.com>" ]; @@ -1065,7 +1168,7 @@ rec { "quote" = [ "dep:quote" ]; "test" = [ "syn-test-suite/all-features" ]; }; - resolvedDefaultFeatures = [ "clone-impls" "default" "derive" "parsing" "printing" "proc-macro" "quote" ]; + resolvedDefaultFeatures = [ "clone-impls" "default" "derive" "full" "parsing" "printing" "proc-macro" "quote" ]; }; "termcolor" = rec { crateName = "termcolor"; @@ -1105,9 +1208,9 @@ rec { }; "thiserror" = rec { crateName = "thiserror"; - version = "1.0.37"; + version = "1.0.38"; edition = "2018"; - sha256 = "0gky83x4i87gd87w3fknnp920wvk9yycp7dgkf5h3jg364vb7phh"; + sha256 = "1l7yh18iqcr2jnl6qjx3ywvhny98cvda3biwc334ap3xm65d373a"; authors = [ "David Tolnay <dtolnay@gmail.com>" ]; @@ -1121,9 +1224,9 @@ rec { }; "thiserror-impl" = rec { crateName = "thiserror-impl"; - version = "1.0.37"; + version = "1.0.38"; edition = "2018"; - sha256 = "1fydmpksd14x1mkc24zas01qjssz8q43sbn2ywl6n527dda1fbcq"; + sha256 = "0vzkcjqkzzgrwwby92xvnbp11a8d70b1gkybm0zx1r458spjgcqz"; procMacro = true; authors = [ "David Tolnay <dtolnay@gmail.com>" @@ -1146,9 +1249,9 @@ rec { }; "unicode-ident" = rec { crateName = "unicode-ident"; - version = "1.0.5"; + version = "1.0.6"; edition = "2018"; - sha256 = "1wznr6ax3jl09vxkvj4a62vip2avfgif13js9sflkjg4b6fv7skc"; + sha256 = "1g2fdsw5sv9l1m73whm99za3lxq3nw4gzx5kvi562h4b46gjp8l4"; authors = [ "David Tolnay <dtolnay@gmail.com>" ]; @@ -1,36 +1,23 @@ # rustables -Safe abstraction for [`libnftnl`]. Provides low-level userspace access to the -in-kernel nf_tables subsystem. See [`rustables-sys`] for the low level FFI -bindings to the C library. - -Can be used to create, list and remove tables, chains, sets and rules from the -nftables firewall, the successor to iptables. - -This library is directly derived from the [`nftnl-rs`] crate. Let us thank here -the original project team for their great work without which this library would -probably not exist today. - -It currently has quite rough edges and does not make adding and removing -netfilter entries super easy and elegant. That is partly because the library -needs more work, but also partly because nftables is super low level and -extremely customizable, making it hard, and probably wrong, to try and create a -too simple/limited wrapper. See examples for inspiration. One can also look -at how the original project this crate was developed to support uses it : -[Mullvad VPN app]. - -Understanding how to use [`libnftnl`] and implementing this crate has mostly -been done by reading the source code for the [`nftables`] program and attaching -debuggers to the `nft` binary. Since the implementation is mostly based on -trial and error, there might of course be a number of places where the -underlying library is used in an invalid or not intended way. Large portions -of [`libnftnl`] are also not covered yet. Contributions are welcome! - -## Supported versions of `libnftnl` - -This crate will automatically link to the currently installed version of -libnftnl upon build. It requires libnftnl version 1.0.6 or higher. See how the -low level FFI bindings to the C library are generated in [`build.rs`]. +Safe abstraction for userspace access to the in-kernel nf_tables subsystem. +Can be used to create and remove tables, chains, sets and rules from the nftables +firewall, the successor to iptables. + +This library is a fork of the [`nftnl-rs`] crate. Let us thank here the original project +team for their great work without which this library would probably not exist today. + +This library currently has quite rough edges and does not make adding and removing netfilter +entries super easy and elegant. That is partly because the library needs more work, but also +partly because nftables is super low level and extremely customizable, making it hard, and +probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. + +Understanding how to use the netlink subsystem and implementing this crate has mostly been done by +reading the source code for the [`nftables`] userspace program and its corresponding kernel code, +as well as attaching debuggers to the `nft` binary. +Since the implementation is mostly based on trial and error, there might of course be +a number of places where the forged netlink messages are used in an invalid or not intended way. +Contributions are welcome! ## Licensing @@ -39,8 +26,5 @@ License: GNU GPLv3 Original work licensed by Amagicom AB under MIT/Apache-2.0. [`nftnl-rs`]: https://github.com/mullvad/nftnl-rs -[Mullvad VPN app]: https://github.com/mullvad/mullvadvpn-app -[`libnftnl`]: https://netfilter.org/projects/libnftnl/ [`nftables`]: https://netfilter.org/projects/nftables/ -[`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 49e8b7b..b145291 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -38,7 +38,7 @@ use ipnetwork::{IpNetwork, Ipv4Network}; use rustables::{ - expr::{ExpressionList, Immediate, VerdictKind}, + expr::{Cmp, CmpOp, ExpressionList, Immediate, Meta, MetaType, Verdict, VerdictKind}, list_chains_for_table, list_rules_for_chain, list_tables, Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table, }; @@ -82,7 +82,12 @@ fn main() -> Result<(), Error> { batch.add(&in_chain, MsgType::Add); let rule = Rule::new(&in_chain)?.with_expressions( - ExpressionList::default().with_expression(Immediate::new_verdict(VerdictKind::Accept)), + ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Accept)), + ); + batch.add(&rule, MsgType::Add); + + let rule = Rule::new(&in_chain)?.with_expressions( + ExpressionList::default().with_value(Immediate::new_verdict(VerdictKind::Continue)), ); batch.add(&rule, MsgType::Add); @@ -94,26 +99,28 @@ fn main() -> Result<(), Error> { // Lookup the interface index of the loopback interface. let lo_iface_index = iface_index("lo")?; - // First expression to be evaluated in this rule is load the meta information "iif" - // (incoming interface index) into the comparison register of netfilter. - // When an incoming network packet is processed by this rule it will first be processed by this - // expression, which will load the interface index of the interface the packet came from into - // a special "register" in netfilter. - //allow_loopback_in_rule.set_expressions(ExpressionList::builder().with_expression()); - //add_expr(&nft_expr!(meta iif)); - // // Next expression in the rule is to compare the value loaded into the register with our desired - // // interface index, and succeed only if it's equal. For any packet processed where the equality - // // does not hold the packet is said to not match this rule, and the packet moves on to be - // // processed by the next rule in the chain instead. - // allow_loopback_in_rule.add_expr(&nft_expr!(cmp == lo_iface_index)); - // - // // Add a verdict expression to the rule. Any packet getting this far in the expression - // // processing without failing any expression will be given the verdict added here. - // allow_loopback_in_rule.add_expr(&nft_expr!(verdict accept)); - // - // // Add the rule to the batch. - // batch.add(&allow_loopback_in_rule, rustables::MsgType::Add); - // + allow_loopback_in_rule.set_expressions( + ExpressionList::default() + // First expression to be evaluated in this rule is load the meta information "iif" + // (incoming interface index) into the comparison register of netfilter. + // When an incoming network packet is processed by this rule it will first be processed by this + // expression, which will load the interface index of the interface the packet came from into + // a special "register" in netfilter. + .with_value(Meta::new(MetaType::Iif)) + // Next expression in the rule is to compare the value loaded into the register with our desired + // interface index, and succeed only if it's equal. For any packet processed where the equality + // does not hold the packet is said to not match this rule, and the packet moves on to be + // processed by the next rule in the chain instead. + .with_value(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) + + // Add a verdict expression to the rule. Any packet getting this far in the expression + // processing without failing any expression will be given the verdict added here. + .with_value(Immediate::new_verdict(VerdictKind::Accept)), + ); + + // Add the rule to the batch. + batch.add(&allow_loopback_in_rule, rustables::MsgType::Add); + // // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === // // let mut block_out_to_private_net_rule = Rule::new(Rc::clone(&out_chain)); @@ -11,8 +11,8 @@ rustOverlay = (final: prev: let rustChannel = prev.rustChannelOf { - channel = "1.65.0"; - sha256 = "DzNEaW724O8/B8844tt5AVHmSjSQ3cmzlU4BP90oRlY="; + channel = "1.66.0"; + sha256 = "S7epLlflwt0d1GZP44u5Xosgf6dRrmr8xxC+Ml2Pq7c="; }; in { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 11aedaf..9170e82 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -220,7 +220,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>()); let (val, remaining) = <#field_type>::deserialize(buf)?; if remaining.len() != 0 { - return Err(crate::parser::DecodeError::InvalidDataSize); + return Err(crate::error::DecodeError::InvalidDataSize); } self.#field_name = Some(val); Ok(()) @@ -230,12 +230,12 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { quote!( impl crate::nlmsg::AttributeDecoder for #name { #[allow(dead_code)] - fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::parser::DecodeError> { + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> { use crate::nlmsg::NfNetlinkDeserializable; debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>()); match attr_type { #(#match_entries),* - _ => Err(crate::parser::DecodeError::UnsupportedAttributeType(attr_type)), + _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)), } } } @@ -250,8 +250,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { quote!( if let Some(val) = &self.#field_name { // Attribute header + attribute value - size += crate::parser::pad_netlink_object::<crate::sys::nlattr>() - + crate::parser::pad_netlink_object_with_variable_size(val.get_size()); + size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); } ) }); @@ -267,8 +267,8 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { #[allow(unused)] { - let size = crate::parser::pad_netlink_object::<crate::sys::nlattr>() - + crate::parser::pad_netlink_object_with_variable_size(val.get_size()); + let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); addr = addr.offset(size as isize); } } @@ -310,7 +310,7 @@ pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { let nfnetlinkdeserialize_impl = if args.derive_deserialize { quote!( impl crate::nlmsg::NfNetlinkDeserializable for #name { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> { Ok((crate::parser::read_attributes(buf)?, &[])) } } @@ -424,20 +424,26 @@ pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { let match_entries = variants.iter().map(|variant| { let variant_name = variant.name; let variant_value = &variant.value; - quote!( x if x == (#variant_value as #repr_type) => Self::#variant_name, ) + quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), ) }); let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span()); + let tryfrom_impl = quote!( + impl ::core::convert::TryFrom<#repr_type> for #name { + type Error = crate::error::DecodeError; + + fn try_from(val: #repr_type) -> Result<Self, Self::Error> { + match val { + #(#match_entries) * + value => Err(crate::error::DecodeError::#unknown_type_ident(value)) + } + } + } + ); let nfnetlinkdeserialize_impl = quote!( impl crate::nlmsg::NfNetlinkDeserializable for #name { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::parser::DecodeError> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> { let (v, remaining_data) = #repr_type::deserialize(buf)?; - Ok(( - match v { - #(#match_entries) * - value => return Err(crate::parser::DecodeError::#unknown_type_ident(value)) - }, - remaining_data, - )) + <#name>::try_from(v).map(|x| (x, remaining_data)) } } ); @@ -475,8 +481,9 @@ pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { } } - #nfnetlinkdeserialize_impl + #tryfrom_impl + #nfnetlinkdeserialize_impl }; res.into() diff --git a/src/batch.rs b/src/batch.rs index d885813..b5c88b8 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -2,11 +2,11 @@ use libc; use thiserror::Error; +use crate::error::QueryError; use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; use crate::sys::NFNL_SUBSYS_NFTABLES; use crate::{MsgType, ProtocolFamily}; -use crate::query::Error; use nix::sys::socket::{ self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, }; @@ -88,7 +88,7 @@ impl Batch { *self.buf } - pub fn send(mut self) -> Result<(), Error> { + pub fn send(self) -> Result<(), QueryError> { use crate::query::{recv_and_process, socket_close_wrapper}; let sock = socket::socket( @@ -97,7 +97,7 @@ impl Batch { SockFlag::empty(), SockProtocol::NetlinkNetFilter, ) - .map_err(Error::NetlinkOpenError)?; + .map_err(QueryError::NetlinkOpenError)?; let max_seq = self.seq - 1; @@ -110,7 +110,7 @@ impl Batch { let mut sent = 0; while sent != to_send.len() { sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) - .map_err(Error::NetlinkSendError)?; + .map_err(QueryError::NetlinkSendError)?; } Ok(socket_close_wrapper(sock, move |sock| { diff --git a/src/chain.rs b/src/chain.rs index 7a62fb2..0ce0ad8 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,15 +1,14 @@ use libc::{NF_ACCEPT, NF_DROP}; use rustables_macros::nfnetlink_struct; -use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; -use crate::parser::{DecodeError, Parsable}; +use crate::error::{DecodeError, QueryError}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; use crate::sys::{ NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, - NFT_MSG_NEWCHAIN, NLM_F_ACK, NLM_F_CREATE, + NFT_MSG_NEWCHAIN, }; -use crate::{MsgType, ProtocolFamily, Table}; -use std::convert::TryFrom; +use crate::{ProtocolFamily, Table}; use std::fmt::Debug; pub type ChainPriority = i32; @@ -132,14 +131,10 @@ impl NfNetlinkDeserializable for ChainType { } } -/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s. -/// -/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more -/// details. +/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s. /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -/// [`set_hook`]: #method.set_hook #[derive(PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(derive_deserialize = false)] pub struct Chain { @@ -166,7 +161,7 @@ impl Chain { /// [`Table`]: struct.Table.html pub fn new(table: &Table) -> Chain { let mut chain = Chain::default(); - chain.family = table.family; + chain.family = table.get_family(); if let Some(table_name) = table.get_name() { chain.set_table(table_name); @@ -174,73 +169,22 @@ impl Chain { chain } - - pub fn get_family(&self) -> ProtocolFamily { - self.family - } - - /* - /// Returns a textual description of the chain. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_chain_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.chain, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - */ } -/* -impl PartialEq for Chain { - fn eq(&self, other: &Self) -> bool { - self.get_table() == other.get_table() && self.get_name() == other.get_name() - } -} -*/ - impl NfNetlinkObject for Chain { - fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { - let raw_msg_type = match msg_type { - MsgType::Add => NFT_MSG_NEWCHAIN, - MsgType::Del => NFT_MSG_DELCHAIN, - } as u16; - writer.write_header( - raw_msg_type, - self.family, - (if let MsgType::Add = msg_type { - NLM_F_CREATE - } else { - 0 - } | NLM_F_ACK) as u16, - seq, - None, - ); - let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } - writer.finalize_writing_object(); - } -} + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN; -impl NfNetlinkDeserializable for Chain { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (mut obj, nfgenmsg, remaining_data) = - Self::parse_object(buf, NFT_MSG_NEWCHAIN, NFT_MSG_DELCHAIN)?; - obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; + fn get_family(&self) -> ProtocolFamily { + self.family + } - Ok((obj, remaining_data)) + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, crate::query::Error> { +pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> { let mut result = Vec::new(); crate::query::list_objects_with_data( libc::NFT_MSG_GETCHAIN as u16, diff --git a/src/data_type.rs b/src/data_type.rs new file mode 100644 index 0000000..f9c97cb --- /dev/null +++ b/src/data_type.rs @@ -0,0 +1,35 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +pub trait DataType { + const TYPE: u32; + const LEN: u32; + + fn data(&self) -> Vec<u8>; +} + +impl DataType for Ipv4Addr { + const TYPE: u32 = 7; + const LEN: u32 = 4; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl DataType for Ipv6Addr { + const TYPE: u32 = 8; + const LEN: u32 = 16; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl<const N: usize> DataType for [u8; N] { + const TYPE: u32 = 5; + const LEN: u32 = N as u32; + + fn data(&self) -> Vec<u8> { + self.to_vec() + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..eae6898 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,174 @@ +use std::string::FromUtf8Error; + +use nix::errno::Errno; +use thiserror::Error; + +use crate::sys::nlmsgerr; + +#[derive(Error, Debug)] +pub enum DecodeError { + #[error("The buffer is too small to hold a valid message")] + BufTooSmall, + + #[error("The message is too small")] + NlMsgTooSmall, + + #[error("The message holds unexpected data")] + InvalidDataSize, + + #[error("Invalid subsystem, expected NFTABLES")] + InvalidSubsystem(u8), + + #[error("Invalid version, expected NFNETLINK_V0")] + InvalidVersion(u8), + + #[error("Invalid port ID")] + InvalidPortId(u32), + + #[error("Invalid sequence number")] + InvalidSeq(u32), + + #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] + ConcurrentGenerationUpdate, + + #[error("Unsupported message type")] + UnsupportedType(u16), + + #[error("Invalid attribute type")] + InvalidAttributeType, + + #[error("Invalid type for a chain")] + UnknownChainType, + + #[error("Invalid policy for a chain")] + UnknownChainPolicy, + + #[error("Unknown type for a Meta expression")] + UnknownMetaType(u32), + + #[error("Unsupported value for an icmp reject type")] + UnknownRejectType(u32), + + #[error("Unsupported value for an icmp code in a reject expression")] + UnknownIcmpCode(u8), + + #[error("Invalid value for a register")] + UnknownRegister(u32), + + #[error("Invalid type for a verdict expression")] + UnknownVerdictType(i32), + + #[error("Invalid type for a nat expression")] + UnknownNatType(i32), + + #[error("Invalid type for a payload expression")] + UnknownPayloadType(u32), + + #[error("Invalid type for a compare expression")] + UnknownCmpOp(u32), + + #[error("Invalid type for a conntrack key")] + UnknownConntrackKey(u32), + + #[error("Unsupported value for a link layer header field")] + UnknownLinkLayerHeaderField(u32, u32), + + #[error("Unsupported value for an IPv4 header field")] + UnknownIPv4HeaderField(u32, u32), + + #[error("Unsupported value for an IPv6 header field")] + UnknownIPv6HeaderField(u32, u32), + + #[error("Unsupported value for a TCP header field")] + UnknownTCPHeaderField(u32, u32), + + #[error("Unsupported value for an UDP header field")] + UnknownUDPHeaderField(u32, u32), + + #[error("Unsupported value for an ICMPv6 header field")] + UnknownICMPv6HeaderField(u32, u32), + + #[error("Missing the 'base' attribute to deserialize the payload object")] + PayloadMissingBase, + + #[error("Missing the 'offset' attribute to deserialize the payload object")] + PayloadMissingOffset, + + #[error("Missing the 'len' attribute to deserialize the payload object")] + PayloadMissingLen, + + #[error("The object does not contain a name for the expression being parsed")] + MissingExpressionName, + + #[error("Unsupported attribute type")] + UnsupportedAttributeType(u16), + + #[error("Unexpected message type")] + UnexpectedType(u16), + + #[error("The decoded String is not UTF8 compliant")] + StringDecodeFailure(#[from] FromUtf8Error), + + #[error("Invalid value for a protocol family")] + UnknownProtocolFamily(i32), + + #[error("A custom error occured")] + Custom(Box<dyn std::error::Error + 'static>), +} + +#[derive(thiserror::Error, Debug)] +pub enum BuilderError { + #[error("The length of the arguments are not compatible with each other")] + IncompatibleLength, + + #[error("The table does not have a name")] + MissingTableName, + + #[error("Missing information in the chain to create a rule")] + MissingChainInformationError, + + #[error("Missing name for the set")] + MissingSetName, +} + +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + #[error("Unable to open netlink socket to netfilter")] + NetlinkOpenError(#[source] nix::Error), + + #[error("Unable to send netlink command to netfilter")] + NetlinkSendError(#[source] nix::Error), + + #[error("Error while reading from netlink socket")] + NetlinkRecvError(#[source] nix::Error), + + #[error("Error while processing an incoming netlink message")] + ProcessNetlinkError(#[from] DecodeError), + + #[error("Error while building netlink objects in Rust")] + BuilderError(#[from] BuilderError), + + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), + + #[error("Custom error when customizing the query")] + InitError(#[from] Box<dyn std::error::Error + Send + 'static>), + + #[error("Couldn't allocate a netlink object, out of memory ?")] + NetlinkAllocationFailed, + + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, + + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, + + #[error("Only a part of the message was sent")] + TruncatedSend, + + #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")] + UndecidableMessageTermination, + + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), +} diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index 29d2d63..fb40a04 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,7 +1,8 @@ use rustables_macros::nfnetlink_struct; -use super::{Expression, ExpressionData, Register}; -use crate::parser::DecodeError; +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::parser_impls::NfNetlinkData; use crate::sys::{ NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, }; @@ -16,9 +17,9 @@ pub struct Bitwise { #[field(NFTA_BITWISE_LEN)] len: u32, #[field(NFTA_BITWISE_MASK)] - mask: ExpressionData, + mask: NfNetlinkData, #[field(NFTA_BITWISE_XOR)] - xor: ExpressionData, + xor: NfNetlinkData, } impl Expression for Bitwise { @@ -30,17 +31,17 @@ impl Expression for Bitwise { impl Bitwise { /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and /// then performs xor with the value in `xor` - pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, DecodeError> { + pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> { let mask = mask.into(); let xor = xor.into(); if mask.len() != xor.len() { - return Err(DecodeError::IncompatibleLength); + return Err(BuilderError::IncompatibleLength); } Ok(Bitwise::default() .with_sreg(Register::Reg1) .with_dreg(Register::Reg1) .with_len(mask.len() as u32) - .with_xor(ExpressionData::default().with_value(xor)) - .with_mask(ExpressionData::default().with_value(mask))) + .with_xor(NfNetlinkData::default().with_value(xor)) + .with_mask(NfNetlinkData::default().with_value(mask))) } } diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index d69f73c..223902f 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,11 +1,15 @@ use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -use crate::sys::{ - NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, - NFT_CMP_LTE, NFT_CMP_NEQ, +use crate::{ + data_type::DataType, + parser_impls::NfNetlinkData, + sys::{ + NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, + NFT_CMP_LTE, NFT_CMP_NEQ, + }, }; -use super::{Expression, ExpressionData, Register}; +use super::{Expression, Register}; /// Comparison operator. #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -34,17 +38,17 @@ pub struct Cmp { #[field(NFTA_CMP_OP)] op: CmpOp, #[field(NFTA_CMP_DATA)] - data: ExpressionData, + data: NfNetlinkData, } impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { + pub fn new(op: CmpOp, data: impl DataType) -> Self { Cmp { sreg: Some(Register::Reg1), op: Some(op), - data: Some(ExpressionData::default().with_value(data)), + data: Some(NfNetlinkData::default().with_value(data.data())), } } } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 134f7e1..2fd9bd5 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,7 +1,10 @@ use rustables_macros::nfnetlink_struct; -use super::{Expression, ExpressionData, Register}; -use crate::sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}; +use super::{Expression, Register, Verdict, VerdictKind, VerdictType}; +use crate::{ + parser_impls::NfNetlinkData, + sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}, +}; #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct] @@ -9,14 +12,34 @@ pub struct Immediate { #[field(NFTA_IMMEDIATE_DREG)] dreg: Register, #[field(NFTA_IMMEDIATE_DATA)] - data: ExpressionData, + data: NfNetlinkData, } impl Immediate { pub fn new_data(data: Vec<u8>, register: Register) -> Self { Immediate::default() .with_dreg(register) - .with_data(ExpressionData::default().with_value(data)) + .with_data(NfNetlinkData::default().with_value(data)) + } + + pub fn new_verdict(kind: VerdictKind) -> Self { + let code = match kind { + VerdictKind::Drop => VerdictType::Drop, + VerdictKind::Accept => VerdictType::Accept, + VerdictKind::Queue => VerdictType::Queue, + VerdictKind::Continue => VerdictType::Continue, + VerdictKind::Break => VerdictType::Break, + VerdictKind::Jump { .. } => VerdictType::Jump, + VerdictKind::Goto { .. } => VerdictType::Goto, + VerdictKind::Return => VerdictType::Return, + }; + let mut data = Verdict::default().with_code(code); + if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { + data.set_chain(chain); + } + Immediate::default() + .with_dreg(Register::Verdict) + .with_data(NfNetlinkData::default().with_verdict(data)) } } diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs index a0cc021..2ef830e 100644 --- a/src/expr/lookup.rs +++ b/src/expr/lookup.rs @@ -1,78 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::set::Set; -use crate::sys::{self, libc}; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -#[derive(Debug, PartialEq)] +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG}; +use crate::Set; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] pub struct Lookup { - set_name: CString, + #[field(NFTA_LOOKUP_SET)] + set: String, + #[field(NFTA_LOOKUP_SREG)] + sreg: Register, + #[field(NFTA_LOOKUP_DREG)] + dreg: Register, + #[field(NFTA_LOOKUP_SET_ID)] set_id: u32, } impl Lookup { - /// Creates a new lookup entry. May return None if the set has no name. - pub fn new<K>(set: &Set<K>) -> Option<Self> { - set.get_name().map(|set_name| Lookup { - set_name: set_name.to_owned(), - set_id: set.get_id(), - }) - } -} - -impl Expression for Lookup { - fn get_raw_name() -> *const libc::c_char { - b"lookup\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16); - let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); - - if set_name.is_null() { - return Err(DeserializationError::NullPointer); - } - - let set_name = CStr::from_ptr(set_name).to_owned(); - - Ok(Lookup { set_id, set_name }) + /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name. + pub fn new(set: &Set) -> Result<Self, BuilderError> { + let mut res = Lookup::default() + .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?) + .with_sreg(Register::Reg1); + + if let Some(id) = set.get_id() { + res.set_set_id(*id); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_LOOKUP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_str( - expr, - sys::NFTNL_EXPR_LOOKUP_SET as u16, - self.set_name.as_ptr() as *const _ as *const c_char, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id); - // This code is left here since it's quite likely we need it again when we get further - // if self.reverse { - // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16, - // libc::NFT_LOOKUP_F_INV as u32); - // } - - expr - } + Ok(res) } } -#[macro_export] -macro_rules! nft_expr_lookup { - ($set:expr) => { - $crate::expr::Lookup::new($set) - }; +impl Expression for Lookup { + fn get_name() -> &'static str { + "lookup" + } } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index 79016bd..d0fecee 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -49,6 +49,12 @@ pub struct Meta { sreg: Register, } +impl Meta { + pub fn new(ty: MetaType) -> Self { + Meta::default().with_dreg(Register::Reg1).with_key(ty) + } +} + impl Expression for Meta { fn get_name() -> &'static str { "meta" diff --git a/src/expr/mod.rs b/src/expr/mod.rs index cfc01c8..979ebb2 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -4,21 +4,15 @@ //! [`Rule`]: struct.Rule.html use std::fmt::Debug; -use std::mem::transmute; - -use crate::nlmsg::NfNetlinkAttribute; -use crate::nlmsg::NfNetlinkDeserializable; -use crate::parser::pad_netlink_object; -use crate::parser::pad_netlink_object_with_variable_size; -use crate::parser::write_attribute; -use crate::parser::DecodeError; -use crate::sys::{self, nlattr}; -use crate::sys::{ - NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_EXPR_DATA, NFTA_EXPR_NAME, NLA_TYPE_MASK, -}; + use rustables_macros::nfnetlink_struct; use thiserror::Error; +use crate::error::DecodeError; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; +use crate::parser_impls::NfNetlinkList; +use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME}; + mod bitwise; pub use self::bitwise::*; @@ -36,11 +30,9 @@ pub use self::immediate::*; mod log; pub use self::log::*; -/* mod lookup; pub use self::lookup::*; -*/ mod masquerade; pub use self::masquerade::*; @@ -105,19 +97,18 @@ pub struct RawExpression { data: ExpressionVariant, } -impl RawExpression { - pub fn new<T>(expr: T) -> Self - where - T: Expression, - ExpressionVariant: From<T>, - { +impl<T> From<T> for RawExpression +where + T: Expression, + ExpressionVariant: From<T>, +{ + fn from(val: T) -> Self { RawExpression::default() .with_name(T::get_name()) - .with_data(ExpressionVariant::from(expr)) + .with_data(ExpressionVariant::from(val)) } } -#[macro_export] macro_rules! create_expr_variant { ($enum:ident $(, [$name:ident, $type:ty])+) => { #[derive(Debug, Clone, PartialEq, Eq)] @@ -162,14 +153,14 @@ macro_rules! create_expr_variant { &mut self, attr_type: u16, buf: &[u8], - ) -> Result<(), $crate::parser::DecodeError> { + ) -> Result<(), $crate::error::DecodeError> { debug!("Decoding attribute {} in an expression", attr_type); match attr_type { x if x == sys::NFTA_EXPR_NAME => { debug!("Calling {}::deserialize()", std::any::type_name::<String>()); let (val, remaining) = String::deserialize(buf)?; if remaining.len() != 0 { - return Err($crate::parser::DecodeError::InvalidDataSize); + return Err($crate::error::DecodeError::InvalidDataSize); } self.name = Some(val); Ok(()) @@ -178,14 +169,14 @@ macro_rules! create_expr_variant { // we can assume we have already the name parsed, as that's how we identify the // type of expression let name = self.name.as_ref() - .ok_or($crate::parser::DecodeError::MissingExpressionName)?; + .ok_or($crate::error::DecodeError::MissingExpressionName)?; match name { $( x if x == <$type>::get_name() => { debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); let (res, remaining) = <$type>::deserialize(buf)?; if remaining.len() != 0 { - return Err($crate::parser::DecodeError::InvalidDataSize); + return Err($crate::error::DecodeError::InvalidDataSize); } self.data = Some(ExpressionVariant::from(res)); Ok(()) @@ -207,126 +198,22 @@ macro_rules! create_expr_variant { create_expr_variant!( ExpressionVariant, - [Log, Log], - [Immediate, Immediate], [Bitwise, Bitwise], + [Cmp, Cmp], + [Conntrack, Conntrack], + [Counter, Counter], [ExpressionRaw, ExpressionRaw], + [Immediate, Immediate], + [Log, Log], + [Lookup, Lookup], + [Masquerade, Masquerade], [Meta, Meta], - [Reject, Reject], - [Counter, Counter], [Nat, Nat], [Payload, Payload], - [Cmp, Cmp], - [Conntrack, Conntrack], - [Masquerade, Masquerade] + [Reject, Reject] ); -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct ExpressionList { - exprs: Vec<RawExpression>, -} - -impl ExpressionList { - /// Useful to add raw expressions because RawExpression cannot infer alone its type - pub fn add_raw_expression(&mut self, e: RawExpression) { - self.exprs.push(e); - } - - pub fn add_expression<T>(&mut self, e: T) - where - T: Expression, - ExpressionVariant: From<T>, - { - self.exprs.push(RawExpression::new(e)); - } - - pub fn with_expression<T>(mut self, e: T) -> Self - where - T: Expression, - ExpressionVariant: From<T>, - { - self.add_expression(e); - self - } - - pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a ExpressionVariant> { - self.exprs.iter().map(|e| e.get_data().unwrap()) - } -} - -impl NfNetlinkAttribute for ExpressionList { - fn is_nested(&self) -> bool { - true - } - - fn get_size(&self) -> usize { - // one nlattr LIST_ELEM per object - self.exprs.iter().fold(0, |acc, item| { - acc + item.get_size() + pad_netlink_object::<nlattr>() - }) - } - - unsafe fn write_payload(&self, mut addr: *mut u8) { - for item in &self.exprs { - write_attribute(sys::NFTA_LIST_ELEM, item, addr); - addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); - } - } -} - -impl NfNetlinkDeserializable for ExpressionList { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let mut exprs = Vec::new(); - - let mut pos = 0; - while buf.len() - pos > pad_netlink_object::<nlattr>() { - let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; - // ignore the byteorder and nested attributes - let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; - - if nla_type != sys::NFTA_LIST_ELEM { - return Err(DecodeError::UnsupportedAttributeType(nla_type)); - } - - let (expr, remaining) = RawExpression::deserialize( - &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], - )?; - if remaining.len() != 0 { - return Err(DecodeError::InvalidDataSize); - } - exprs.push(expr); - - pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); - } - - if pos != buf.len() { - Err(DecodeError::InvalidDataSize) - } else { - Ok((Self { exprs }, &[])) - } - } -} - -impl<T> From<Vec<T>> for ExpressionList -where - ExpressionVariant: From<T>, - T: Expression, -{ - fn from(v: Vec<T>) -> Self { - ExpressionList { - exprs: v.into_iter().map(RawExpression::new).collect(), - } - } -} - -#[derive(Clone, PartialEq, Eq, Default, Debug)] -#[nfnetlink_struct(nested = true)] -pub struct ExpressionData { - #[field(NFTA_DATA_VALUE)] - value: Vec<u8>, - #[field(NFTA_DATA_VERDICT)] - verdict: VerdictAttribute, -} +pub type ExpressionList = NfNetlinkList<RawExpression>; // default type for expressions that we do not handle yet #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/expr/payload.rs b/src/expr/payload.rs index 490a4ec..d0b2cea 100644 --- a/src/expr/payload.rs +++ b/src/expr/payload.rs @@ -2,7 +2,7 @@ use rustables_macros::nfnetlink_struct; use super::{Expression, Register}; use crate::{ - parser::DecodeError, + error::DecodeError, sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER}, }; diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index c4facfb..7edf7cd 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -use super::{ExpressionData, Immediate, Register}; use crate::sys::{ NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, NFT_GOTO, NFT_JUMP, NFT_RETURN, @@ -24,7 +23,7 @@ pub enum VerdictType { #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(nested = true)] -pub struct VerdictAttribute { +pub struct Verdict { #[field(NFTA_VERDICT_CODE)] code: VerdictType, #[field(NFTA_VERDICT_CHAIN)] @@ -50,25 +49,3 @@ pub enum VerdictKind { }, Return, } - -impl Immediate { - pub fn new_verdict(kind: VerdictKind) -> Self { - let code = match kind { - VerdictKind::Drop => VerdictType::Drop, - VerdictKind::Accept => VerdictType::Accept, - VerdictKind::Queue => VerdictType::Queue, - VerdictKind::Continue => VerdictType::Continue, - VerdictKind::Break => VerdictType::Break, - VerdictKind::Jump { .. } => VerdictType::Jump, - VerdictKind::Goto { .. } => VerdictType::Goto, - VerdictKind::Return => VerdictType::Return, - }; - let mut data = VerdictAttribute::default().with_code(code); - if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { - data.set_chain(chain); - } - Immediate::default() - .with_dreg(Register::Verdict) - .with_data(ExpressionData::default().with_verdict(data)) - } -} @@ -1,4 +1,4 @@ -// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby +// Copyryght (c) 2021-2022 GPL lafleur@boum.org and Simon Thoby // // This file is free software: you may copy, redistribute and/or modify it // under the terms of the GNU General Public License as published by the @@ -24,64 +24,37 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Safe abstraction for [`libnftnl`]. Provides userspace access to the in-kernel nf_tables -//! subsystem. Can be used to create and remove tables, chains, sets and rules from the nftables +//! Safe abstraction for userspace access to the in-kernel nf_tables subsystem. +//! Can be used to create and remove tables, chains, sets and rules from the nftables //! firewall, the successor to iptables. //! //! This library currently has quite rough edges and does not make adding and removing netfilter //! entries super easy and elegant. That is partly because the library needs more work, but also //! partly because nftables is super low level and extremely customizable, making it hard, and //! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. -//! One can also look at how the original project this crate was developed to support uses it: -//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app) //! -//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by -//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft` -//! binary. Since the implementation is mostly based on trial and error, there might of course be -//! a number of places where the underlying library is used in an invalid or not intended way. -//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome! +//! Understanding how to use the netlink subsystem and implementing this crate has mostly been done by +//! reading the source code for the [`nftables`] userspace program and its corresponding kernel code, +//! as well as attaching debuggers to the `nft` binary. +//! Since the implementation is mostly based on trial and error, there might of course be +//! a number of places where the forged netlink messages are used in an invalid or not intended way. +//! Contributions are welcome! //! -//! # Supported versions of `libnftnl` -//! -//! This crate will automatically link to the currently installed version of libnftnl upon build. -//! It requires libnftnl version 1.0.6 or higher. See how the low level FFI bindings to the C -//! library are generated in [`build.rs`]. -//! -//! # Access to raw handles -//! -//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absolutely -//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`. -//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For -//! example, a program using a const handle to a object in a thread and writing through a mutable -//! handle in another could reach all kind of undefined (and dangerous!) behaviors. By enabling -//! that feature flag, you acknowledge that guaranteeing the respect of safety invariants is now -//! your responsibility! Despite these shortcomings, that feature is still available because it -//! may allow you to perform manipulations that this library doesn't currently expose. If that is -//! your case, we would be very happy to hear from you and maybe help you get the necessary -//! functionality upstream. -//! -//! Our current lack of confidence in our availability to provide a safe abstraction over the use -//! of raw handles in the face of concurrency is the reason we decided to settly on `Rc` pointers -//! instead of `Arc` (besides, this should gives us some nice performance boost, not that it -//! matters much of course) and why we do not declare the types exposed by the library as `Send` -//! nor `Sync`. -//! -//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/ //! [`nftables`]: https://netfilter.org/projects/nftables/ -//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs - -use parser::DecodeError; #[macro_use] extern crate log; -pub mod sys; use libc; + +use rustables_macros::nfnetlink_enum; use std::convert::TryFrom; mod batch; pub use batch::{default_batch_page_size, Batch}; +mod data_type; + mod table; pub use table::list_tables; pub use table::Table; @@ -90,13 +63,16 @@ mod chain; pub use chain::list_chains_for_table; pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; +pub mod error; + //mod chain_methods; //pub use chain_methods::ChainMethods; pub mod query; -pub mod nlmsg; -pub mod parser; +pub(crate) mod nlmsg; +pub(crate) mod parser; +pub(crate) mod parser_impls; mod rule; pub use rule::list_rules_for_chain; @@ -107,8 +83,13 @@ pub mod expr; //mod rule_methods; //pub use rule_methods::{iface_index, Error as MatchError, Protocol, RuleMethods}; -//pub mod set; -//pub use set::Set; +pub mod set; +pub use set::Set; + +pub mod sys; + +#[cfg(test)] +mod tests; /// The type of the message as it's sent to netfilter. A message consists of an object, such as a /// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with @@ -119,7 +100,7 @@ pub mod expr; /// [`Chain`]: struct.Chain.html /// [`Rule`]: struct.Rule.html /// [`MsgType`]: enum.MsgType.html -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MsgType { /// Add the object to netfilter. Add, @@ -128,8 +109,8 @@ pub enum MsgType { } /// Denotes a protocol. Used to specify which protocol a table or set belongs to. -#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(i32)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_enum(i32)] pub enum ProtocolFamily { Unspec = libc::NFPROTO_UNSPEC, /// Inet - Means both IPv4 and IPv6 @@ -144,23 +125,6 @@ pub enum ProtocolFamily { impl Default for ProtocolFamily { fn default() -> Self { - Self::Unspec - } -} - -impl TryFrom<i32> for ProtocolFamily { - type Error = DecodeError; - fn try_from(value: i32) -> Result<Self, Self::Error> { - match value { - libc::NFPROTO_UNSPEC => Ok(ProtocolFamily::Unspec), - libc::NFPROTO_INET => Ok(ProtocolFamily::Inet), - libc::NFPROTO_IPV4 => Ok(ProtocolFamily::Ipv4), - libc::NFPROTO_ARP => Ok(ProtocolFamily::Arp), - libc::NFPROTO_NETDEV => Ok(ProtocolFamily::NetDev), - libc::NFPROTO_BRIDGE => Ok(ProtocolFamily::Bridge), - libc::NFPROTO_IPV6 => Ok(ProtocolFamily::Ipv6), - libc::NFPROTO_DECNET => Ok(ProtocolFamily::DecNet), - _ => Err(DecodeError::InvalidProtocolFamily(value)), - } + ProtocolFamily::Unspec } } diff --git a/src/nlmsg.rs b/src/nlmsg.rs index 8563a37..b3710bf 100644 --- a/src/nlmsg.rs +++ b/src/nlmsg.rs @@ -1,13 +1,41 @@ use std::{fmt::Debug, mem::size_of}; use crate::{ - parser::{pad_netlink_object, pad_netlink_object_with_variable_size, DecodeError}, + error::DecodeError, sys::{ nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, - NFNL_SUBSYS_NFTABLES, + NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE, }, MsgType, ProtocolFamily, }; +/// +/// The largest nf_tables netlink message is the set element message, which contains the +/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set +/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is +/// a bit larger than 64 KBytes. +pub fn nft_nlmsg_maxsize() -> u32 { + u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 +} + +#[inline] +pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { + // align on a 4 bytes boundary + (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) +} + +#[inline] +pub const fn pad_netlink_object<T>() -> usize { + let size = size_of::<T>(); + pad_netlink_object_with_variable_size(size) +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} pub struct NfNetlinkWriter<'a> { buf: &'a mut Vec<u8>, @@ -92,76 +120,67 @@ pub trait NfNetlinkDeserializable: Sized { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; } -pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable { - fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32); -} - -pub type NetlinkType = u16; - -pub trait NfNetlinkAttribute: Debug + Sized { - // is it a nested argument that must be marked with a NLA_F_NESTED flag? - fn is_nested(&self) -> bool { - false +pub trait NfNetlinkObject: + Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute +{ + const MSG_TYPE_ADD: u32; + const MSG_TYPE_DEL: u32; + + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => Self::MSG_TYPE_ADD, + MsgType::Del => Self::MSG_TYPE_DEL, + } as u16; + writer.write_header( + raw_msg_type, + self.get_family(), + (if let MsgType::Add = msg_type { + self.get_add_flags() + } else { + self.get_del_flags() + } | NLM_F_ACK) as u16, + seq, + None, + ); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } + writer.finalize_writing_object(); } - fn get_size(&self) -> usize { - size_of::<Self>() - } + fn get_family(&self) -> ProtocolFamily; - // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); - unsafe fn write_payload(&self, addr: *mut u8); -} - -/* -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct NfNetlinkAttributes { - pub attributes: BTreeMap<NetlinkType, AttributeType>, -} - -impl NfNetlinkAttributes { - pub fn new() -> Self { - NfNetlinkAttributes { - attributes: BTreeMap::new(), - } + fn set_family(&mut self, _family: ProtocolFamily) { + // the default impl do nothing, because some types are family-agnostic } - pub fn set_attr(&mut self, ty: NetlinkType, obj: AttributeType) { - self.attributes.insert(ty, obj); + fn with_family(mut self, family: ProtocolFamily) -> Self { + self.set_family(family); + self } - pub fn get_attr(&self, ty: NetlinkType) -> Option<&AttributeType> { - self.attributes.get(&ty) + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE } - pub fn serialize<'a>(&self, writer: &mut NfNetlinkWriter<'a>) { - let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } + fn get_del_flags(&self) -> u32 { + 0 } } -impl NfNetlinkAttribute for NfNetlinkAttributes { - fn get_size(&self) -> usize { - let mut size = 0; - - for (_type, attr) in self.attributes.iter() { - // Attribute header + attribute value - size += pad_netlink_object::<nlattr>() - + pad_netlink_object_with_variable_size(attr.get_size()); - } +pub type NetlinkType = u16; - size +pub trait NfNetlinkAttribute: Debug + Sized { + // is it a nested argument that must be marked with a NLA_F_NESTED flag? + fn is_nested(&self) -> bool { + false } - unsafe fn write_payload(&self, mut addr: *mut u8) { - for (ty, attr) in self.attributes.iter() { - debug!("writing attribute {} - {:?}", ty, attr); - write_attribute(*ty, attr, addr); - let size = pad_netlink_object::<nlattr>() - + pad_netlink_object_with_variable_size(attr.get_size()); - addr = addr.offset(size as isize); - } + fn get_size(&self) -> usize { + size_of::<Self>() } + + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); + unsafe fn write_payload(&self, addr: *mut u8); } -*/ diff --git a/src/parser.rs b/src/parser.rs index c402dae..6ea34c1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,167 +1,21 @@ use std::{ - convert::TryFrom, fmt::{Debug, DebugStruct}, mem::{size_of, transmute}, - string::FromUtf8Error, }; -use thiserror::Error; - use crate::{ - nlmsg::{AttributeDecoder, NetlinkType, NfNetlinkAttribute, NfNetlinkDeserializable}, + error::DecodeError, + nlmsg::{ + get_operation_from_nlmsghdr_type, get_subsystem_from_nlmsghdr_type, pad_netlink_object, + pad_netlink_object_with_variable_size, AttributeDecoder, NetlinkType, NfNetlinkAttribute, + }, sys::{ nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, - NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_ALIGNTO, - NLMSG_DONE, NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, + NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, + NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, }, - ProtocolFamily, }; -#[derive(Error, Debug)] -pub enum DecodeError { - #[error("The buffer is too small to hold a valid message")] - BufTooSmall, - - #[error("The message is too small")] - NlMsgTooSmall, - - #[error("The message holds unexpected data")] - InvalidDataSize, - - #[error("Missing information in the chain to create a rule")] - MissingChainInformationError, - - #[error("The length of the arguments are not compatible with each other")] - IncompatibleLength, - - #[error("Invalid subsystem, expected NFTABLES")] - InvalidSubsystem(u8), - - #[error("Invalid version, expected NFNETLINK_V0")] - InvalidVersion(u8), - - #[error("Invalid port ID")] - InvalidPortId(u32), - - #[error("Invalid sequence number")] - InvalidSeq(u32), - - #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] - ConcurrentGenerationUpdate, - - #[error("Unsupported message type")] - UnsupportedType(u16), - - #[error("Invalid attribute type")] - InvalidAttributeType, - - #[error("Invalid type for a chain")] - UnknownChainType, - - #[error("Invalid policy for a chain")] - UnknownChainPolicy, - - #[error("Unknown type for a Meta expression")] - UnknownMetaType(u32), - - #[error("Unsupported value for an icmp reject type")] - UnknownRejectType(u32), - - #[error("Unsupported value for an icmp code in a reject expression")] - UnknownIcmpCode(u8), - - #[error("Invalid value for a register")] - UnknownRegister(u32), - - #[error("Invalid type for a verdict expression")] - UnknownVerdictType(i32), - - #[error("Invalid type for a nat expression")] - UnknownNatType(i32), - - #[error("Invalid type for a payload expression")] - UnknownPayloadType(u32), - - #[error("Invalid type for a compare expression")] - UnknownCmpOp(u32), - - #[error("Invalid type for a conntrack key")] - UnknownConntrackKey(u32), - - #[error("Unsupported value for a link layer header field")] - UnknownLinkLayerHeaderField(u32, u32), - - #[error("Unsupported value for an IPv4 header field")] - UnknownIPv4HeaderField(u32, u32), - - #[error("Unsupported value for an IPv6 header field")] - UnknownIPv6HeaderField(u32, u32), - - #[error("Unsupported value for a TCP header field")] - UnknownTCPHeaderField(u32, u32), - - #[error("Unsupported value for an UDP header field")] - UnknownUDPHeaderField(u32, u32), - - #[error("Unsupported value for an ICMPv6 header field")] - UnknownICMPv6HeaderField(u32, u32), - - #[error("Missing the 'base' attribute to deserialize the payload object")] - PayloadMissingBase, - - #[error("Missing the 'offset' attribute to deserialize the payload object")] - PayloadMissingOffset, - - #[error("Missing the 'len' attribute to deserialize the payload object")] - PayloadMissingLen, - - #[error("The object does not contain a name for the expression being parsed")] - MissingExpressionName, - - #[error("Unsupported attribute type")] - UnsupportedAttributeType(u16), - - #[error("Unexpected message type")] - UnexpectedType(u16), - - #[error("The decoded String is not UTF8 compliant")] - StringDecodeFailure(#[from] FromUtf8Error), - - #[error("Invalid value for a protocol family")] - InvalidProtocolFamily(i32), - - #[error("A custom error occured")] - Custom(Box<dyn std::error::Error + 'static>), -} - -/// The largest nf_tables netlink message is the set element message, which contains the -/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set -/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is -/// a bit larger than 64 KBytes. -pub fn nft_nlmsg_maxsize() -> u32 { - u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 -} - -#[inline] -pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { - // align on a 4 bytes boundary - (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) -} - -#[inline] -pub const fn pad_netlink_object<T>() -> usize { - let size = size_of::<T>(); - pad_netlink_object_with_variable_size(size) -} - -pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { - ((x & 0xff00) >> 8) as u8 -} - -pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { - (x & 0x00ff) as u8 -} - pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> { let size_of_hdr = size_of::<nlmsghdr>(); @@ -272,126 +126,6 @@ pub unsafe fn write_attribute<'a>( obj.write_payload(buf); } -impl NfNetlinkAttribute for u8 { - unsafe fn write_payload(&self, addr: *mut u8) { - *addr = *self; - } -} - -impl NfNetlinkDeserializable for u8 { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok((buf[0], &buf[1..])) - } -} - -impl NfNetlinkAttribute for u16 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); - } -} - -impl NfNetlinkDeserializable for u16 { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) - } -} - -impl NfNetlinkAttribute for i32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); - } -} - -impl NfNetlinkDeserializable for i32 { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok(( - i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - &buf[4..], - )) - } -} - -impl NfNetlinkAttribute for u32 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); - } -} - -impl NfNetlinkDeserializable for u32 { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok(( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - &buf[4..], - )) - } -} - -impl NfNetlinkAttribute for u64 { - unsafe fn write_payload(&self, addr: *mut u8) { - *(addr as *mut Self) = self.to_be(); - } -} - -impl NfNetlinkDeserializable for u64 { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok(( - u64::from_be_bytes([ - buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], - ]), - &buf[8..], - )) - } -} - -impl NfNetlinkAttribute for String { - fn get_size(&self) -> usize { - self.len() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); - } -} - -impl NfNetlinkDeserializable for String { - fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - // ignore the NULL byte terminator, if any - if buf.len() > 0 && buf[buf.len() - 1] == 0 { - buf = &buf[..buf.len() - 1]; - } - Ok((String::from_utf8(buf.to_vec())?, &[])) - } -} - -impl NfNetlinkAttribute for Vec<u8> { - fn get_size(&self) -> usize { - self.len() - } - - unsafe fn write_payload(&self, addr: *mut u8) { - std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); - } -} - -impl NfNetlinkDeserializable for Vec<u8> { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - Ok((buf.to_vec(), &[])) - } -} - -impl NfNetlinkAttribute for ProtocolFamily { - unsafe fn write_payload(&self, addr: *mut u8) { - (*self as i32).write_payload(addr); - } -} - -impl NfNetlinkDeserializable for ProtocolFamily { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (v, remaining_data) = i32::deserialize(buf)?; - Ok((Self::try_from(v)?, remaining_data)) - } -} - pub(crate) fn read_attributes<T: AttributeDecoder + Default>(buf: &[u8]) -> Result<T, DecodeError> { debug!( "Calling <{} as NfNetlinkDeserialize>::deserialize()", diff --git a/src/parser_impls.rs b/src/parser_impls.rs new file mode 100644 index 0000000..b2681bb --- /dev/null +++ b/src/parser_impls.rs @@ -0,0 +1,243 @@ +use std::{fmt::Debug, mem::transmute}; + +use rustables_macros::nfnetlink_struct; + +use crate::{ + error::DecodeError, + expr::Verdict, + nlmsg::{ + pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkDeserializable, NfNetlinkObject, + }, + parser::{write_attribute, Parsable}, + sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, + ProtocolFamily, +}; + +impl NfNetlinkAttribute for u8 { + unsafe fn write_payload(&self, addr: *mut u8) { + *addr = *self; + } +} + +impl NfNetlinkDeserializable for u8 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf[0], &buf[1..])) + } +} + +impl NfNetlinkAttribute for u16 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u16 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) + } +} + +impl NfNetlinkAttribute for i32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for i32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u64 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u64 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ]), + &buf[8..], + )) + } +} + +impl NfNetlinkAttribute for String { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for String { + fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + // ignore the NULL byte terminator, if any + if buf.len() > 0 && buf[buf.len() - 1] == 0 { + buf = &buf[..buf.len() - 1]; + } + Ok((String::from_utf8(buf.to_vec())?, &[])) + } +} + +impl NfNetlinkAttribute for Vec<u8> { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for Vec<u8> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf.to_vec(), &[])) + } +} +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct NfNetlinkData { + #[field(NFTA_DATA_VALUE)] + value: Vec<u8>, + #[field(NFTA_DATA_VERDICT)] + verdict: Verdict, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default, +{ + objs: Vec<T>, +} + +impl<T> NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + pub fn add_value(&mut self, e: impl Into<T>) { + self.objs.push(e.into()); + } + + pub fn with_value(mut self, e: impl Into<T>) -> Self { + self.add_value(e); + self + } + + pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> { + self.objs.iter() + } +} + +impl<T> NfNetlinkAttribute for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + // one nlattr LIST_ELEM per object + self.objs.iter().fold(0, |acc, item| { + acc + item.get_size() + pad_netlink_object::<nlattr>() + }) + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + for item in &self.objs { + write_attribute(NFTA_LIST_ELEM, item, addr); + addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + } + } +} + +impl<T> NfNetlinkDeserializable for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let mut objs = Vec::new(); + + let mut pos = 0; + while buf.len() - pos > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + if nla_type != NFTA_LIST_ELEM { + return Err(DecodeError::UnsupportedAttributeType(nla_type)); + } + + let (obj, remaining) = T::deserialize( + &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], + )?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + objs.push(obj); + + pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if pos != buf.len() { + Err(DecodeError::InvalidDataSize) + } else { + Ok((Self { objs }, &[])) + } + } +} + +impl<O, T> From<Vec<O>> for NfNetlinkList<T> +where + T: From<O>, + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn from(v: Vec<O>) -> Self { + NfNetlinkList { + objs: v.into_iter().map(T::from).collect(), + } + } +} + +impl<T> NfNetlinkDeserializable for T +where + T: NfNetlinkObject + Parsable, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + buf, + <T as NfNetlinkObject>::MSG_TYPE_ADD, + <T as NfNetlinkObject>::MSG_TYPE_DEL, + )?; + obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?); + + Ok((obj, remaining_data)) + } +} diff --git a/src/query.rs b/src/query.rs index 294cbfe..7cf5050 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,71 +1,31 @@ use std::os::unix::prelude::RawFd; +use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}; + use crate::{ - nlmsg::{NfNetlinkAttribute, NfNetlinkObject, NfNetlinkWriter}, - parser::{nft_nlmsg_maxsize, pad_netlink_object_with_variable_size}, - sys::{nlmsgerr, NLM_F_DUMP, NLM_F_MULTI}, + error::QueryError, + nlmsg::{ + nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkObject, NfNetlinkWriter, + }, + parser::{parse_nlmsg, NlMsg}, + sys::{NLM_F_DUMP, NLM_F_MULTI}, ProtocolFamily, }; -use nix::{ - errno::Errno, - sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}, -}; - -use crate::parser::{parse_nlmsg, DecodeError, NlMsg}; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] nix::Error), - - #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] nix::Error), - - #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] nix::Error), - - #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[from] DecodeError), - - #[error("Error received from the kernel")] - NetlinkError(nlmsgerr), - - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + Send + 'static>), - - #[error("Couldn't allocate a netlink object, out of memory ?")] - NetlinkAllocationFailed, - - #[error("This socket is not a netlink socket")] - NotNetlinkSocket, - - #[error("Couldn't retrieve information on a socket")] - RetrievingSocketInfoFailed, - - #[error("Only a part of the message was sent")] - TruncatedSend, - - #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")] - UndecidableMessageTermination, - - #[error("Couldn't close the socket")] - CloseFailed(#[source] Errno), -} - pub(crate) fn recv_and_process<'a, T>( sock: RawFd, max_seq: Option<u32>, - cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), Error>>, + cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>, working_data: &'a mut T, -) -> Result<(), Error> { +) -> Result<(), QueryError> { let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize]; let mut buf_start = 0; let mut end_pos = 0; loop { let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty()) - .map_err(Error::NetlinkRecvError)?; + .map_err(QueryError::NetlinkRecvError)?; if nb_recv <= 0 { return Ok(()); } @@ -87,7 +47,7 @@ pub(crate) fn recv_and_process<'a, T>( } NlMsg::Error(e) => { if e.error != 0 { - return Err(Error::NetlinkError(e)); + return Err(QueryError::NetlinkError(e)); } } NlMsg::Noop => {} @@ -101,7 +61,7 @@ pub(crate) fn recv_and_process<'a, T>( // we cannot know when a sequence of messages will end if the messages do not end // with an NlMsg::Done marker while if a maximum sequence number wasn't specified if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { - return Err(Error::UndecidableMessageTermination); + return Err(QueryError::UndecidableMessageTermination); } // retrieve the next message @@ -136,15 +96,15 @@ pub(crate) fn recv_and_process<'a, T>( pub(crate) fn socket_close_wrapper<E>( sock: RawFd, cb: impl FnOnce(RawFd) -> Result<(), E>, -) -> Result<(), Error> +) -> Result<(), QueryError> where - Error: From<E>, + QueryError: From<E>, { let ret = cb(sock); // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; // and return EOPNOTSUPP if we try) - nix::unistd::close(sock).map_err(Error::CloseFailed)?; + nix::unistd::close(sock).map_err(QueryError::CloseFailed)?; Ok(ret?) } @@ -156,7 +116,7 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>( msg_type: u16, seq: u32, filter: Option<&T>, -) -> Result<Vec<u8>, Error> { +) -> Result<Vec<u8>, QueryError> { let mut buffer = Vec::new(); let mut writer = NfNetlinkWriter::new(&mut buffer); writer.write_header( @@ -182,10 +142,10 @@ pub fn get_list_of_objects<T: NfNetlinkAttribute>( /// and of the output vector, to which it should append the parsed object it received. pub fn list_objects_with_data<'a, Object, Accumulator>( data_type: u16, - cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), Error>, + cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>, filter: Option<&Object>, working_data: &'a mut Accumulator, -) -> Result<(), Error> +) -> Result<(), QueryError> where Object: NfNetlinkObject + NfNetlinkAttribute, { @@ -196,12 +156,12 @@ where SockFlag::empty(), SockProtocol::NetlinkNetFilter, ) - .map_err(Error::NetlinkOpenError)?; + .map_err(QueryError::NetlinkOpenError)?; let seq = 0; let chains_buf = get_list_of_objects(data_type, seq, filter)?; - socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(Error::NetlinkSendError)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?; socket_close_wrapper(sock, move |sock| { // the kernel should return NLM_F_MULTI objects diff --git a/src/rule.rs b/src/rule.rs index 5d13ac4..7f732d3 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,22 +1,24 @@ +use std::fmt::Debug; + use rustables_macros::nfnetlink_struct; +use crate::chain::Chain; +use crate::error::{BuilderError, QueryError}; use crate::expr::ExpressionList; -use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; -use crate::parser::{DecodeError, Parsable}; +use crate::nlmsg::NfNetlinkObject; use crate::query::list_objects_with_data; use crate::sys::{ NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION, - NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_ACK, NLM_F_CREATE, + NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, + NLM_F_CREATE, }; use crate::ProtocolFamily; -use crate::{chain::Chain, MsgType}; -use std::convert::TryFrom; -use std::fmt::Debug; /// A nftables firewall rule. #[derive(Clone, PartialEq, Eq, Default, Debug)] #[nfnetlink_struct(derive_deserialize = false)] pub struct Rule { + family: ProtocolFamily, #[field(NFTA_RULE_TABLE)] table: String, #[field(NFTA_RULE_CHAIN)] @@ -31,78 +33,47 @@ pub struct Rule { userdata: Vec<u8>, #[field(NFTA_RULE_ID)] id: u32, - family: ProtocolFamily, } impl Rule { /// Creates a new rule object in the given [`Chain`]. /// /// [`Chain`]: struct.Chain.html - pub fn new(chain: &Chain) -> Result<Rule, DecodeError> { + pub fn new(chain: &Chain) -> Result<Rule, BuilderError> { Ok(Rule::default() .with_family(chain.get_family()) .with_table( chain .get_table() - .ok_or(DecodeError::MissingChainInformationError)?, + .ok_or(BuilderError::MissingChainInformationError)?, ) .with_chain( chain .get_name() - .ok_or(DecodeError::MissingChainInformationError)?, + .ok_or(BuilderError::MissingChainInformationError)?, )) } +} + +impl NfNetlinkObject for Rule { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE; - pub fn get_family(&self) -> ProtocolFamily { + fn get_family(&self) -> ProtocolFamily { self.family } - pub fn set_family(&mut self, family: ProtocolFamily) { + fn set_family(&mut self, family: ProtocolFamily) { self.family = family; } - pub fn with_family(mut self, family: ProtocolFamily) -> Self { - self.set_family(family); - self - } -} - -impl NfNetlinkObject for Rule { - fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { - let raw_msg_type = match msg_type { - MsgType::Add => NFT_MSG_NEWRULE, - MsgType::Del => NFT_MSG_DELRULE, - } as u16; - writer.write_header( - raw_msg_type, - self.family, - (if let MsgType::Add = msg_type { - NLM_F_CREATE - } else { - 0 - } | NLM_F_ACK) as u16, - seq, - None, - ); - let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } - writer.finalize_writing_object(); - } -} - -impl NfNetlinkDeserializable for Rule { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (mut obj, nfgenmsg, remaining_data) = - Self::parse_object(buf, NFT_MSG_NEWRULE, NFT_MSG_DELRULE)?; - obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; - - Ok((obj, remaining_data)) + // append at the end of the chain, instead of the beginning + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE | NLM_F_APPEND } } -pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, crate::query::Error> { +pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> { let mut result = Vec::new(); list_objects_with_data( libc::NFT_MSG_GETRULE as u16, @@ -1,278 +1,117 @@ -use crate::nlmsg::NlMsg; -use crate::sys::{self, libc}; -use crate::{table::Table, MsgType}; -use std::{ - cell::Cell, - ffi::{c_void, CStr, CString}, - fmt::Debug, - net::{Ipv4Addr, Ipv6Addr}, - os::raw::c_char, - rc::Rc, +use rustables_macros::nfnetlink_struct; + +use crate::data_type::DataType; +use crate::error::BuilderError; +use crate::nlmsg::NfNetlinkObject; +use crate::parser_impls::{NfNetlinkData, NfNetlinkList}; +use crate::sys::{ + NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, NFTA_SET_ELEM_LIST_SET, + NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_FLAGS, NFTA_SET_ID, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_DELSETELEM, + NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, }; - -#[macro_export] -macro_rules! nft_set { - ($name:expr, $id:expr, $table:expr) => { - $crate::set::Set::new(Some($name), $id, $table, $family) - }; - ($name:expr, $id:expr, $table:expr; [ ]) => { - nft_set!(Some($name), $id, $table) - }; - ($name:expr, $id:expr, $table:expr; [ $($value:expr,)* ]) => {{ - let mut set = nft_set!(Some($name), $id, $table).expect("Set allocation failed"); - $( - set.add($value).expect(stringify!(Unable to add $value to set $name)); - )* - set - }}; -} - -pub struct Set<K> { - pub(crate) set: *mut sys::nftnl_set, - pub(crate) table: Rc<Table>, - _marker: ::std::marker::PhantomData<K>, -} - -impl<K> Set<K> { - pub fn new(name: &CStr, id: u32, table: Rc<Table>) -> Self - where - K: SetKey, - { - unsafe { - let set = try_alloc!(sys::nftnl_set_alloc()); - - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, table.get_family() as u32); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr()); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr()); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id); - - sys::nftnl_set_set_u32( - set, - sys::NFTNL_SET_FLAGS as u16, - (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32, - ); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN); - - Set { - set, - table, - _marker: ::std::marker::PhantomData, - } - } - } - - pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: Rc<Table>) -> Self - where - K: SetKey, - { - Set { - set, - table, - _marker: ::std::marker::PhantomData, - } - } - - pub fn add(&mut self, key: &K) - where - K: SetKey, - { - unsafe { - let elem = try_alloc!(sys::nftnl_set_elem_alloc()); - - let data = key.data(); - let data_len = data.len() as u32; - trace!("Adding key {:?} with len {}", data, data_len); - sys::nftnl_set_elem_set( - elem, - sys::NFTNL_SET_ELEM_KEY as u16, - data.as_ref() as *const _ as *const c_void, - data_len, - ); - sys::nftnl_set_elem_add(self.set, elem); - } - } - - pub fn elems_iter(&self) -> SetElemsIter<K> { - SetElemsIter::new(self) - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_set { - self.set as *const sys::nftnl_set - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set { - self.set - } - - /// Returns a textual description of the set. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_set_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.set, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - pub fn get_name(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - pub fn get_id(&self) -> u32 { - unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) } - } -} - -impl<K> Debug for Set<K> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -unsafe impl<K> NlMsg for Set<K> { - unsafe fn write(&self, buf: &mut Vec<u8>, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWSET, - MsgType::Del => libc::NFT_MSG_DELSET, - }; - /* - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.table.get_family() as u16, - (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16, - seq, - ); - sys::nftnl_set_nlmsg_build_payload(header, self.set); - */ - } -} - -impl<K> Drop for Set<K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_free(self.set) }; - } -} - -pub struct SetElemsIter<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} - -impl<'a, K> SetElemsIter<'a, K> { - fn new(set: &'a Set<K>) -> Self { - let iter = try_alloc!(unsafe { - sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set) +use crate::table::Table; +use crate::ProtocolFamily; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(derive_deserialize = false)] +pub struct Set { + pub family: ProtocolFamily, + #[field(NFTA_SET_TABLE)] + pub table: String, + #[field(NFTA_SET_NAME)] + pub name: String, + #[field(NFTA_SET_FLAGS)] + pub flags: u32, + #[field(NFTA_SET_KEY_TYPE)] + pub key_type: u32, + #[field(NFTA_SET_KEY_LEN)] + pub key_len: u32, + #[field(NFTA_SET_ID)] + pub id: u32, + #[field(NFTA_SET_USERDATA)] + pub userdata: String, +} + +impl NfNetlinkObject for Set { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSET; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSET; + + fn get_family(&self) -> ProtocolFamily { + self.family + } + + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; + } +} + +pub struct SetBuilder<K: DataType> { + inner: Set, + list: SetElementList, + _phantom: PhantomData<K>, +} + +impl<K: DataType> SetBuilder<K> { + pub fn new(name: impl Into<String>, id: u32, table: &Table) -> Result<Self, BuilderError> { + let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; + let set_name = name.into(); + let set = Set::default() + .with_id(id) + .with_key_type(K::TYPE) + .with_key_len(K::LEN) + .with_table(table_name) + .with_name(&set_name); + + Ok(SetBuilder { + inner: set, + list: SetElementList { + table: Some(table_name.clone()), + set: Some(set_name), + elements: Some(SetElementListElements::default()), + }, + _phantom: PhantomData, + }) + } + + pub fn add(&mut self, key: &K) { + self.list.elements.as_mut().unwrap().add_value(SetElement { + key: Some(NfNetlinkData::default().with_value(key.data())), }); - SetElemsIter { - set, - iter, - ret: Rc::new(Cell::new(1)), - } } -} - -impl<'a, K> Iterator for SetElemsIter<'a, K> { - type Item = SetElemsMsg<'a, K>; - fn next(&mut self) -> Option<Self::Item> { - if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } { - trace!("SetElemsIter iterator ending"); - None - } else { - trace!("SetElemsIter returning new SetElemsMsg"); - Some(SetElemsMsg { - set: self.set, - iter: self.iter, - ret: self.ret.clone(), - }) - } + pub fn finish(self) -> (Set, SetElementList) { + (self.inner, self.list) } } -impl<'a, K> Drop for SetElemsIter<'a, K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) }; - } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true, derive_deserialize = false)] +pub struct SetElementList { + #[field(NFTA_SET_ELEM_LIST_TABLE)] + pub table: String, + #[field(NFTA_SET_ELEM_LIST_SET)] + pub set: String, + #[field(NFTA_SET_ELEM_LIST_ELEMENTS)] + pub elements: SetElementListElements, } -pub struct SetElemsMsg<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} +impl NfNetlinkObject for SetElementList { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSETELEM; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSETELEM; -unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - trace!("Writing SetElemsMsg to NlMsg"); - let (type_, flags) = match msg_type { - MsgType::Add => ( - libc::NFT_MSG_NEWSETELEM, - libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, - ), - MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.set.table.get_family() as u16, - flags as u16, - seq, - ); - self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter( - header, self.iter, - )); + fn get_family(&self) -> ProtocolFamily { + ProtocolFamily::Unspec } } -pub trait SetKey { - const TYPE: u32; - const LEN: u32; - - fn data(&self) -> Box<[u8]>; +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct SetElement { + #[field(NFTA_SET_ELEM_KEY)] + pub key: NfNetlinkData, } -impl SetKey for Ipv4Addr { - const TYPE: u32 = 7; - const LEN: u32 = 4; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } -} - -impl SetKey for Ipv6Addr { - const TYPE: u32 = 8; - const LEN: u32 = 16; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } -} - -impl<const N: usize> SetKey for [u8; N] { - const TYPE: u32 = 5; - const LEN: u32 = N as u32; - - fn data(&self) -> Box<[u8]> { - Box::new(*self) - } -} +type SetElementListElements = NfNetlinkList<SetElement>; diff --git a/src/table.rs b/src/table.rs index e6a6a1a..63bf669 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,15 +1,14 @@ -use std::convert::TryFrom; use std::fmt::Debug; use rustables_macros::nfnetlink_struct; -use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject, NfNetlinkWriter}; -use crate::parser::{DecodeError, Parsable}; +use crate::error::QueryError; +use crate::nlmsg::NfNetlinkObject; use crate::sys::{ NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, - NFT_MSG_NEWTABLE, NLM_F_ACK, NLM_F_CREATE, + NFT_MSG_NEWTABLE, }; -use crate::{MsgType, ProtocolFamily}; +use crate::ProtocolFamily; /// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. @@ -18,13 +17,13 @@ use crate::{MsgType, ProtocolFamily}; #[derive(Default, PartialEq, Eq, Debug)] #[nfnetlink_struct(derive_deserialize = false)] pub struct Table { + family: ProtocolFamily, #[field(NFTA_TABLE_NAME)] name: String, #[field(NFTA_TABLE_FLAGS)] flags: u32, #[field(NFTA_TABLE_USERDATA)] userdata: Vec<u8>, - pub family: ProtocolFamily, } impl Table { @@ -36,41 +35,19 @@ impl Table { } impl NfNetlinkObject for Table { - fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { - let raw_msg_type = match msg_type { - MsgType::Add => NFT_MSG_NEWTABLE, - MsgType::Del => NFT_MSG_DELTABLE, - } as u16; - writer.write_header( - raw_msg_type, - self.family, - (if let MsgType::Add = msg_type { - NLM_F_CREATE - } else { - 0 - } | NLM_F_ACK) as u16, - seq, - None, - ); - let buf = writer.add_data_zeroed(self.get_size()); - unsafe { - self.write_payload(buf.as_mut_ptr()); - } - writer.finalize_writing_object(); - } -} + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWTABLE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELTABLE; -impl NfNetlinkDeserializable for Table { - fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { - let (mut obj, nfgenmsg, remaining_data) = - Self::parse_object(buf, NFT_MSG_NEWTABLE, NFT_MSG_DELTABLE)?; - obj.family = ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?; + fn get_family(&self) -> ProtocolFamily { + self.family + } - Ok((obj, remaining_data)) + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { +pub fn list_tables() -> Result<Vec<Table>, QueryError> { let mut result = Vec::new(); crate::query::list_objects_with_data( NFT_MSG_GETTABLE as u16, diff --git a/tests/batch.rs b/src/tests/batch.rs index 5a766b0..12f373f 100644 --- a/tests/batch.rs +++ b/src/tests/batch.rs @@ -3,13 +3,12 @@ use std::mem::size_of; use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST}; use nix::libc::NFNL_MSG_BATCH_END; -use rustables::nlmsg::NfNetlinkDeserializable; -use rustables::parser::{pad_netlink_object_with_variable_size, parse_nlmsg, NlMsg}; -use rustables::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; -use rustables::{Batch, MsgType, Table}; +use crate::nlmsg::{pad_netlink_object_with_variable_size, NfNetlinkDeserializable}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; +use crate::{Batch, MsgType, Table}; -mod common; -use common::*; +use super::get_test_table; const HEADER_SIZE: u32 = pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()) as u32; @@ -42,7 +41,6 @@ const DEFAULT_BATCH_END_HDR: nlmsghdr = nlmsghdr { fn batch_empty() { let batch = Batch::new(); let buf = batch.finalize(); - println!("{:?}", buf); let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR); diff --git a/tests/chain.rs b/src/tests/chain.rs index 99347da..7f696e6 100644 --- a/tests/chain.rs +++ b/src/tests/chain.rs @@ -1,5 +1,5 @@ -use rustables::{ - parser::get_operation_from_nlmsghdr_type, +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, sys::{ NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, @@ -7,8 +7,10 @@ use rustables::{ ChainType, Hook, HookClass, MsgType, }; -mod common; -use common::*; +use super::{ + get_test_chain, get_test_nlmsg, get_test_nlmsg_with_msg_type, NetlinkExpr, CHAIN_NAME, + CHAIN_USERDATA, TABLE_NAME, +}; #[test] fn new_empty_chain() { diff --git a/tests/expr.rs b/src/tests/expr.rs index da98677..141f6ac 100644 --- a/tests/expr.rs +++ b/src/tests/expr.rs @@ -1,4 +1,8 @@ -use rustables::{ +use std::net::Ipv4Addr; + +use libc::NF_DROP; + +use crate::{ expr::{ Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, HighLevelPayload, IcmpCode, Immediate, Log, Masquerade, Meta, MetaType, Nat, NatType, @@ -18,29 +22,14 @@ use rustables::{ }, ProtocolFamily, }; -//use rustables::expr::{ -// Bitwise, Cmp, CmpOp, Conntrack, Counter, Expression, HeaderField, IcmpCode, Immediate, Log, -// LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, -// TransportHeaderField, Verdict, -//}; -//use rustables::query::{get_operation_from_nlmsghdr_type, Nfgenmsg}; -//use rustables::set::Set; -//use rustables::sys::libc::{nlmsghdr, NF_DROP}; -//use rustables::{ProtoFamily, Rule}; -//use std::ffi::CStr; -use std::net::Ipv4Addr; -use libc::NF_DROP; - -mod common; -use common::*; +use super::{get_test_nlmsg, get_test_rule, NetlinkExpr, CHAIN_NAME, TABLE_NAME}; #[test] fn bitwise_expr_is_valid() { let netmask = Ipv4Addr::new(255, 255, 255, 0); let bitwise = Bitwise::new(netmask.octets(), [0, 0, 0, 0]).unwrap(); - let mut rule = - get_test_rule().with_expressions(ExpressionList::default().with_expression(bitwise)); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(bitwise)); let mut buf = Vec::new(); let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); @@ -95,7 +84,7 @@ fn bitwise_expr_is_valid() { #[test] fn cmp_expr_is_valid() { - let val = vec![1u8, 2, 3, 4]; + let val = [1u8, 2, 3, 4]; let cmp = Cmp::new(CmpOp::Eq, val.clone()); let mut rule = get_test_rule().with_expressions(vec![cmp]); @@ -121,7 +110,7 @@ fn cmp_expr_is_valid() { NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()), NetlinkExpr::Nested( NFTA_CMP_DATA, - vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val)] + vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val.to_vec())] ) ] ) @@ -221,7 +210,7 @@ fn ct_expr_is_valid() { fn immediate_expr_is_valid() { let immediate = Immediate::new_data(vec![42u8], Register::Reg1); let mut rule = - get_test_rule().with_expressions(ExpressionList::default().with_expression(immediate)); + get_test_rule().with_expressions(ExpressionList::default().with_value(immediate)); let mut buf = Vec::new(); let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); @@ -262,7 +251,7 @@ fn immediate_expr_is_valid() { #[test] fn log_expr_is_valid() { let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression"); - let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_expression(log)); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(log)); let mut buf = Vec::new(); let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); @@ -294,47 +283,49 @@ fn log_expr_is_valid() { ); } -//#[test] -//fn lookup_expr_is_valid() { -// let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); -// let mut rule = get_test_rule(); -// let table = rule.get_chain().get_table(); -// let mut set = Set::new(set_name, 0, table); -// let address: Ipv4Addr = [8, 8, 8, 8].into(); -// set.add(&address); -// let lookup = Lookup::new(&set).unwrap(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); -// assert_eq!(nlmsghdr.nlmsg_len, 104); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), -// NetlinkExpr::Nested( -// NFTA_RULE_EXPRESSIONS, -// vec![NetlinkExpr::Nested( -// NFTA_LIST_ELEM, -// vec![ -// NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), -// NetlinkExpr::Nested( -// NFTA_EXPR_DATA, -// vec![ -// NetlinkExpr::Final( -// NFTA_LOOKUP_SREG, -// NFT_REG_1.to_be_bytes().to_vec() -// ), -// NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), -// NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), -// ] -// ) -// ] -// )] -// ) -// ]) -// .to_raw() -// ); -//} +/* +#[test] +fn lookup_expr_is_valid() { + let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); + let mut rule = get_test_rule(); + let table = rule.get_chain().get_table(); + let mut set = Set::new(set_name, 0, table); + let address: Ipv4Addr = [8, 8, 8, 8].into(); + set.add(&address); + let lookup = Lookup::new(&set).unwrap(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); + assert_eq!(nlmsghdr.nlmsg_len, 104); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_LOOKUP_SREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), + NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} +*/ #[test] fn masquerade_expr_is_valid() { @@ -553,8 +544,7 @@ fn reject_expr_is_valid() { #[test] fn verdict_expr_is_valid() { let verdict = Immediate::new_verdict(VerdictKind::Drop); - let mut rule = - get_test_rule().with_expressions(ExpressionList::default().with_expression(verdict)); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(verdict)); let mut buf = Vec::new(); let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); diff --git a/tests/common.rs b/src/tests/mod.rs index 99b5a6a..3693d35 100644 --- a/tests/common.rs +++ b/src/tests/mod.rs @@ -1,11 +1,15 @@ -#![allow(dead_code)] -use std::ffi::CString; - -use rustables::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; -//use rustables::set::SetKey; -use rustables::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; - -//use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, Rule, Set, Table}; +use crate::data_type::DataType; +use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::set::{Set, SetBuilder}; +use crate::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; + +mod batch; +mod chain; +mod expr; +mod rule; +mod set; +mod table; pub const TABLE_NAME: &'static str = "mocktable"; pub const CHAIN_NAME: &'static str = "mockchain"; @@ -130,10 +134,7 @@ pub fn get_test_table() -> Table { pub fn get_test_table_raw_expr() -> NetlinkExpr { NetlinkExpr::List(vec![ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - NetlinkExpr::Final( - NFTA_TABLE_NAME, - CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), - ), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), ]) .sort() } @@ -141,14 +142,8 @@ pub fn get_test_table_raw_expr() -> NetlinkExpr { pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr { NetlinkExpr::List(vec![ NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - NetlinkExpr::Final( - NFTA_TABLE_NAME, - CString::new(TABLE_NAME).unwrap().to_bytes().to_vec(), - ), - NetlinkExpr::Final( - NFTA_TABLE_USERDATA, - CString::new(TABLE_USERDATA).unwrap().to_bytes().to_vec(), - ), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.as_bytes().to_vec()), ]) .sort() } @@ -161,11 +156,13 @@ pub fn get_test_rule() -> Rule { Rule::new(&get_test_chain()).unwrap() } -/* -pub fn get_test_set<T: SetKey>() -> Set<T> { - Set::new(SET_NAME, SET_ID, Rc::new(get_test_table())) +pub fn get_test_set<K: DataType>() -> Set { + SetBuilder::<K>::new(SET_NAME, SET_ID, &get_test_table()) + .expect("Couldn't create a set") + .finish() + .0 + .with_userdata(SET_USERDATA) } -*/ pub fn get_test_nlmsg_with_msg_type<'a>( buf: &'a mut Vec<u8>, @@ -175,11 +172,10 @@ pub fn get_test_nlmsg_with_msg_type<'a>( let mut writer = NfNetlinkWriter::new(buf); obj.add_or_remove(&mut writer, msg_type, 0); - let (hdr, msg) = - rustables::parser::parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); + let (hdr, msg) = parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); let (nfgenmsg, raw_value) = match msg { - rustables::parser::NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), + NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), _ => panic!("Invalid return value type, expected a valid message"), }; diff --git a/tests/rule.rs b/src/tests/rule.rs index de5be3c..08b4139 100644 --- a/tests/rule.rs +++ b/src/tests/rule.rs @@ -1,5 +1,5 @@ -use rustables::{ - parser::get_operation_from_nlmsghdr_type, +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, sys::{ NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, @@ -7,8 +7,10 @@ use rustables::{ MsgType, }; -mod common; -use common::*; +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_rule, NetlinkExpr, CHAIN_NAME, + RULE_USERDATA, TABLE_NAME, +}; #[test] fn new_empty_rule() { diff --git a/src/tests/set.rs b/src/tests/set.rs new file mode 100644 index 0000000..db27ced --- /dev/null +++ b/src/tests/set.rs @@ -0,0 +1,122 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use crate::{ + data_type::DataType, + nlmsg::get_operation_from_nlmsghdr_type, + set::SetBuilder, + sys::{ + NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_ID, NFTA_SET_KEY_LEN, + NFTA_SET_KEY_TYPE, NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, + NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, + }, + MsgType, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, + SET_ID, SET_NAME, SET_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_set() { + let mut set = get_test_set::<Ipv4Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut set); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_set() { + let mut set = get_test_set::<Ipv6Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut set, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_set_with_data() { + let ip1 = Ipv4Addr::new(127, 0, 0, 1); + let ip2 = Ipv4Addr::new(1, 1, 1, 1); + let mut set_builder = + SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), SET_ID, &get_test_table()) + .expect("Couldn't create a set"); + + set_builder.add(&ip1); + set_builder.add(&ip2); + let (_set, mut elem_list) = set_builder.finish(); + + let mut buf = Vec::new(); + + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut elem_list); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSETELEM as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 84); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_SET, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_SET_ELEM_LIST_ELEMENTS, + vec![ + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip1.data().to_vec())] + )] + ), + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip2.data().to_vec())] + )] + ), + ] + ), + ]) + .to_raw() + ); +} diff --git a/tests/table.rs b/src/tests/table.rs index 44394c9..39bf399 100644 --- a/tests/table.rs +++ b/src/tests/table.rs @@ -1,12 +1,13 @@ -use rustables::{ - nlmsg::NfNetlinkDeserializable, - parser::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize}, +use crate::{ + nlmsg::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize, NfNetlinkDeserializable}, sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE}, MsgType, Table, }; -mod common; -use common::*; +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_table, get_test_table_raw_expr, + get_test_table_with_userdata_raw_expr, TABLE_USERDATA, +}; #[test] fn new_empty_table() { diff --git a/tests/set.rs b/tests/set.rs deleted file mode 100644 index 0e5d002..0000000 --- a/tests/set.rs +++ /dev/null @@ -1,66 +0,0 @@ -//mod sys; -//use std::net::{Ipv4Addr, Ipv6Addr}; -// -//use rustables::{query::get_operation_from_nlmsghdr_type, set::SetKey, MsgType}; -//use sys::*; -// -//mod lib; -//use lib::*; -// -//#[test] -//fn new_empty_set() { -// let mut set = get_test_set::<Ipv4Addr>(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut set); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_NEWSET as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 80); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.to_vec()), -// NetlinkExpr::Final( -// NFTA_SET_FLAGS, -// ((libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32) -// .to_be_bytes() -// .to_vec() -// ), -// NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), -// NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), -// NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), -// ]) -// .to_raw() -// ); -//} -// -//#[test] -//fn delete_empty_set() { -// let mut set = get_test_set::<Ipv6Addr>(); -// let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut set, MsgType::Del); -// assert_eq!( -// get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), -// NFT_MSG_DELSET as u8 -// ); -// assert_eq!(nlmsghdr.nlmsg_len, 80); -// -// assert_eq!( -// raw_expr, -// NetlinkExpr::List(vec![ -// NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.to_vec()), -// NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.to_vec()), -// NetlinkExpr::Final( -// NFTA_SET_FLAGS, -// ((libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32) -// .to_be_bytes() -// .to_vec() -// ), -// NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), -// NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), -// NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), -// ]) -// .to_raw() -// ); -//} |