diff options
author | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
commit | d5b9ec5185a27414286ee303eb3d21ce3069db09 (patch) | |
tree | 369eb90e8a2da307d7cd8f0b15a3318bbdba0003 | |
parent | 3e48e7efa516183d623f80d2e4e393cecc2acde9 (diff) | |
parent | c3e3773cccd01f80f2d72a7691e0654d304e6b2d (diff) |
Merge branch 'no_mnl' into 'master'
experimental support for a full-rust rewrite of the codebase (no libnftnl/libmnl anymore)
See merge request rustwall/rustables!16
58 files changed, 6201 insertions, 4344 deletions
@@ -1,5 +1,3 @@ -src/sys.rs -tests/sys.rs /target/ **/*.rs.bk Cargo.lock diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9d6980d..25c2707 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,9 +1,3 @@ -default: - cache: - paths: - - target/ - - cargo/ - test:alpine: # Official language image. Look for the different tagged releases at: # https://hub.docker.com/r/library/rust/tags/ @@ -21,10 +15,9 @@ test:alpine: RUSTDOCFLAGS: "-C target-feature=-crt-static" before_script: - apk update - - apk add build-base libmnl-dev libnftnl-dev clang-libs + - apk add build-base clang-dev llvm-dev script: - mkdir -p target cargo - - du -sh target cargo - rustc --version && cargo --version - cargo test --workspace --verbose --release @@ -35,9 +28,8 @@ test:debian-stable: CARGO_HOME: $CI_PROJECT_DIR/cargo before_script: - apt-get -y update - - apt-get -y install libmnl-dev libnftnl-dev libclang-11-dev + - apt-get -y install libclang-13-dev llvm-13-dev script: - mkdir -p target cargo - - du -sh target cargo - rustc --version && cargo --version - cargo test --workspace --verbose --release diff --git a/Cargo.nix b/Cargo.nix new file mode 100644 index 0000000..c4a4522 --- /dev/null +++ b/Cargo.nix @@ -0,0 +1,2050 @@ + +# This file was @generated by crate2nix 0.10.0 with the command: +# "generate" +# See https://github.com/kolloch/crate2nix for more info. + +{ nixpkgs ? <nixpkgs> +, pkgs ? import nixpkgs { config = {}; } +, lib ? pkgs.lib +, stdenv ? pkgs.stdenv +, buildRustCrateForPkgs ? pkgs: pkgs.buildRustCrate + # This is used as the `crateOverrides` argument for `buildRustCrate`. +, defaultCrateOverrides ? pkgs.defaultCrateOverrides + # The features to enable for the root_crate or the workspace_members. +, rootFeatures ? [ "default" ] + # If true, throw errors instead of issueing deprecation warnings. +, strictDeprecation ? false + # Used for conditional compilation based on CPU feature detection. +, targetFeatures ? [] + # Whether to perform release builds: longer compile times, faster binaries. +, release ? true + # Additional crate2nix configuration if it exists. +, crateConfig + ? if builtins.pathExists ./crate-config.nix + then pkgs.callPackage ./crate-config.nix {} + else {} +}: + +rec { + # + # "public" attributes that we attempt to keep stable with new versions of crate2nix. + # + + rootCrate = rec { + packageId = "rustables"; + + # Use this attribute to refer to the derivation building your root crate package. + # You can override the features with rootCrate.build.override { features = [ "default" "feature1" ... ]; }. + build = internal.buildRustCrateWithFeatures { + inherit packageId; + }; + + # Debug support which might change between releases. + # File a bug if you depend on any for non-debug work! + debug = internal.debugCrate { inherit packageId; }; + }; + # Refer your crate build derivation by name here. + # You can override the features with + # workspaceMembers."${crateName}".build.override { features = [ "default" "feature1" ... ]; }. + workspaceMembers = { + "rustables" = rec { + packageId = "rustables"; + build = internal.buildRustCrateWithFeatures { + packageId = "rustables"; + }; + + # Debug support which might change between releases. + # File a bug if you depend on any for non-debug work! + debug = internal.debugCrate { inherit packageId; }; + }; + }; + + # A derivation that joins the outputs of all workspace members together. + allWorkspaceMembers = pkgs.symlinkJoin { + name = "all-workspace-members"; + paths = + let members = builtins.attrValues workspaceMembers; + in builtins.map (m: m.build) members; + }; + + # + # "internal" ("private") attributes that may change in every new version of crate2nix. + # + + internal = rec { + # Build and dependency information for crates. + # Many of the fields are passed one-to-one to buildRustCrate. + # + # Noteworthy: + # * `dependencies`/`buildDependencies`: similar to the corresponding fields for buildRustCrate. + # but with additional information which is used during dependency/feature resolution. + # * `resolvedDependencies`: the selected default features reported by cargo - only included for debugging. + # * `devDependencies` as of now not used by `buildRustCrate` but used to + # inject test dependencies into the build + + crates = { + "aho-corasick" = rec { + crateName = "aho-corasick"; + version = "0.7.20"; + edition = "2018"; + sha256 = "1b3if3nav4qzgjz9bf75b2cv2h2yisrqfs0np70i38kgz4cn94yc"; + libName = "aho_corasick"; + authors = [ + "Andrew Gallant <jamslam@gmail.com>" + ]; + dependencies = [ + { + name = "memchr"; + packageId = "memchr"; + usesDefaultFeatures = false; + } + ]; + features = { + "default" = [ "std" ]; + "std" = [ "memchr/std" ]; + }; + resolvedDefaultFeatures = [ "default" "std" ]; + }; + "ansi_term" = rec { + crateName = "ansi_term"; + version = "0.12.1"; + edition = "2015"; + sha256 = "1ljmkbilxgmhavxvxqa7qvm6f3fjggi7q2l3a72q9x0cxjvrnanm"; + authors = [ + "ogham@bsago.me" + "Ryan Scheel (Havvy) <ryan.havvy@gmail.com>" + "Josh Triplett <josh@joshtriplett.org>" + ]; + dependencies = [ + { + name = "winapi"; + packageId = "winapi"; + target = { target, features }: (target."os" == "windows"); + features = [ "consoleapi" "errhandlingapi" "fileapi" "handleapi" "processenv" ]; + } + ]; + features = { + "derive_serde_style" = [ "serde" ]; + "serde" = [ "dep:serde" ]; + }; + }; + "atty" = rec { + crateName = "atty"; + version = "0.2.14"; + edition = "2015"; + sha256 = "1s7yslcs6a28c5vz7jwj63lkfgyx8mx99fdirlhi9lbhhzhrpcyr"; + authors = [ + "softprops <d.tangren@gmail.com>" + ]; + dependencies = [ + { + name = "hermit-abi"; + packageId = "hermit-abi"; + target = { target, features }: (target."os" == "hermit"); + } + { + name = "libc"; + packageId = "libc"; + usesDefaultFeatures = false; + target = { target, features }: (target."unix" or false); + } + { + name = "winapi"; + packageId = "winapi"; + target = { target, features }: (target."windows" or false); + features = [ "consoleapi" "processenv" "minwinbase" "minwindef" "winbase" ]; + } + ]; + + }; + "autocfg" = rec { + crateName = "autocfg"; + version = "1.1.0"; + edition = "2015"; + sha256 = "1ylp3cb47ylzabimazvbz9ms6ap784zhb6syaz6c1jqpmcmq0s6l"; + authors = [ + "Josh Stone <cuviper@gmail.com>" + ]; + + }; + "bindgen" = rec { + crateName = "bindgen"; + version = "0.53.3"; + edition = "2015"; + crateBin = []; + sha256 = "1rc9grfd25bk5b2acmqljhx55ndbzmh7w8b3x6q707cb4s6rfan7"; + authors = [ + "Jyun-Yan You <jyyou.tw@gmail.com>" + "Emilio Cobos Álvarez <emilio@crisal.io>" + "Nick Fitzgerald <fitzgen@gmail.com>" + "The Servo project developers" + ]; + dependencies = [ + { + name = "bitflags"; + packageId = "bitflags"; + } + { + name = "cexpr"; + packageId = "cexpr"; + } + { + name = "cfg-if"; + packageId = "cfg-if 0.1.10"; + } + { + name = "clang-sys"; + packageId = "clang-sys"; + features = [ "clang_6_0" ]; + } + { + name = "clap"; + packageId = "clap"; + optional = true; + } + { + name = "env_logger"; + packageId = "env_logger 0.7.1"; + optional = true; + } + { + name = "lazy_static"; + packageId = "lazy_static"; + } + { + name = "lazycell"; + packageId = "lazycell"; + } + { + name = "log"; + packageId = "log"; + optional = true; + } + { + name = "peeking_take_while"; + packageId = "peeking_take_while"; + } + { + name = "proc-macro2"; + packageId = "proc-macro2"; + usesDefaultFeatures = false; + } + { + name = "quote"; + packageId = "quote"; + usesDefaultFeatures = false; + } + { + name = "regex"; + packageId = "regex"; + usesDefaultFeatures = false; + features = [ "std" "unicode" ]; + } + { + name = "rustc-hash"; + packageId = "rustc-hash"; + } + { + name = "shlex"; + packageId = "shlex"; + } + { + name = "which"; + packageId = "which"; + optional = true; + usesDefaultFeatures = false; + } + ]; + devDependencies = [ + { + name = "clap"; + packageId = "clap"; + } + { + name = "shlex"; + packageId = "shlex"; + } + ]; + features = { + "clap" = [ "dep:clap" ]; + "default" = [ "logging" "clap" "runtime" "which-rustfmt" ]; + "env_logger" = [ "dep:env_logger" ]; + "log" = [ "dep:log" ]; + "logging" = [ "env_logger" "log" ]; + "runtime" = [ "clang-sys/runtime" ]; + "static" = [ "clang-sys/static" ]; + "which" = [ "dep:which" ]; + "which-rustfmt" = [ "which" ]; + }; + resolvedDefaultFeatures = [ "clap" "default" "env_logger" "log" "logging" "runtime" "which" "which-rustfmt" ]; + }; + "bitflags" = rec { + crateName = "bitflags"; + version = "1.3.2"; + edition = "2018"; + sha256 = "12ki6w8gn1ldq7yz9y680llwk5gmrhrzszaa17g1sbrw2r2qvwxy"; + authors = [ + "The Rust Project Developers" + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "rustc-dep-of-std" = [ "core" "compiler_builtins" ]; + }; + resolvedDefaultFeatures = [ "default" ]; + }; + "cc" = rec { + crateName = "cc"; + version = "1.0.78"; + edition = "2018"; + crateBin = []; + sha256 = "0gcch8g41jsjs4zk8fy7k4jhc33sfqdab4nxsrcsds2w6gi080d2"; + authors = [ + "Alex Crichton <alex@alexcrichton.com>" + ]; + features = { + "jobserver" = [ "dep:jobserver" ]; + "parallel" = [ "jobserver" ]; + }; + }; + "cexpr" = rec { + crateName = "cexpr"; + version = "0.4.0"; + edition = "2018"; + sha256 = "09qd1k1mrhcqfhqmsz4y1bya9gcs29si7y3w96pqkgid4y2dpbpl"; + authors = [ + "Jethro Beekman <jethro@jbeekman.nl>" + ]; + dependencies = [ + { + name = "nom"; + packageId = "nom"; + usesDefaultFeatures = false; + features = [ "std" ]; + } + ]; + + }; + "cfg-if 0.1.10" = rec { + crateName = "cfg-if"; + version = "0.1.10"; + edition = "2018"; + sha256 = "08h80ihs74jcyp24cd75wwabygbbdgl05k6p5dmq8akbr78vv1a7"; + authors = [ + "Alex Crichton <alex@alexcrichton.com>" + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "rustc-dep-of-std" = [ "core" "compiler_builtins" ]; + }; + }; + "cfg-if 1.0.0" = rec { + crateName = "cfg-if"; + version = "1.0.0"; + edition = "2018"; + sha256 = "1za0vb97n4brpzpv8lsbnzmq5r8f2b0cpqqr0sy8h5bn751xxwds"; + authors = [ + "Alex Crichton <alex@alexcrichton.com>" + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "rustc-dep-of-std" = [ "core" "compiler_builtins" ]; + }; + }; + "clang-sys" = rec { + crateName = "clang-sys"; + version = "0.29.3"; + edition = "2015"; + sha256 = "02nibl74zbz5x693iy5vdbhnfckja47m7j1mp2bj7fjw3pgkfs7y"; + authors = [ + "Kyle Mayes <kyle@mayeses.com>" + ]; + dependencies = [ + { + name = "glob"; + packageId = "glob"; + } + { + name = "libc"; + packageId = "libc"; + usesDefaultFeatures = false; + } + { + name = "libloading"; + packageId = "libloading"; + optional = true; + } + ]; + buildDependencies = [ + { + name = "glob"; + packageId = "glob"; + } + ]; + features = { + "clang_3_6" = [ "gte_clang_3_6" ]; + "clang_3_7" = [ "gte_clang_3_6" "gte_clang_3_7" ]; + "clang_3_8" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" ]; + "clang_3_9" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" ]; + "clang_4_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" ]; + "clang_5_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" ]; + "clang_6_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" "gte_clang_6_0" ]; + "clang_7_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" "gte_clang_6_0" "gte_clang_7_0" ]; + "clang_8_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" "gte_clang_6_0" "gte_clang_7_0" "gte_clang_8_0" ]; + "clang_9_0" = [ "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" "gte_clang_6_0" "gte_clang_7_0" "gte_clang_8_0" "gte_clang_9_0" ]; + "libloading" = [ "dep:libloading" ]; + "runtime" = [ "libloading" ]; + }; + resolvedDefaultFeatures = [ "clang_6_0" "gte_clang_3_6" "gte_clang_3_7" "gte_clang_3_8" "gte_clang_3_9" "gte_clang_4_0" "gte_clang_5_0" "gte_clang_6_0" "libloading" "runtime" ]; + }; + "clap" = rec { + crateName = "clap"; + version = "2.34.0"; + edition = "2018"; + sha256 = "071q5d8jfwbazi6zhik9xwpacx5i6kb2vkzy060vhf0c3120aqd0"; + authors = [ + "Kevin K. <kbknapp@gmail.com>" + ]; + dependencies = [ + { + name = "ansi_term"; + packageId = "ansi_term"; + optional = true; + target = { target, features }: (!(target."windows" or false)); + } + { + name = "atty"; + packageId = "atty"; + optional = true; + } + { + name = "bitflags"; + packageId = "bitflags"; + } + { + name = "strsim"; + packageId = "strsim"; + optional = true; + } + { + name = "textwrap"; + packageId = "textwrap"; + } + { + name = "unicode-width"; + packageId = "unicode-width"; + } + { + name = "vec_map"; + packageId = "vec_map"; + optional = true; + } + ]; + features = { + "ansi_term" = [ "dep:ansi_term" ]; + "atty" = [ "dep:atty" ]; + "clippy" = [ "dep:clippy" ]; + "color" = [ "ansi_term" "atty" ]; + "default" = [ "suggestions" "color" "vec_map" ]; + "doc" = [ "yaml" ]; + "strsim" = [ "dep:strsim" ]; + "suggestions" = [ "strsim" ]; + "term_size" = [ "dep:term_size" ]; + "vec_map" = [ "dep:vec_map" ]; + "wrap_help" = [ "term_size" "textwrap/term_size" ]; + "yaml" = [ "yaml-rust" ]; + "yaml-rust" = [ "dep:yaml-rust" ]; + }; + resolvedDefaultFeatures = [ "ansi_term" "atty" "color" "default" "strsim" "suggestions" "vec_map" ]; + }; + "env_logger 0.7.1" = rec { + crateName = "env_logger"; + version = "0.7.1"; + edition = "2018"; + sha256 = "0djx8h8xfib43g5w94r1m1mkky5spcw4wblzgnhiyg5vnfxknls4"; + authors = [ + "The Rust Project Developers" + ]; + dependencies = [ + { + name = "atty"; + packageId = "atty"; + optional = true; + } + { + name = "humantime"; + packageId = "humantime 1.3.0"; + optional = true; + } + { + name = "log"; + packageId = "log"; + features = [ "std" ]; + } + { + name = "regex"; + packageId = "regex"; + optional = true; + } + { + name = "termcolor"; + packageId = "termcolor"; + optional = true; + } + ]; + features = { + "atty" = [ "dep:atty" ]; + "default" = [ "termcolor" "atty" "humantime" "regex" ]; + "humantime" = [ "dep:humantime" ]; + "regex" = [ "dep:regex" ]; + "termcolor" = [ "dep:termcolor" ]; + }; + resolvedDefaultFeatures = [ "atty" "default" "humantime" "regex" "termcolor" ]; + }; + "env_logger 0.9.3" = rec { + crateName = "env_logger"; + version = "0.9.3"; + edition = "2018"; + sha256 = "1rq0kqpa8my6i1qcyhfqrn1g9xr5fbkwwbd42nqvlzn9qibncbm1"; + dependencies = [ + { + name = "atty"; + packageId = "atty"; + optional = true; + } + { + name = "humantime"; + packageId = "humantime 2.1.0"; + optional = true; + } + { + name = "log"; + packageId = "log"; + features = [ "std" ]; + } + { + name = "regex"; + packageId = "regex"; + optional = true; + usesDefaultFeatures = false; + features = [ "std" "perf" ]; + } + { + name = "termcolor"; + packageId = "termcolor"; + optional = true; + } + ]; + features = { + "atty" = [ "dep:atty" ]; + "default" = [ "termcolor" "atty" "humantime" "regex" ]; + "humantime" = [ "dep:humantime" ]; + "regex" = [ "dep:regex" ]; + "termcolor" = [ "dep:termcolor" ]; + }; + resolvedDefaultFeatures = [ "atty" "default" "humantime" "regex" "termcolor" ]; + }; + "glob" = rec { + crateName = "glob"; + version = "0.3.1"; + edition = "2015"; + sha256 = "16zca52nglanv23q5qrwd5jinw3d3as5ylya6y1pbx47vkxvrynj"; + authors = [ + "The Rust Project Developers" + ]; + + }; + "hermit-abi" = rec { + crateName = "hermit-abi"; + version = "0.1.19"; + edition = "2018"; + sha256 = "0cxcm8093nf5fyn114w8vxbrbcyvv91d4015rdnlgfll7cs6gd32"; + authors = [ + "Stefan Lankes" + ]; + dependencies = [ + { + name = "libc"; + packageId = "libc"; + usesDefaultFeatures = false; + } + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "rustc-dep-of-std" = [ "core" "compiler_builtins/rustc-dep-of-std" "libc/rustc-dep-of-std" ]; + }; + resolvedDefaultFeatures = [ "default" ]; + }; + "humantime 1.3.0" = rec { + crateName = "humantime"; + version = "1.3.0"; + edition = "2015"; + sha256 = "0krwgbf35pd46xvkqg14j070vircsndabahahlv3rwhflpy4q06z"; + authors = [ + "Paul Colomiets <paul@colomiets.name>" + ]; + dependencies = [ + { + name = "quick-error"; + packageId = "quick-error"; + } + ]; + + }; + "humantime 2.1.0" = rec { + crateName = "humantime"; + version = "2.1.0"; + edition = "2018"; + sha256 = "1r55pfkkf5v0ji1x6izrjwdq9v6sc7bv99xj6srywcar37xmnfls"; + authors = [ + "Paul Colomiets <paul@colomiets.name>" + ]; + + }; + "ipnetwork" = rec { + crateName = "ipnetwork"; + version = "0.20.0"; + edition = "2021"; + sha256 = "03hhmxyimz0800z44wl3z1ak8iw91xcnk7sgx5p5jinmx50naimz"; + authors = [ + "Abhishek Chanda <abhishek.becs@gmail.com>" + "Linus Färnstrand <faern@faern.net>" + ]; + features = { + "default" = [ "serde" ]; + "schemars" = [ "dep:schemars" ]; + "serde" = [ "dep:serde" ]; + }; + }; + "lazy_static" = rec { + crateName = "lazy_static"; + version = "1.4.0"; + edition = "2015"; + sha256 = "0in6ikhw8mgl33wjv6q6xfrb5b9jr16q8ygjy803fay4zcisvaz2"; + authors = [ + "Marvin Löbel <loebel.marvin@gmail.com>" + ]; + features = { + "spin" = [ "dep:spin" ]; + "spin_no_std" = [ "spin" ]; + }; + }; + "lazycell" = rec { + crateName = "lazycell"; + version = "1.3.0"; + edition = "2015"; + sha256 = "0m8gw7dn30i0zjjpjdyf6pc16c34nl71lpv461mix50x3p70h3c3"; + authors = [ + "Alex Crichton <alex@alexcrichton.com>" + "Nikita Pekin <contact@nikitapek.in>" + ]; + features = { + "clippy" = [ "dep:clippy" ]; + "nightly-testing" = [ "clippy" "nightly" ]; + "serde" = [ "dep:serde" ]; + }; + }; + "libc" = rec { + crateName = "libc"; + version = "0.2.139"; + edition = "2015"; + sha256 = "0yaz3z56c72p2nfgv2y2zdi8bzi7x3kdq2hzgishgw0da8ky6790"; + authors = [ + "The Rust Project Developers" + ]; + features = { + "default" = [ "std" ]; + "rustc-dep-of-std" = [ "align" "rustc-std-workspace-core" ]; + "rustc-std-workspace-core" = [ "dep:rustc-std-workspace-core" ]; + "use_std" = [ "std" ]; + }; + resolvedDefaultFeatures = [ "default" "extra_traits" "std" ]; + }; + "libloading" = rec { + crateName = "libloading"; + version = "0.5.2"; + edition = "2015"; + sha256 = "0lyply8rcqc8agajzxs7bq6ivba9dnn1i68kgb9z2flnfjh13cgj"; + authors = [ + "Simonas Kazlauskas <libloading@kazlauskas.me>" + ]; + dependencies = [ + { + name = "winapi"; + packageId = "winapi"; + target = { target, features }: (target."windows" or false); + features = [ "winerror" "errhandlingapi" "libloaderapi" ]; + } + ]; + buildDependencies = [ + { + name = "cc"; + packageId = "cc"; + } + ]; + + }; + "log" = rec { + crateName = "log"; + version = "0.4.17"; + edition = "2015"; + sha256 = "0biqlaaw1lsr8bpnmbcc0fvgjj34yy79ghqzyi0ali7vgil2xcdb"; + authors = [ + "The Rust Project Developers" + ]; + dependencies = [ + { + name = "cfg-if"; + packageId = "cfg-if 1.0.0"; + } + ]; + features = { + "kv_unstable" = [ "value-bag" ]; + "kv_unstable_serde" = [ "kv_unstable_std" "value-bag/serde" "serde" ]; + "kv_unstable_std" = [ "std" "kv_unstable" "value-bag/error" ]; + "kv_unstable_sval" = [ "kv_unstable" "value-bag/sval" "sval" ]; + "serde" = [ "dep:serde" ]; + "sval" = [ "dep:sval" ]; + "value-bag" = [ "dep:value-bag" ]; + }; + resolvedDefaultFeatures = [ "std" ]; + }; + "memchr" = rec { + crateName = "memchr"; + version = "2.5.0"; + edition = "2018"; + sha256 = "0vanfk5mzs1g1syqnj03q8n0syggnhn55dq535h2wxr7rwpfbzrd"; + authors = [ + "Andrew Gallant <jamslam@gmail.com>" + "bluss" + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "default" = [ "std" ]; + "libc" = [ "dep:libc" ]; + "rustc-dep-of-std" = [ "core" "compiler_builtins" ]; + "use_std" = [ "std" ]; + }; + resolvedDefaultFeatures = [ "default" "std" "use_std" ]; + }; + "memoffset" = rec { + crateName = "memoffset"; + version = "0.6.5"; + edition = "2015"; + sha256 = "1kkrzll58a3ayn5zdyy9i1f1v3mx0xgl29x0chq614zazba638ss"; + authors = [ + "Gilad Naaman <gilad.naaman@gmail.com>" + ]; + buildDependencies = [ + { + name = "autocfg"; + packageId = "autocfg"; + } + ]; + features = { + }; + resolvedDefaultFeatures = [ "default" ]; + }; + "nix" = rec { + crateName = "nix"; + version = "0.23.2"; + edition = "2018"; + sha256 = "0p5kxhm5d8lry0szqbsllpcb5i3z7lg1dkglw0ni2l011b090dwg"; + authors = [ + "The nix-rust Project Developers" + ]; + dependencies = [ + { + name = "bitflags"; + packageId = "bitflags"; + } + { + name = "cfg-if"; + packageId = "cfg-if 1.0.0"; + } + { + name = "libc"; + packageId = "libc"; + features = [ "extra_traits" ]; + } + { + name = "memoffset"; + packageId = "memoffset"; + target = { target, features }: (!(target."os" == "redox")); + } + ]; + buildDependencies = [ + { + name = "cc"; + packageId = "cc"; + target = {target, features}: (target."os" == "dragonfly"); + } + ]; + + }; + "nom" = rec { + crateName = "nom"; + version = "5.1.2"; + edition = "2018"; + sha256 = "1br74rwdp3c2ddga03bphnf355spn4mzwf1slg0a30zd4qnjdd7z"; + authors = [ + "contact@geoffroycouprie.com" + ]; + dependencies = [ + { + name = "memchr"; + packageId = "memchr"; + usesDefaultFeatures = false; + } + ]; + buildDependencies = [ + { + name = "version_check"; + packageId = "version_check"; + } + ]; + features = { + "default" = [ "std" "lexical" ]; + "lazy_static" = [ "dep:lazy_static" ]; + "lexical" = [ "lexical-core" ]; + "lexical-core" = [ "dep:lexical-core" ]; + "regex" = [ "dep:regex" ]; + "regexp" = [ "regex" ]; + "regexp_macros" = [ "regexp" "lazy_static" ]; + "std" = [ "alloc" "memchr/use_std" ]; + }; + resolvedDefaultFeatures = [ "alloc" "std" ]; + }; + "peeking_take_while" = rec { + crateName = "peeking_take_while"; + version = "0.1.2"; + edition = "2015"; + sha256 = "16bhqr6rdyrp12zv381cxaaqqd0pwysvm1q8h2ygihvypvfprc8r"; + authors = [ + "Nick Fitzgerald <fitzgen@gmail.com>" + ]; + + }; + "proc-macro-error" = rec { + crateName = "proc-macro-error"; + version = "1.0.4"; + edition = "2018"; + sha256 = "1373bhxaf0pagd8zkyd03kkx6bchzf6g0dkwrwzsnal9z47lj9fs"; + authors = [ + "CreepySkeleton <creepy-skeleton@yandex.ru>" + ]; + dependencies = [ + { + name = "proc-macro-error-attr"; + packageId = "proc-macro-error-attr"; + } + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + { + name = "syn"; + packageId = "syn"; + optional = true; + usesDefaultFeatures = false; + } + ]; + buildDependencies = [ + { + name = "version_check"; + packageId = "version_check"; + } + ]; + features = { + "default" = [ "syn-error" ]; + "syn" = [ "dep:syn" ]; + "syn-error" = [ "syn" ]; + }; + resolvedDefaultFeatures = [ "default" "syn" "syn-error" ]; + }; + "proc-macro-error-attr" = rec { + crateName = "proc-macro-error-attr"; + version = "1.0.4"; + edition = "2018"; + sha256 = "0sgq6m5jfmasmwwy8x4mjygx5l7kp8s4j60bv25ckv2j1qc41gm1"; + procMacro = true; + authors = [ + "CreepySkeleton <creepy-skeleton@yandex.ru>" + ]; + dependencies = [ + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + ]; + buildDependencies = [ + { + name = "version_check"; + packageId = "version_check"; + } + ]; + + }; + "proc-macro2" = rec { + crateName = "proc-macro2"; + version = "1.0.49"; + edition = "2018"; + sha256 = "19b3xdfmnay9mchza82lhb3n8qjrfzkxwd23f50xxzy4z6lyra2p"; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + "Alex Crichton <alex@alexcrichton.com>" + ]; + dependencies = [ + { + name = "unicode-ident"; + packageId = "unicode-ident"; + } + ]; + features = { + "default" = [ "proc-macro" ]; + }; + resolvedDefaultFeatures = [ "default" "proc-macro" ]; + }; + "quick-error" = rec { + crateName = "quick-error"; + version = "1.2.3"; + edition = "2015"; + sha256 = "1q6za3v78hsspisc197bg3g7rpc989qycy8ypr8ap8igv10ikl51"; + authors = [ + "Paul Colomiets <paul@colomiets.name>" + "Colin Kiegel <kiegel@gmx.de>" + ]; + + }; + "quote" = rec { + crateName = "quote"; + version = "1.0.23"; + edition = "2018"; + sha256 = "0ywwzw5xfwwgq62ihp4fbjbfdjb3ilss2vh3fka18ai59lvdhml8"; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + ]; + dependencies = [ + { + name = "proc-macro2"; + packageId = "proc-macro2"; + usesDefaultFeatures = false; + } + ]; + features = { + "default" = [ "proc-macro" ]; + "proc-macro" = [ "proc-macro2/proc-macro" ]; + }; + resolvedDefaultFeatures = [ "default" "proc-macro" ]; + }; + "regex" = rec { + crateName = "regex"; + version = "1.7.1"; + edition = "2018"; + sha256 = "0czp6hxg02lm02hvlhp9xjkd65cjcagw119crnaznwd5idsabaj8"; + authors = [ + "The Rust Project Developers" + ]; + dependencies = [ + { + name = "aho-corasick"; + packageId = "aho-corasick"; + optional = true; + } + { + name = "memchr"; + packageId = "memchr"; + optional = true; + } + { + name = "regex-syntax"; + packageId = "regex-syntax"; + usesDefaultFeatures = false; + } + ]; + features = { + "aho-corasick" = [ "dep:aho-corasick" ]; + "default" = [ "std" "perf" "unicode" "regex-syntax/default" ]; + "memchr" = [ "dep:memchr" ]; + "perf" = [ "perf-cache" "perf-dfa" "perf-inline" "perf-literal" ]; + "perf-literal" = [ "aho-corasick" "memchr" ]; + "unicode" = [ "unicode-age" "unicode-bool" "unicode-case" "unicode-gencat" "unicode-perl" "unicode-script" "unicode-segment" "regex-syntax/unicode" ]; + "unicode-age" = [ "regex-syntax/unicode-age" ]; + "unicode-bool" = [ "regex-syntax/unicode-bool" ]; + "unicode-case" = [ "regex-syntax/unicode-case" ]; + "unicode-gencat" = [ "regex-syntax/unicode-gencat" ]; + "unicode-perl" = [ "regex-syntax/unicode-perl" ]; + "unicode-script" = [ "regex-syntax/unicode-script" ]; + "unicode-segment" = [ "regex-syntax/unicode-segment" ]; + "unstable" = [ "pattern" ]; + "use_std" = [ "std" ]; + }; + resolvedDefaultFeatures = [ "aho-corasick" "default" "memchr" "perf" "perf-cache" "perf-dfa" "perf-inline" "perf-literal" "std" "unicode" "unicode-age" "unicode-bool" "unicode-case" "unicode-gencat" "unicode-perl" "unicode-script" "unicode-segment" ]; + }; + "regex-syntax" = rec { + crateName = "regex-syntax"; + version = "0.6.28"; + edition = "2018"; + sha256 = "0j68z4jnxshfymb08j1drvxn9wgs1469047lfaq4im78wcxn0v25"; + authors = [ + "The Rust Project Developers" + ]; + features = { + "default" = [ "unicode" ]; + "unicode" = [ "unicode-age" "unicode-bool" "unicode-case" "unicode-gencat" "unicode-perl" "unicode-script" "unicode-segment" ]; + }; + resolvedDefaultFeatures = [ "default" "unicode" "unicode-age" "unicode-bool" "unicode-case" "unicode-gencat" "unicode-perl" "unicode-script" "unicode-segment" ]; + }; + "rustables" = rec { + crateName = "rustables"; + version = "0.8.0"; + edition = "2021"; + # We can't filter paths with references in Nix 2.4 + # See https://github.com/NixOS/nix/issues/5410 + src = if (lib.versionOlder builtins.nixVersion "2.4pre20211007") + then lib.cleanSourceWith { filter = sourceFilter; src = ./.; } + else ./.; + authors = [ + "lafleur@boum.org" + "Simon Thoby" + "Mullvad VPN" + ]; + dependencies = [ + { + name = "bitflags"; + packageId = "bitflags"; + } + { + name = "ipnetwork"; + packageId = "ipnetwork"; + usesDefaultFeatures = false; + } + { + name = "libc"; + packageId = "libc"; + } + { + name = "log"; + packageId = "log"; + } + { + name = "nix"; + packageId = "nix"; + } + { + name = "rustables-macros"; + packageId = "rustables-macros"; + } + { + name = "thiserror"; + packageId = "thiserror"; + } + ]; + buildDependencies = [ + { + name = "bindgen"; + packageId = "bindgen"; + } + { + name = "regex"; + packageId = "regex"; + } + ]; + devDependencies = [ + { + name = "env_logger"; + packageId = "env_logger 0.9.3"; + } + ]; + + }; + "rustables-macros" = rec { + crateName = "rustables-macros"; + version = "0.1.0"; + edition = "2021"; + sha256 = "093ygmvwd4w69qiry4p99xvyzm2g4ywf8zx0hxrqhyrwy1fldqxm"; + procMacro = true; + authors = [ + "Simon Thoby" + ]; + dependencies = [ + { + name = "proc-macro-error"; + packageId = "proc-macro-error"; + } + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + { + name = "syn"; + packageId = "syn"; + features = [ "full" ]; + } + ]; + + }; + "rustc-hash" = rec { + crateName = "rustc-hash"; + version = "1.1.0"; + edition = "2015"; + sha256 = "1qkc5khrmv5pqi5l5ca9p5nl5hs742cagrndhbrlk3dhlrx3zm08"; + authors = [ + "The Rust Project Developers" + ]; + features = { + "default" = [ "std" ]; + }; + resolvedDefaultFeatures = [ "default" "std" ]; + }; + "shlex" = rec { + crateName = "shlex"; + version = "0.1.1"; + edition = "2015"; + sha256 = "1lmv6san7g8dv6jdfp14m7bdczq9ss7j7bgsfqyqjc3jnjfippvz"; + authors = [ + "comex <comexk@gmail.com>" + ]; + + }; + "strsim" = rec { + crateName = "strsim"; + version = "0.8.0"; + edition = "2015"; + sha256 = "0sjsm7hrvjdifz661pjxq5w4hf190hx53fra8dfvamacvff139cf"; + authors = [ + "Danny Guo <dannyguo91@gmail.com>" + ]; + + }; + "syn" = rec { + crateName = "syn"; + version = "1.0.107"; + edition = "2018"; + sha256 = "1xg3315vx8civ8y0l5zxq5mkx07qskaqwnjak18aw0vfn6sn8h0z"; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + ]; + dependencies = [ + { + name = "proc-macro2"; + packageId = "proc-macro2"; + usesDefaultFeatures = false; + } + { + name = "quote"; + packageId = "quote"; + optional = true; + usesDefaultFeatures = false; + } + { + name = "unicode-ident"; + packageId = "unicode-ident"; + } + ]; + features = { + "default" = [ "derive" "parsing" "printing" "clone-impls" "proc-macro" ]; + "printing" = [ "quote" ]; + "proc-macro" = [ "proc-macro2/proc-macro" "quote/proc-macro" ]; + "quote" = [ "dep:quote" ]; + "test" = [ "syn-test-suite/all-features" ]; + }; + resolvedDefaultFeatures = [ "clone-impls" "default" "derive" "full" "parsing" "printing" "proc-macro" "quote" ]; + }; + "termcolor" = rec { + crateName = "termcolor"; + version = "1.1.3"; + edition = "2018"; + sha256 = "0mbpflskhnz3jf312k50vn0hqbql8ga2rk0k79pkgchip4q4vcms"; + authors = [ + "Andrew Gallant <jamslam@gmail.com>" + ]; + dependencies = [ + { + name = "winapi-util"; + packageId = "winapi-util"; + target = { target, features }: (target."windows" or false); + } + ]; + + }; + "textwrap" = rec { + crateName = "textwrap"; + version = "0.11.0"; + edition = "2015"; + sha256 = "0q5hky03ik3y50s9sz25r438bc4nwhqc6dqwynv4wylc807n29nk"; + authors = [ + "Martin Geisler <martin@geisler.net>" + ]; + dependencies = [ + { + name = "unicode-width"; + packageId = "unicode-width"; + } + ]; + features = { + "hyphenation" = [ "dep:hyphenation" ]; + "term_size" = [ "dep:term_size" ]; + }; + }; + "thiserror" = rec { + crateName = "thiserror"; + version = "1.0.38"; + edition = "2018"; + sha256 = "1l7yh18iqcr2jnl6qjx3ywvhny98cvda3biwc334ap3xm65d373a"; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + ]; + dependencies = [ + { + name = "thiserror-impl"; + packageId = "thiserror-impl"; + } + ]; + + }; + "thiserror-impl" = rec { + crateName = "thiserror-impl"; + version = "1.0.38"; + edition = "2018"; + sha256 = "0vzkcjqkzzgrwwby92xvnbp11a8d70b1gkybm0zx1r458spjgcqz"; + procMacro = true; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + ]; + dependencies = [ + { + name = "proc-macro2"; + packageId = "proc-macro2"; + } + { + name = "quote"; + packageId = "quote"; + } + { + name = "syn"; + packageId = "syn"; + } + ]; + + }; + "unicode-ident" = rec { + crateName = "unicode-ident"; + version = "1.0.6"; + edition = "2018"; + sha256 = "1g2fdsw5sv9l1m73whm99za3lxq3nw4gzx5kvi562h4b46gjp8l4"; + authors = [ + "David Tolnay <dtolnay@gmail.com>" + ]; + + }; + "unicode-width" = rec { + crateName = "unicode-width"; + version = "0.1.10"; + edition = "2015"; + sha256 = "12vc3wv0qwg8rzcgb9bhaf5119dlmd6lmkhbfy1zfls6n7jx3vf0"; + authors = [ + "kwantam <kwantam@gmail.com>" + "Manish Goregaokar <manishsmail@gmail.com>" + ]; + features = { + "compiler_builtins" = [ "dep:compiler_builtins" ]; + "core" = [ "dep:core" ]; + "rustc-dep-of-std" = [ "std" "core" "compiler_builtins" ]; + "std" = [ "dep:std" ]; + }; + resolvedDefaultFeatures = [ "default" ]; + }; + "vec_map" = rec { + crateName = "vec_map"; + version = "0.8.2"; + edition = "2015"; + sha256 = "1481w9g1dw9rxp3l6snkdqihzyrd2f8vispzqmwjwsdyhw8xzggi"; + authors = [ + "Alex Crichton <alex@alexcrichton.com>" + "Jorge Aparicio <japaricious@gmail.com>" + "Alexis Beingessner <a.beingessner@gmail.com>" + "Brian Anderson <>" + "tbu- <>" + "Manish Goregaokar <>" + "Aaron Turon <aturon@mozilla.com>" + "Adolfo Ochagavía <>" + "Niko Matsakis <>" + "Steven Fackler <>" + "Chase Southwood <csouth3@illinois.edu>" + "Eduard Burtescu <>" + "Florian Wilkens <>" + "Félix Raimundo <>" + "Tibor Benke <>" + "Markus Siemens <markus@m-siemens.de>" + "Josh Branchaud <jbranchaud@gmail.com>" + "Huon Wilson <dbau.pp@gmail.com>" + "Corey Farwell <coref@rwell.org>" + "Aaron Liblong <>" + "Nick Cameron <nrc@ncameron.org>" + "Patrick Walton <pcwalton@mimiga.net>" + "Felix S Klock II <>" + "Andrew Paseltiner <apaseltiner@gmail.com>" + "Sean McArthur <sean.monstar@gmail.com>" + "Vadim Petrochenkov <>" + ]; + features = { + "eders" = [ "serde" ]; + "serde" = [ "dep:serde" ]; + }; + }; + "version_check" = rec { + crateName = "version_check"; + version = "0.9.4"; + edition = "2015"; + sha256 = "0gs8grwdlgh0xq660d7wr80x14vxbizmd8dbp29p2pdncx8lp1s9"; + authors = [ + "Sergio Benitez <sb@sergio.bz>" + ]; + + }; + "which" = rec { + crateName = "which"; + version = "3.1.1"; + edition = "2015"; + sha256 = "094pw9pi48szshn9ln69z2kg7syq1jp80h5ps1qncbsaw4d0f4fh"; + authors = [ + "Harry Fei <tiziyuanfang@gmail.com>" + ]; + dependencies = [ + { + name = "libc"; + packageId = "libc"; + } + ]; + features = { + "default" = [ "failure" ]; + "failure" = [ "dep:failure" ]; + }; + }; + "winapi" = rec { + crateName = "winapi"; + version = "0.3.9"; + edition = "2015"; + sha256 = "06gl025x418lchw1wxj64ycr7gha83m44cjr5sarhynd9xkrm0sw"; + authors = [ + "Peter Atashian <retep998@gmail.com>" + ]; + dependencies = [ + { + name = "winapi-i686-pc-windows-gnu"; + packageId = "winapi-i686-pc-windows-gnu"; + target = { target, features }: (pkgs.rust.lib.toRustTarget stdenv.hostPlatform == "i686-pc-windows-gnu"); + } + { + name = "winapi-x86_64-pc-windows-gnu"; + packageId = "winapi-x86_64-pc-windows-gnu"; + target = { target, features }: (pkgs.rust.lib.toRustTarget stdenv.hostPlatform == "x86_64-pc-windows-gnu"); + } + ]; + features = { + "debug" = [ "impl-debug" ]; + }; + resolvedDefaultFeatures = [ "consoleapi" "errhandlingapi" "fileapi" "handleapi" "libloaderapi" "minwinbase" "minwindef" "processenv" "std" "winbase" "wincon" "winerror" "winnt" ]; + }; + "winapi-i686-pc-windows-gnu" = rec { + crateName = "winapi-i686-pc-windows-gnu"; + version = "0.4.0"; + edition = "2015"; + sha256 = "1dmpa6mvcvzz16zg6d5vrfy4bxgg541wxrcip7cnshi06v38ffxc"; + authors = [ + "Peter Atashian <retep998@gmail.com>" + ]; + + }; + "winapi-util" = rec { + crateName = "winapi-util"; + version = "0.1.5"; + edition = "2018"; + sha256 = "0y71bp7f6d536czj40dhqk0d55wfbbwqfp2ymqf1an5ibgl6rv3h"; + authors = [ + "Andrew Gallant <jamslam@gmail.com>" + ]; + dependencies = [ + { + name = "winapi"; + packageId = "winapi"; + target = { target, features }: (target."windows" or false); + features = [ "std" "consoleapi" "errhandlingapi" "fileapi" "minwindef" "processenv" "winbase" "wincon" "winerror" "winnt" ]; + } + ]; + + }; + "winapi-x86_64-pc-windows-gnu" = rec { + crateName = "winapi-x86_64-pc-windows-gnu"; + version = "0.4.0"; + edition = "2015"; + sha256 = "0gqq64czqb64kskjryj8isp62m2sgvx25yyj3kpc2myh85w24bki"; + authors = [ + "Peter Atashian <retep998@gmail.com>" + ]; + + }; + }; + + # +# crate2nix/default.nix (excerpt start) +# + + /* Target (platform) data for conditional dependencies. + This corresponds roughly to what buildRustCrate is setting. + */ + makeDefaultTarget = platform: { + unix = platform.isUnix; + windows = platform.isWindows; + fuchsia = true; + test = false; + + /* We are choosing an arbitrary rust version to grab `lib` from, + which is unfortunate, but `lib` has been version-agnostic the + whole time so this is good enough for now. + */ + os = pkgs.rust.lib.toTargetOs platform; + arch = pkgs.rust.lib.toTargetArch platform; + family = "unix"; + env = "gnu"; + endian = + if platform.parsed.cpu.significantByte.name == "littleEndian" + then "little" else "big"; + pointer_width = toString platform.parsed.cpu.bits; + vendor = platform.parsed.vendor.name; + debug_assertions = false; + }; + + /* Filters common temp files and build files. */ + # TODO(pkolloch): Substitute with gitignore filter + sourceFilter = name: type: + let + baseName = builtins.baseNameOf (builtins.toString name); + in + ! ( + # Filter out git + baseName == ".gitignore" + || (type == "directory" && baseName == ".git") + + # Filter out build results + || ( + type == "directory" && ( + baseName == "target" + || baseName == "_site" + || baseName == ".sass-cache" + || baseName == ".jekyll-metadata" + || baseName == "build-artifacts" + ) + ) + + # Filter out nix-build result symlinks + || ( + type == "symlink" && lib.hasPrefix "result" baseName + ) + + # Filter out IDE config + || ( + type == "directory" && ( + baseName == ".idea" || baseName == ".vscode" + ) + ) || lib.hasSuffix ".iml" baseName + + # Filter out nix build files + || baseName == "Cargo.nix" + + # Filter out editor backup / swap files. + || lib.hasSuffix "~" baseName + || builtins.match "^\\.sw[a-z]$$" baseName != null + || builtins.match "^\\..*\\.sw[a-z]$$" baseName != null + || lib.hasSuffix ".tmp" baseName + || lib.hasSuffix ".bak" baseName + || baseName == "tests.nix" + ); + + /* Returns a crate which depends on successful test execution + of crate given as the second argument. + + testCrateFlags: list of flags to pass to the test exectuable + testInputs: list of packages that should be available during test execution + */ + crateWithTest = { crate, testCrate, testCrateFlags, testInputs, testPreRun, testPostRun }: + assert builtins.typeOf testCrateFlags == "list"; + assert builtins.typeOf testInputs == "list"; + assert builtins.typeOf testPreRun == "string"; + assert builtins.typeOf testPostRun == "string"; + let + # override the `crate` so that it will build and execute tests instead of + # building the actual lib and bin targets We just have to pass `--test` + # to rustc and it will do the right thing. We execute the tests and copy + # their log and the test executables to $out for later inspection. + test = + let + drv = testCrate.override + ( + _: { + buildTests = true; + } + ); + # If the user hasn't set any pre/post commands, we don't want to + # insert empty lines. This means that any existing users of crate2nix + # don't get a spurious rebuild unless they set these explicitly. + testCommand = pkgs.lib.concatStringsSep "\n" + (pkgs.lib.filter (s: s != "") [ + testPreRun + "$f $testCrateFlags 2>&1 | tee -a $out" + testPostRun + ]); + in + pkgs.runCommand "run-tests-${testCrate.name}" + { + inherit testCrateFlags; + buildInputs = testInputs; + } '' + set -ex + + export RUST_BACKTRACE=1 + + # recreate a file hierarchy as when running tests with cargo + + # the source for test data + ${pkgs.xorg.lndir}/bin/lndir ${crate.src} + + # build outputs + testRoot=target/debug + mkdir -p $testRoot + + # executables of the crate + # we copy to prevent std::env::current_exe() to resolve to a store location + for i in ${crate}/bin/*; do + cp "$i" "$testRoot" + done + chmod +w -R . + + # test harness executables are suffixed with a hash, like cargo does + # this allows to prevent name collision with the main + # executables of the crate + hash=$(basename $out) + for file in ${drv}/tests/*; do + f=$testRoot/$(basename $file)-$hash + cp $file $f + ${testCommand} + done + ''; + in + pkgs.runCommand "${crate.name}-linked" + { + inherit (crate) outputs crateName; + passthru = (crate.passthru or { }) // { + inherit test; + }; + } '' + echo tested by ${test} + ${lib.concatMapStringsSep "\n" (output: "ln -s ${crate.${output}} ${"$"}${output}") crate.outputs} + ''; + + /* A restricted overridable version of builtRustCratesWithFeatures. */ + buildRustCrateWithFeatures = + { packageId + , features ? rootFeatures + , crateOverrides ? defaultCrateOverrides + , buildRustCrateForPkgsFunc ? null + , runTests ? false + , testCrateFlags ? [ ] + , testInputs ? [ ] + # Any command to run immediatelly before a test is executed. + , testPreRun ? "" + # Any command run immediatelly after a test is executed. + , testPostRun ? "" + }: + lib.makeOverridable + ( + { features + , crateOverrides + , runTests + , testCrateFlags + , testInputs + , testPreRun + , testPostRun + }: + let + buildRustCrateForPkgsFuncOverriden = + if buildRustCrateForPkgsFunc != null + then buildRustCrateForPkgsFunc + else + ( + if crateOverrides == pkgs.defaultCrateOverrides + then buildRustCrateForPkgs + else + pkgs: (buildRustCrateForPkgs pkgs).override { + defaultCrateOverrides = crateOverrides; + } + ); + builtRustCrates = builtRustCratesWithFeatures { + inherit packageId features; + buildRustCrateForPkgsFunc = buildRustCrateForPkgsFuncOverriden; + runTests = false; + }; + builtTestRustCrates = builtRustCratesWithFeatures { + inherit packageId features; + buildRustCrateForPkgsFunc = buildRustCrateForPkgsFuncOverriden; + runTests = true; + }; + drv = builtRustCrates.crates.${packageId}; + testDrv = builtTestRustCrates.crates.${packageId}; + derivation = + if runTests then + crateWithTest + { + crate = drv; + testCrate = testDrv; + inherit testCrateFlags testInputs testPreRun testPostRun; + } + else drv; + in + derivation + ) + { inherit features crateOverrides runTests testCrateFlags testInputs testPreRun testPostRun; }; + + /* Returns an attr set with packageId mapped to the result of buildRustCrateForPkgsFunc + for the corresponding crate. + */ + builtRustCratesWithFeatures = + { packageId + , features + , crateConfigs ? crates + , buildRustCrateForPkgsFunc + , runTests + , makeTarget ? makeDefaultTarget + } @ args: + assert (builtins.isAttrs crateConfigs); + assert (builtins.isString packageId); + assert (builtins.isList features); + assert (builtins.isAttrs (makeTarget stdenv.hostPlatform)); + assert (builtins.isBool runTests); + let + rootPackageId = packageId; + mergedFeatures = mergePackageFeatures + ( + args // { + inherit rootPackageId; + target = makeTarget stdenv.hostPlatform // { test = runTests; }; + } + ); + # Memoize built packages so that reappearing packages are only built once. + builtByPackageIdByPkgs = mkBuiltByPackageIdByPkgs pkgs; + mkBuiltByPackageIdByPkgs = pkgs: + let + self = { + crates = lib.mapAttrs (packageId: value: buildByPackageIdForPkgsImpl self pkgs packageId) crateConfigs; + target = makeTarget pkgs.stdenv.hostPlatform; + build = mkBuiltByPackageIdByPkgs pkgs.buildPackages; + }; + in + self; + buildByPackageIdForPkgsImpl = self: pkgs: packageId: + let + features = mergedFeatures."${packageId}" or [ ]; + crateConfig' = crateConfigs."${packageId}"; + crateConfig = + builtins.removeAttrs crateConfig' [ "resolvedDefaultFeatures" "devDependencies" ]; + devDependencies = + lib.optionals + (runTests && packageId == rootPackageId) + (crateConfig'.devDependencies or [ ]); + dependencies = + dependencyDerivations { + inherit features; + inherit (self) target; + buildByPackageId = depPackageId: + # proc_macro crates must be compiled for the build architecture + if crateConfigs.${depPackageId}.procMacro or false + then self.build.crates.${depPackageId} + else self.crates.${depPackageId}; + dependencies = + (crateConfig.dependencies or [ ]) + ++ devDependencies; + }; + buildDependencies = + dependencyDerivations { + inherit features; + inherit (self.build) target; + buildByPackageId = depPackageId: + self.build.crates.${depPackageId}; + dependencies = crateConfig.buildDependencies or [ ]; + }; + dependenciesWithRenames = + let + buildDeps = filterEnabledDependencies { + inherit features; + inherit (self) target; + dependencies = crateConfig.dependencies or [ ] ++ devDependencies; + }; + hostDeps = filterEnabledDependencies { + inherit features; + inherit (self.build) target; + dependencies = crateConfig.buildDependencies or [ ]; + }; + in + lib.filter (d: d ? "rename") (hostDeps ++ buildDeps); + # Crate renames have the form: + # + # { + # crate_name = [ + # { version = "1.2.3"; rename = "crate_name01"; } + # ]; + # # ... + # } + crateRenames = + let + grouped = + lib.groupBy + (dependency: dependency.name) + dependenciesWithRenames; + versionAndRename = dep: + let + package = crateConfigs."${dep.packageId}"; + in + { inherit (dep) rename; version = package.version; }; + in + lib.mapAttrs (name: choices: builtins.map versionAndRename choices) grouped; + in + buildRustCrateForPkgsFunc pkgs + ( + crateConfig // { + src = crateConfig.src or ( + pkgs.fetchurl rec { + name = "${crateConfig.crateName}-${crateConfig.version}.tar.gz"; + # https://www.pietroalbini.org/blog/downloading-crates-io/ + # Not rate-limited, CDN URL. + url = "https://static.crates.io/crates/${crateConfig.crateName}/${crateConfig.crateName}-${crateConfig.version}.crate"; + sha256 = + assert (lib.assertMsg (crateConfig ? sha256) "Missing sha256 for ${name}"); + crateConfig.sha256; + } + ); + extraRustcOpts = lib.lists.optional (targetFeatures != [ ]) "-C target-feature=${lib.concatMapStringsSep "," (x: "+${x}") targetFeatures}"; + inherit features dependencies buildDependencies crateRenames release; + } + ); + in + builtByPackageIdByPkgs; + + /* Returns the actual derivations for the given dependencies. */ + dependencyDerivations = + { buildByPackageId + , features + , dependencies + , target + }: + assert (builtins.isList features); + assert (builtins.isList dependencies); + assert (builtins.isAttrs target); + let + enabledDependencies = filterEnabledDependencies { + inherit dependencies features target; + }; + depDerivation = dependency: buildByPackageId dependency.packageId; + in + map depDerivation enabledDependencies; + + /* Returns a sanitized version of val with all values substituted that cannot + be serialized as JSON. + */ + sanitizeForJson = val: + if builtins.isAttrs val + then lib.mapAttrs (n: v: sanitizeForJson v) val + else if builtins.isList val + then builtins.map sanitizeForJson val + else if builtins.isFunction val + then "function" + else val; + + /* Returns various tools to debug a crate. */ + debugCrate = { packageId, target ? makeDefaultTarget stdenv.hostPlatform }: + assert (builtins.isString packageId); + let + debug = rec { + # The built tree as passed to buildRustCrate. + buildTree = buildRustCrateWithFeatures { + buildRustCrateForPkgsFunc = _: lib.id; + inherit packageId; + }; + sanitizedBuildTree = sanitizeForJson buildTree; + dependencyTree = sanitizeForJson + ( + buildRustCrateWithFeatures { + buildRustCrateForPkgsFunc = _: crate: { + "01_crateName" = crate.crateName or false; + "02_features" = crate.features or [ ]; + "03_dependencies" = crate.dependencies or [ ]; + }; + inherit packageId; + } + ); + mergedPackageFeatures = mergePackageFeatures { + features = rootFeatures; + inherit packageId target; + }; + diffedDefaultPackageFeatures = diffDefaultPackageFeatures { + inherit packageId target; + }; + }; + in + { internal = debug; }; + + /* Returns differences between cargo default features and crate2nix default + features. + + This is useful for verifying the feature resolution in crate2nix. + */ + diffDefaultPackageFeatures = + { crateConfigs ? crates + , packageId + , target + }: + assert (builtins.isAttrs crateConfigs); + let + prefixValues = prefix: lib.mapAttrs (n: v: { "${prefix}" = v; }); + mergedFeatures = + prefixValues + "crate2nix" + (mergePackageFeatures { inherit crateConfigs packageId target; features = [ "default" ]; }); + configs = prefixValues "cargo" crateConfigs; + combined = lib.foldAttrs (a: b: a // b) { } [ mergedFeatures configs ]; + onlyInCargo = + builtins.attrNames + (lib.filterAttrs (n: v: !(v ? "crate2nix") && (v ? "cargo")) combined); + onlyInCrate2Nix = + builtins.attrNames + (lib.filterAttrs (n: v: (v ? "crate2nix") && !(v ? "cargo")) combined); + differentFeatures = lib.filterAttrs + ( + n: v: + (v ? "crate2nix") + && (v ? "cargo") + && (v.crate2nix.features or [ ]) != (v."cargo".resolved_default_features or [ ]) + ) + combined; + in + builtins.toJSON { + inherit onlyInCargo onlyInCrate2Nix differentFeatures; + }; + + /* Returns an attrset mapping packageId to the list of enabled features. + + If multiple paths to a dependency enable different features, the + corresponding feature sets are merged. Features in rust are additive. + */ + mergePackageFeatures = + { crateConfigs ? crates + , packageId + , rootPackageId ? packageId + , features ? rootFeatures + , dependencyPath ? [ crates.${packageId}.crateName ] + , featuresByPackageId ? { } + , target + # Adds devDependencies to the crate with rootPackageId. + , runTests ? false + , ... + } @ args: + assert (builtins.isAttrs crateConfigs); + assert (builtins.isString packageId); + assert (builtins.isString rootPackageId); + assert (builtins.isList features); + assert (builtins.isList dependencyPath); + assert (builtins.isAttrs featuresByPackageId); + assert (builtins.isAttrs target); + assert (builtins.isBool runTests); + let + crateConfig = crateConfigs."${packageId}" or (builtins.throw "Package not found: ${packageId}"); + expandedFeatures = expandFeatures (crateConfig.features or { }) features; + enabledFeatures = enableFeatures (crateConfig.dependencies or [ ]) expandedFeatures; + depWithResolvedFeatures = dependency: + let + packageId = dependency.packageId; + features = dependencyFeatures enabledFeatures dependency; + in + { inherit packageId features; }; + resolveDependencies = cache: path: dependencies: + assert (builtins.isAttrs cache); + assert (builtins.isList dependencies); + let + enabledDependencies = filterEnabledDependencies { + inherit dependencies target; + features = enabledFeatures; + }; + directDependencies = map depWithResolvedFeatures enabledDependencies; + foldOverCache = op: lib.foldl op cache directDependencies; + in + foldOverCache + ( + cache: { packageId, features }: + let + cacheFeatures = cache.${packageId} or [ ]; + combinedFeatures = sortedUnique (cacheFeatures ++ features); + in + if cache ? ${packageId} && cache.${packageId} == combinedFeatures + then cache + else + mergePackageFeatures { + features = combinedFeatures; + featuresByPackageId = cache; + inherit crateConfigs packageId target runTests rootPackageId; + } + ); + cacheWithSelf = + let + cacheFeatures = featuresByPackageId.${packageId} or [ ]; + combinedFeatures = sortedUnique (cacheFeatures ++ enabledFeatures); + in + featuresByPackageId // { + "${packageId}" = combinedFeatures; + }; + cacheWithDependencies = + resolveDependencies cacheWithSelf "dep" + ( + crateConfig.dependencies or [ ] + ++ lib.optionals + (runTests && packageId == rootPackageId) + (crateConfig.devDependencies or [ ]) + ); + cacheWithAll = + resolveDependencies + cacheWithDependencies "build" + (crateConfig.buildDependencies or [ ]); + in + cacheWithAll; + + /* Returns the enabled dependencies given the enabled features. */ + filterEnabledDependencies = { dependencies, features, target }: + assert (builtins.isList dependencies); + assert (builtins.isList features); + assert (builtins.isAttrs target); + + lib.filter + ( + dep: + let + targetFunc = dep.target or (features: true); + in + targetFunc { inherit features target; } + && ( + !(dep.optional or false) + || builtins.any (doesFeatureEnableDependency dep) features + ) + ) + dependencies; + + /* Returns whether the given feature should enable the given dependency. */ + doesFeatureEnableDependency = dependency: feature: + let + name = dependency.rename or dependency.name; + prefix = "${name}/"; + len = builtins.stringLength prefix; + startsWithPrefix = builtins.substring 0 len feature == prefix; + in + feature == name || feature == "dep:" + name || startsWithPrefix; + + /* Returns the expanded features for the given inputFeatures by applying the + rules in featureMap. + + featureMap is an attribute set which maps feature names to lists of further + feature names to enable in case this feature is selected. + */ + expandFeatures = featureMap: inputFeatures: + assert (builtins.isAttrs featureMap); + assert (builtins.isList inputFeatures); + let + expandFeature = feature: + assert (builtins.isString feature); + [ feature ] ++ (expandFeatures featureMap (featureMap."${feature}" or [ ])); + outFeatures = lib.concatMap expandFeature inputFeatures; + in + sortedUnique outFeatures; + + /* This function adds optional dependencies as features if they are enabled + indirectly by dependency features. This function mimics Cargo's behavior + described in a note at: + https://doc.rust-lang.org/nightly/cargo/reference/features.html#dependency-features + */ + enableFeatures = dependencies: features: + assert (builtins.isList features); + assert (builtins.isList dependencies); + let + additionalFeatures = lib.concatMap + ( + dependency: + assert (builtins.isAttrs dependency); + let + enabled = builtins.any (doesFeatureEnableDependency dependency) features; + in + if (dependency.optional or false) && enabled + then [ (dependency.rename or dependency.name) ] + else [ ] + ) + dependencies; + in + sortedUnique (features ++ additionalFeatures); + + /* + Returns the actual features for the given dependency. + + features: The features of the crate that refers this dependency. + */ + dependencyFeatures = features: dependency: + assert (builtins.isList features); + assert (builtins.isAttrs dependency); + let + defaultOrNil = + if dependency.usesDefaultFeatures or true + then [ "default" ] + else [ ]; + explicitFeatures = dependency.features or [ ]; + additionalDependencyFeatures = + let + dependencyPrefix = (dependency.rename or dependency.name) + "/"; + dependencyFeatures = + builtins.filter (f: lib.hasPrefix dependencyPrefix f) features; + in + builtins.map (lib.removePrefix dependencyPrefix) dependencyFeatures; + in + defaultOrNil ++ explicitFeatures ++ additionalDependencyFeatures; + + /* Sorts and removes duplicates from a list of strings. */ + sortedUnique = features: + assert (builtins.isList features); + assert (builtins.all builtins.isString features); + let + outFeaturesSet = lib.foldl (set: feature: set // { "${feature}" = 1; }) { } features; + outFeaturesUnique = builtins.attrNames outFeaturesSet; + in + builtins.sort (a: b: a < b) outFeaturesUnique; + + deprecationWarning = message: value: + if strictDeprecation + then builtins.throw "strictDeprecation enabled, aborting: ${message}" + else builtins.trace message value; + + # + # crate2nix/default.nix (excerpt end) + # + }; +} + @@ -1,35 +1,30 @@ [package] name = "rustables" -version = "0.7.0" -resolver = "2" -authors = ["lafleur@boum.org, Simon Thoby, Mullvad VPN"] +version = "0.8.0" +authors = ["lafleur@boum.org", "Simon Thoby", "Mullvad VPN"] license = "GPL-3.0-or-later" description = "Safe abstraction for libnftnl. Provides low-level userspace access to the in-kernel nf_tables subsystem" repository = "https://gitlab.com/rustwall/rustables" readme = "README.md" keywords = ["nftables", "nft", "firewall", "iptables", "netfilter"] categories = ["network-programming", "os::unix-apis", "api-bindings"] -edition = "2018" +resolver = "2" +edition = "2021" [features] -query = [] -unsafe-raw-handles = [] [dependencies] bitflags = "1.0" thiserror = "1.0" log = "0.4" libc = "0.2.43" -mnl = "0.2" -ipnetwork = "0.16" -serde = { version = "1.0", features = ["derive"] } +nix = "0.23" +ipnetwork = { version = "0.20", default-features = false } +rustables-macros = "0.1.0" [dev-dependencies] -rustables = { path = ".", features = ["query"] } +env_logger = "0.9" [build-dependencies] bindgen = "0.53.1" -pkg-config = "0.3" regex = "1.5.4" -lazy_static = "1.4.0" - @@ -1,36 +1,23 @@ # rustables -Safe abstraction for [`libnftnl`]. Provides low-level userspace access to the -in-kernel nf_tables subsystem. See [`rustables-sys`] for the low level FFI -bindings to the C library. - -Can be used to create, list and remove tables, chains, sets and rules from the -nftables firewall, the successor to iptables. - -This library is directly derived from the [`nftnl-rs`] crate. Let us thank here -the original project team for their great work without which this library would -probably not exist today. - -It currently has quite rough edges and does not make adding and removing -netfilter entries super easy and elegant. That is partly because the library -needs more work, but also partly because nftables is super low level and -extremely customizable, making it hard, and probably wrong, to try and create a -too simple/limited wrapper. See examples for inspiration. One can also look -at how the original project this crate was developed to support uses it : -[Mullvad VPN app]. - -Understanding how to use [`libnftnl`] and implementing this crate has mostly -been done by reading the source code for the [`nftables`] program and attaching -debuggers to the `nft` binary. Since the implementation is mostly based on -trial and error, there might of course be a number of places where the -underlying library is used in an invalid or not intended way. Large portions -of [`libnftnl`] are also not covered yet. Contributions are welcome! - -## Supported versions of `libnftnl` - -This crate will automatically link to the currently installed version of -libnftnl upon build. It requires libnftnl version 1.0.6 or higher. See how the -low level FFI bindings to the C library are generated in [`build.rs`]. +Safe abstraction for userspace access to the in-kernel nf_tables subsystem. +Can be used to create and remove tables, chains, sets and rules from the nftables +firewall, the successor to iptables. + +This library is a fork of the [`nftnl-rs`] crate. Let us thank here the original project +team for their great work without which this library would probably not exist today. + +This library currently has quite rough edges and does not make adding and removing netfilter +entries super easy and elegant. That is partly because the library needs more work, but also +partly because nftables is super low level and extremely customizable, making it hard, and +probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. + +Understanding how to use the netlink subsystem and implementing this crate has mostly been done by +reading the source code for the [`nftables`] userspace program and its corresponding kernel code, +as well as attaching debuggers to the `nft` binary. +Since the implementation is mostly based on trial and error, there might of course be +a number of places where the forged netlink messages are used in an invalid or not intended way. +Contributions are welcome! ## Licensing @@ -39,8 +26,5 @@ License: GNU GPLv3 Original work licensed by Amagicom AB under MIT/Apache-2.0. [`nftnl-rs`]: https://github.com/mullvad/nftnl-rs -[Mullvad VPN app]: https://github.com/mullvad/mullvadvpn-app -[`libnftnl`]: https://netfilter.org/projects/libnftnl/ [`nftables`]: https://netfilter.org/projects/nftables/ -[`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs @@ -1,10 +1,6 @@ -//! This build script leverages `bindgen` to generate rust sys files that link to the libnftnl -//! library. It retrieves the includes needed by `bindgen` using `pkg_config`, and tells cargo -//! the directives needed by the linker to link against the exported symbols. +//! This build script leverages `bindgen` to generate rust sys files. use bindgen; -use lazy_static::lazy_static; -use pkg_config; use regex::{Captures, Regex}; use std::borrow::Cow; use std::env; @@ -13,141 +9,22 @@ use std::io::Write; use std::path::PathBuf; const SYS_HEADER_FILE: &str = "include/wrapper.h"; -const SYS_BINDINGS_FILE: &str = "src/sys.rs"; -const TESTS_HEADER_FILE: &str = "include/tests_wrapper.h"; -const TESTS_BINDINGS_FILE: &str = "tests/sys.rs"; -const MIN_LIBNFTNL_VERSION: &str = "1.0.6"; -const MIN_LIBMNL_VERSION: &str = "1.0.0"; - fn main() { - pkg_config_mnl(); - let clang_args = pkg_config_nftnl(); - generate_sys(clang_args.into_iter()); - generate_tests_sys(); -} - -/// Setup rust linking directives for libnftnl, and return the include directory list. -fn pkg_config_nftnl() -> Vec<String> { - let mut res = vec![]; - - if let Some(lib_dir) = get_env("LIBNFTNL_LIB_DIR") { - if !lib_dir.is_dir() { - panic!( - "libnftnl library directory does not exist: {}", - lib_dir.display() - ); - } - println!("cargo:rustc-link-search=native={}", lib_dir.display()); - println!("cargo:rustc-link-lib=nftnl"); - } else { - // Trying with pkg-config instead - println!("Minimum libnftnl version: {}", MIN_LIBNFTNL_VERSION); - let pkg_config_res = pkg_config::Config::new() - .atleast_version(MIN_LIBNFTNL_VERSION) - .probe("libnftnl") - .unwrap(); - for path in pkg_config_res.include_paths { - res.push(format!("-I{}", path.to_str().unwrap())); - } - } - - res -} - -/// Setup rust linking directives for libmnl. -fn pkg_config_mnl() { - if let Some(lib_dir) = get_env("LIBMNL_LIB_DIR") { - if !lib_dir.is_dir() { - panic!( - "libmnl library directory does not exist: {}", - lib_dir.display() - ); - } - println!("cargo:rustc-link-search=native={}", lib_dir.display()); - println!("cargo:rustc-link-lib=mnl"); - } else { - // Trying with pkg-config instead - pkg_config::Config::new() - .atleast_version(MIN_LIBMNL_VERSION) - .probe("libmnl") - .unwrap(); - } -} - -fn get_env(var: &'static str) -> Option<PathBuf> { - println!("cargo:rerun-if-env-changed={}", var); - env::var_os(var).map(PathBuf::from) + generate_sys(); } -/// `bindgen`erate a rust sys file from the C headers of the nftnl library. -fn generate_sys(clang_args: impl Iterator<Item = String>) { +/// `bindgen`erate a rust sys file from the C kernel headers of the nf_tables capabilities. +fn generate_sys() { // Tell cargo to invalidate the built crate whenever the headers change. println!("cargo:rerun-if-changed={}", SYS_HEADER_FILE); let bindings = bindgen::Builder::default() .header(SYS_HEADER_FILE) - .clang_args(clang_args) - .generate_comments(false) - .prepend_enum_name(false) - .use_core() - .whitelist_function("^nftnl_.+$") - .whitelist_type("^nftnl_.+$") - .whitelist_var("^nftnl_.+$") - .whitelist_var("^NFTNL_.+$") - .blacklist_type("(FILE|iovec)") - .blacklist_type("^_IO_.+$") - .blacklist_type("^__.+$") - .blacklist_type("nlmsghdr") - .raw_line("#![allow(non_camel_case_types)]\n\n") - .raw_line("pub use libc;") - .raw_line("use libc::{c_char, c_int, c_ulong, c_void, iovec, nlmsghdr, FILE};") - .raw_line("use core::option::Option;") - .ctypes_prefix("libc") - // Tell cargo to invalidate the built crate whenever any of the - // included header files changed. - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - // Finish the builder and generate the bindings. - .generate() - // Unwrap the Result and panic on failure. - .expect("Error: unable to generate bindings"); - - let mut s = bindings.to_string() - // Add newlines because in alpine bindgen doesn't add them after - // statements. - .replace(" ; ", ";\n") - .replace("#[derive(Debug, Copy, Clone)]", ""); - let re = Regex::new(r"libc::(c_[a-z]*)").unwrap(); - s = re.replace_all(&s, "$1").into(); - let re = Regex::new(r"::core::option::(Option)").unwrap(); - s = re.replace_all(&s, "$1").into(); - let re = Regex::new(r"_bindgen_ty_[0-9]+").unwrap(); - s = re.replace_all(&s, "u32").into(); - // Change struct bodies to c_void. - let re = Regex::new(r"(pub struct .*) \{\n *_unused: \[u8; 0\],\n\}\n").unwrap(); - s = re.replace_all(&s, "$1(c_void);\n").into(); - let re = Regex::new(r"pub type u32 = u32;\n").unwrap(); - s = re.replace_all(&s, "").into(); - - // Write the bindings to the rust header file. - let out_path = PathBuf::from(SYS_BINDINGS_FILE); - File::create(out_path) - .expect("Error: could not create rust header file.") - .write_all(&s.as_bytes()) - .expect("Error: could not write to the rust header file."); -} - -/// `bindgen`erate a rust sys file from the C kernel headers of the nf_tables capabilities. -/// Used in the rustables tests. -fn generate_tests_sys() { - // Tell cargo to invalidate the built crate whenever the headers change. - println!("cargo:rerun-if-changed={}", TESTS_HEADER_FILE); - - let bindings = bindgen::Builder::default() - .header(TESTS_HEADER_FILE) .generate_comments(false) .prepend_enum_name(false) - .raw_line("#![allow(non_camel_case_types, dead_code)]\n\n") + .layout_tests(false) + .derive_partialeq(true) // Tell cargo to invalidate the built crate whenever any of the // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) @@ -161,7 +38,7 @@ fn generate_tests_sys() { let s = reformat_units(&s); // Write the bindings to the rust header file. - let out_path = PathBuf::from(TESTS_BINDINGS_FILE); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("sys.rs"); File::create(out_path) .expect("Error: could not create rust header file.") .write_all(&s.as_bytes()) @@ -170,11 +47,8 @@ fn generate_tests_sys() { /// Recast nft_*_attributes from u32 to u16 in header string `header`. fn reformat_units(header: &str) -> Cow<str> { - lazy_static! { - static ref RE: Regex = Regex::new(r"(pub type nft[a-zA-Z_]*_attributes) = u32;").unwrap(); - } - RE.replace_all(header, |captures: &Captures| { + let re = Regex::new(r"(pub type nft[a-zA-Z_]*_attributes) = u32;").unwrap(); + re.replace_all(header, |captures: &Captures| { format!("{} = u16;", &captures[1]) }) } - diff --git a/examples/add-rules.rs b/examples/add-rules.rs index 3aae7ee..a2b9c9c 100644 --- a/examples/add-rules.rs +++ b/examples/add-rules.rs @@ -37,212 +37,155 @@ //! ``` use ipnetwork::{IpNetwork, Ipv4Network}; -use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; -use std::{ - ffi::{self, CString}, - io, - net::Ipv4Addr, - rc::Rc +use rustables::{ + data_type::ip_to_vec, + expr::{ + Bitwise, Cmp, CmpOp, Counter, HighLevelPayload, ICMPv6HeaderField, IPv4HeaderField, + IcmpCode, Immediate, Meta, MetaType, NetworkHeaderField, TransportHeaderField, VerdictKind, + }, + iface_index, Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table, }; +use std::net::Ipv4Addr; const TABLE_NAME: &str = "example-table"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; const IN_CHAIN_NAME: &str = "chain-for-incoming-packets"; fn main() -> Result<(), Error> { + env_logger::init(); + // Create a batch. This is used to store all the netlink messages we will later send. // Creating a new batch also automatically writes the initial batch begin message needed // to tell netlink this is a single transaction that might arrive over multiple netlink packets. let mut batch = Batch::new(); // Create a netfilter table operating on both IPv4 and IPv6 (ProtoFamily::Inet) - let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet)); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); // Add the table to the batch with the `MsgType::Add` type, thus instructing netfilter to add - // this table under its `ProtoFamily::Inet` ruleset. - batch.add(&Rc::clone(&table), rustables::MsgType::Add); + // this table under its `ProtocolFamily::Inet` ruleset. + batch.add(&table, MsgType::Add); // Create input and output chains under the table we created above. // Hook the chains to the input and output event hooks, with highest priority (priority zero). - // See the `Chain::set_hook` documentation for details. - let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table)); - let mut in_chain = Chain::new(&CString::new(IN_CHAIN_NAME).unwrap(), Rc::clone(&table)); + let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME); + let mut in_chain = Chain::new(&table).with_name(IN_CHAIN_NAME); - out_chain.set_hook(rustables::Hook::Out, 0); - in_chain.set_hook(rustables::Hook::In, 0); + out_chain.set_hook(Hook::new(HookClass::Out, 0)); + in_chain.set_hook(Hook::new(HookClass::In, 0)); // Set the default policies on the chains. If no rule matches a packet processed by the // `out_chain` or the `in_chain` it will accept the packet. - out_chain.set_policy(rustables::Policy::Accept); - in_chain.set_policy(rustables::Policy::Accept); - - let out_chain = Rc::new(out_chain); - let in_chain = Rc::new(in_chain); + out_chain.set_policy(ChainPolicy::Accept); + in_chain.set_policy(ChainPolicy::Accept); // Add the two chains to the batch with the `MsgType` to tell netfilter to create the chains // under the table. - batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add); - batch.add(&Rc::clone(&in_chain), rustables::MsgType::Add); + batch.add(&out_chain, MsgType::Add); + batch.add(&in_chain, MsgType::Add); // === ADD RULE ALLOWING ALL TRAFFIC TO THE LOOPBACK DEVICE === - // Create a new rule object under the input chain. - let mut allow_loopback_in_rule = Rule::new(Rc::clone(&in_chain)); // Lookup the interface index of the loopback interface. let lo_iface_index = iface_index("lo")?; - // First expression to be evaluated in this rule is load the meta information "iif" - // (incoming interface index) into the comparison register of netfilter. - // When an incoming network packet is processed by this rule it will first be processed by this - // expression, which will load the interface index of the interface the packet came from into - // a special "register" in netfilter. - allow_loopback_in_rule.add_expr(&nft_expr!(meta iif)); - // Next expression in the rule is to compare the value loaded into the register with our desired - // interface index, and succeed only if it's equal. For any packet processed where the equality - // does not hold the packet is said to not match this rule, and the packet moves on to be - // processed by the next rule in the chain instead. - allow_loopback_in_rule.add_expr(&nft_expr!(cmp == lo_iface_index)); - - // Add a verdict expression to the rule. Any packet getting this far in the expression - // processing without failing any expression will be given the verdict added here. - allow_loopback_in_rule.add_expr(&nft_expr!(verdict accept)); + // Create a new rule object under the input chain. + let allow_loopback_in_rule = Rule::new(&in_chain)? + // First expression to be evaluated in this rule is load the meta information "iif" + // (incoming interface index) into the comparison register of netfilter. + // When an incoming network packet is processed by this rule it will first be processed by this + // expression, which will load the interface index of the interface the packet came from into + // a special "register" in netfilter. + .with_expr(Meta::new(MetaType::Iif)) + + // Next expression in the rule is to compare the value loaded into the register with our desired + // interface index, and succeed only if it's equal. For any packet processed where the equality + // does not hold the packet is said to not match this rule, and the packet moves on to be + // processed by the next rule in the chain instead. + .with_expr(Cmp::new(CmpOp::Eq, lo_iface_index.to_le_bytes())) + + // Add a verdict expression to the rule. Any packet getting this far in the expression + // processing without failing any expression will be given the verdict added here. + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); // Add the rule to the batch. batch.add(&allow_loopback_in_rule, rustables::MsgType::Add); // === ADD A RULE ALLOWING (AND COUNTING) ALL PACKETS TO THE 10.1.0.0/24 NETWORK === - let mut block_out_to_private_net_rule = Rule::new(Rc::clone(&out_chain)); let private_net_ip = Ipv4Addr::new(10, 1, 0, 0); let private_net_prefix = 24; let private_net = IpNetwork::V4(Ipv4Network::new(private_net_ip, private_net_prefix)?); - // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 - // protocol the packet being processed is using. - block_out_to_private_net_rule.add_expr(&nft_expr!(meta nfproto)); - // Check if the currently processed packet is an IPv4 packet. This must be done before payload - // data assuming the packet uses IPv4 can be loaded in the next expression. - block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - - // Load the IPv4 destination address into the netfilter register. - block_out_to_private_net_rule.add_expr(&nft_expr!(payload ipv4 daddr)); - // Mask out the part of the destination address that is not part of the network bits. The result - // of this bitwise masking is stored back into the same netfilter register. - block_out_to_private_net_rule.add_expr(&nft_expr!(bitwise mask private_net.mask(), xor 0)); - // Compare the result of the masking with the IP of the network we are interested in. - block_out_to_private_net_rule.add_expr(&nft_expr!(cmp == private_net.ip())); - - // Add a packet counter to the rule. Shows how many packets have been evaluated against this - // expression. Since expressions are evaluated from first to last, putting this counter before - // the above IP net check would make the counter increment on all packets also *not* matching - // those expressions. Because the counter would then be evaluated before it fails a check. - // Similarly, if the counter was added after the verdict it would always remain at zero. Since - // when the packet hits the verdict expression any further processing of expressions stop. - block_out_to_private_net_rule.add_expr(&nft_expr!(counter)); - - // Accept all the packets matching the rule so far. - block_out_to_private_net_rule.add_expr(&nft_expr!(verdict accept)); + let block_out_to_private_net_rule = Rule::new(&out_chain)? + // Load the `nfproto` metadata into the netfilter register. This metadata denotes which layer3 + // protocol the packet being processed is using. + .with_expr(Meta::new(MetaType::NfProto)) - // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, - // and all the work on `block_out_to_private_net_rule` so far would go to waste. - batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); + // Check if the currently processed packet is an IPv4 packet. This must be done before payload + // data assuming the packet uses IPv4 can be loaded in the next expression. + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])) - // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === + // Load the IPv4 destination address into the netfilter register. + .with_expr(HighLevelPayload::Network(NetworkHeaderField::IPv4(IPv4HeaderField::Daddr)).build()) - let mut allow_router_solicitation = Rule::new(Rc::clone(&out_chain)); + // Mask out the part of the destination address that is not part of the network bits. The result + // of this bitwise masking is stored back into the same netfilter register. + .with_expr(Bitwise::new(ip_to_vec(private_net.mask()), [0u8; 4])?) - // Check that the packet is IPv6 and ICMPv6 - allow_router_solicitation.add_expr(&nft_expr!(meta nfproto)); - allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - allow_router_solicitation.add_expr(&nft_expr!(meta l4proto)); - allow_router_solicitation.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); + // Compare the result of the masking with the IP of the network we are interested in. + .with_expr(Cmp::new(CmpOp::Eq, ip_to_vec(private_net.ip()))) - allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Type), - )); - allow_router_solicitation.add_expr(&nft_expr!(cmp == 133u8)); - allow_router_solicitation.add_expr(&rustables::expr::Payload::Transport( - rustables::expr::TransportHeaderField::Icmpv6(rustables::expr::Icmpv6HeaderField::Code), - )); - allow_router_solicitation.add_expr(&nft_expr!(cmp == 0u8)); + // Add a packet counter to the rule. Shows how many packets have been evaluated against this + // expression. Since expressions are evaluated from first to last, putting this counter before + // the above IP net check would make the counter increment on all packets also *not* matching + // those expressions. Because the counter would then be evaluated before it fails a check. + // Similarly, if the counter was added after the verdict it would always remain at zero. Since + // when the packet hits the verdict expression any further processing of expressions stop. + .with_expr(Counter::default()) - allow_router_solicitation.add_expr(&nft_expr!(verdict accept)); + // Accept all the packets matching the rule so far. + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); - batch.add(&allow_router_solicitation, rustables::MsgType::Add); + // Add the rule to the batch. Without this nothing would be sent over netlink and netfilter, + // and all the work on `block_out_to_private_net_rule` so far would go to waste. + batch.add(&block_out_to_private_net_rule, rustables::MsgType::Add); - // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === + // === ADD A RULE ALLOWING ALL OUTGOING ICMPv6 PACKETS WITH TYPE 133 AND CODE 0 === - // Finalize the batch. This means the batch end message is written into the batch, telling - // netfilter the we reached the end of the transaction message. It's also converted to a type - // that implements `IntoIterator<Item = &'a [u8]>`, thus allowing us to get the raw netlink data - // out so it can be sent over a netlink socket to netfilter. - match batch.finalize() { - Some(mut finalized_batch) => { - // Send the entire batch and process any returned messages. - send_and_process(&mut finalized_batch)?; - Ok(()) - }, - None => todo!() - } -} + let allow_router_solicitation = Rule::new(&out_chain)? + // Check that the packet is IPv6 and ICMPv6 + .with_expr(Meta::new(MetaType::NfProto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])) + .with_expr(Meta::new(MetaType::L4Proto)) + .with_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMPV6 as u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Type)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [133u8])) + .with_expr( + HighLevelPayload::Transport(TransportHeaderField::ICMPv6(ICMPv6HeaderField::Code)) + .build(), + ) + .with_expr(Cmp::new(CmpOp::Eq, [IcmpCode::NoRoute as u8])) + .with_expr(Immediate::new_verdict(VerdictKind::Accept)); -// Look up the interface index for a given interface name. -fn iface_index(name: &str) -> Result<libc::c_uint, Error> { - let c_name = CString::new(name)?; - let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; - if index == 0 { - Err(Error::from(io::Error::last_os_error())) - } else { - Ok(index) - } -} + batch.add(&allow_router_solicitation, rustables::MsgType::Add); -fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> { - // Create a netlink socket to netfilter. - let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; - // Send all the bytes in the batch. - socket.send_all(&mut *batch)?; - - // Try to parse the messages coming back from netfilter. This part is still very unclear. - let portid = socket.portid(); - let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize]; - let very_unclear_what_this_is_for = 2; - while let Some(message) = socket_recv(&socket, &mut buffer[..])? { - match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? { - mnl::CbResult::Stop => { - break; - } - mnl::CbResult::Ok => (), - } - } - Ok(()) -} + // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === -fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> { - let ret = socket.recv(buf)?; - if ret > 0 { - Ok(Some(&buf[..ret])) - } else { - Ok(None) - } + // Finalize the batch and send it. This means the batch end message is written into the batch, telling + // netfilter the we reached the end of the transaction message. It's also converted to a + // Vec<u8>, containing the raw netlink data so it can be sent over a netlink socket to netfilter. + // Finally, the batch is sent over to the kernel. + Ok(batch.send()?) } #[derive(Debug)] struct Error(String); -impl From<io::Error> for Error { - fn from(error: io::Error) -> Self { - Error(error.to_string()) - } -} - -impl From<ffi::NulError> for Error { - fn from(error: ffi::NulError) -> Self { - Error(error.to_string()) - } -} - -impl From<ipnetwork::IpNetworkError> for Error { - fn from(error: ipnetwork::IpNetworkError) -> Self { +impl<T: std::error::Error> From<T> for Error { + fn from(error: T) -> Self { Error(error.to_string()) } } diff --git a/examples/filter-ethernet.rs b/examples/filter-ethernet.rs index b16c49e..a136731 100644 --- a/examples/filter-ethernet.rs +++ b/examples/filter-ethernet.rs @@ -1,63 +1,69 @@ -//! Adds a table, chain and a rule that blocks all traffic to a given MAC address -//! -//! Run the following to print out current active tables, chains and rules in netfilter. Must be -//! executed as root: -//! ```bash -//! # nft list ruleset -//! ``` -//! After running this example, the output should be the following: -//! ```ignore -//! table inet example-filter-ethernet { -//! chain chain-for-outgoing-packets { -//! type filter hook output priority 3; policy accept; -//! ether daddr 00:00:00:00:00:00 drop -//! counter packets 0 bytes 0 meta random > 2147483647 counter packets 0 bytes 0 -//! } -//! } -//! ``` -//! -//! -//! Everything created by this example can be removed by running -//! ```bash -//! # nft delete table inet example-filter-ethernet -//! ``` - -use rustables::{nft_expr, sys::libc, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; -use std::{ffi::CString, io, rc::Rc}; +///! Adds a table, chain and a rule that blocks all traffic to a given MAC address +///! +///! Run the following to print out current active tables, chains and rules in netfilter. Must be +///! executed as root: +///! ```bash +///! # nft list ruleset +///! ``` +///! After running this example, the output should be the following: +///! ```ignore +///! table inet example-filter-ethernet { +///! chain chain-for-outgoing-packets { +///! type filter hook output priority 3; policy accept; +///! ether daddr 01:02:03:04:05:06 drop +///! counter packets 0 bytes 0 meta random > 2147483647 counter packets 0 bytes 0 +///! } +///! } +///! ``` +///! +///! +///! Everything created by this example can be removed by running +///! ```bash +///! # nft delete table inet example-filter-ethernet +///! ``` +use rustables::{ + expr::{ + Cmp, CmpOp, Counter, ExpressionList, HighLevelPayload, Immediate, LLHeaderField, Meta, + MetaType, VerdictKind, + }, + Batch, Chain, ChainPolicy, Hook, HookClass, ProtocolFamily, Rule, Table, +}; const TABLE_NAME: &str = "example-filter-ethernet"; const OUT_CHAIN_NAME: &str = "chain-for-outgoing-packets"; -const BLOCK_THIS_MAC: &[u8] = &[0, 0, 0, 0, 0, 0]; +const BLOCK_THIS_MAC: &[u8] = &[1, 2, 3, 4, 5, 6]; -fn main() -> Result<(), Error> { +fn main() { // For verbose explanations of what all these lines up until the rule creation does, see the // `add-rules` example. let mut batch = Batch::new(); - let table = Rc::new(Table::new(&CString::new(TABLE_NAME).unwrap(), ProtoFamily::Inet)); - batch.add(&Rc::clone(&table), rustables::MsgType::Add); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); + batch.add(&table, rustables::MsgType::Add); - let mut out_chain = Chain::new(&CString::new(OUT_CHAIN_NAME).unwrap(), Rc::clone(&table)); - out_chain.set_hook(rustables::Hook::Out, 3); - out_chain.set_policy(rustables::Policy::Accept); - let out_chain = Rc::new(out_chain); - batch.add(&Rc::clone(&out_chain), rustables::MsgType::Add); + let mut out_chain = Chain::new(&table).with_name(OUT_CHAIN_NAME); + out_chain.set_hook(Hook::new(HookClass::Out, 3)); + out_chain.set_policy(ChainPolicy::Accept); + batch.add(&out_chain, rustables::MsgType::Add); // === ADD RULE DROPPING ALL TRAFFIC TO THE MAC ADDRESS IN `BLOCK_THIS_MAC` === - let mut block_ethernet_rule = Rule::new(Rc::clone(&out_chain)); + let mut block_ethernet_rule = Rule::new(&out_chain).unwrap(); - // Check that the interface type is an ethernet interface. Must be done before we can check - // payload values in the ethernet header. - block_ethernet_rule.add_expr(&nft_expr!(meta iiftype)); - block_ethernet_rule.add_expr(&nft_expr!(cmp == libc::ARPHRD_ETHER)); + block_ethernet_rule.set_expressions( + ExpressionList::default() + // Check that the interface type is an ethernet interface. Must be done before we can check + // payload values in the ethernet header. + .with_value(Meta::new(MetaType::IifType)) + .with_value(Cmp::new(CmpOp::Eq, (libc::ARPHRD_ETHER as u16).to_le_bytes())) - // Compare the ethernet destination address against the MAC address we want to drop - block_ethernet_rule.add_expr(&nft_expr!(payload ethernet daddr)); - block_ethernet_rule.add_expr(&nft_expr!(cmp == BLOCK_THIS_MAC)); + // Compare the ethernet destination address against the MAC address we want to drop + .with_value(HighLevelPayload::LinkLayer(LLHeaderField::Daddr).build()) + .with_value(Cmp::new(CmpOp::Eq, BLOCK_THIS_MAC)) - // Drop the matching packets. - block_ethernet_rule.add_expr(&nft_expr!(verdict drop)); + // Drop the matching packets. + .with_value(Immediate::new_verdict(VerdictKind::Drop)), + ); batch.add(&block_ethernet_rule, rustables::MsgType::Add); @@ -67,67 +73,26 @@ fn main() -> Result<(), Error> { // So after a number of packets has passed through this rule, the first counter should have a // value approximately double that of the second counter. This rule has no verdict, so it never // does anything with the matching packets. - let mut random_rule = Rule::new(Rc::clone(&out_chain)); - // This counter expression will be evaluated (and increment the counter) for all packets coming - // through. - random_rule.add_expr(&nft_expr!(counter)); + let mut random_rule = Rule::new(&out_chain).unwrap(); - // Load a pseudo-random 32 bit unsigned integer into the netfilter register. - random_rule.add_expr(&nft_expr!(meta random)); - // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. - random_rule.add_expr(&nft_expr!(cmp > (::std::u32::MAX / 2).to_be())); + random_rule.set_expressions( + ExpressionList::default() + // This counter expression will be evaluated (and increment the counter) for all packets coming + // through. + .with_value(Counter::default()) - // Add a second counter. This will only be incremented for the packets passing the random check. - random_rule.add_expr(&nft_expr!(counter)); + // Load a pseudo-random 32 bit unsigned integer into the netfilter register. + .with_value(Meta::new(MetaType::PRandom)) + // Check if the random integer is larger than `u32::MAX/2`, thus having 50% chance of success. + .with_value(Cmp::new(CmpOp::Gt, (::std::u32::MAX / 2).to_be_bytes())) + + // Add a second counter. This will only be incremented for the packets passing the random check. + .with_value(Counter::default()), + ); batch.add(&random_rule, rustables::MsgType::Add); // === FINALIZE THE TRANSACTION AND SEND THE DATA TO NETFILTER === - match batch.finalize() { - Some(mut finalized_batch) => { - send_and_process(&mut finalized_batch)?; - Ok(()) - }, - None => todo!() - } -} - -fn send_and_process(batch: &mut FinalizedBatch) -> Result<(), Error> { - // Create a netlink socket to netfilter. - let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; - // Send all the bytes in the batch. - socket.send_all(&mut *batch)?; - - // Try to parse the messages coming back from netfilter. This part is still very unclear. - let portid = socket.portid(); - let mut buffer = vec![0; rustables::nft_nlmsg_maxsize() as usize]; - let very_unclear_what_this_is_for = 2; - while let Some(message) = socket_recv(&socket, &mut buffer[..])? { - match mnl::cb_run(message, very_unclear_what_this_is_for, portid)? { - mnl::CbResult::Stop => { - break; - } - mnl::CbResult::Ok => (), - } - } - Ok(()) -} - -fn socket_recv<'a>(socket: &mnl::Socket, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> { - let ret = socket.recv(buf)?; - if ret > 0 { - Ok(Some(&buf[..ret])) - } else { - Ok(None) - } -} - -#[derive(Debug)] -struct Error(String); - -impl From<io::Error> for Error { - fn from(error: io::Error) -> Self { - Error(error.to_string()) - } + batch.send().unwrap(); } diff --git a/examples/firewall.rs b/examples/firewall.rs index 46a0a4d..3169cdc 100644 --- a/examples/firewall.rs +++ b/examples/firewall.rs @@ -1,35 +1,26 @@ -use rustables::{Batch, Chain, ChainMethods, Hook, MatchError, ProtoFamily, - Protocol, Rule, RuleMethods, Table, MsgType, Policy}; -use rustables::query::{send_batch, Error as QueryError}; -use rustables::expr::{LogGroup, LogPrefix, LogPrefixError}; +//use rustables::{Batch, Chain, ChainMethods, Hook, MatchError, ProtoFamily, +// Protocol, Rule, RuleMethods, Table, MsgType, Policy}; +//use rustables::query::{send_batch, Error as QueryError}; +//use rustables::expr::{LogGroup, LogPrefix, LogPrefixError}; use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; -use std::str::Utf8Error; -use std::rc::Rc; - +use rustables::error::{BuilderError, QueryError}; +use rustables::expr::Log; +use rustables::{ + Batch, Chain, ChainPolicy, Hook, HookClass, MsgType, Protocol, ProtocolFamily, Rule, Table, +}; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C String")] - NulError(#[from] NulError), - #[error("Error creating match")] - MatchError(#[from] MatchError), - #[error("Error converting to utf-8 string")] - Utf8Error(#[from] Utf8Error), - #[error("Error applying batch")] - BatchError(#[from] std::io::Error), + #[error("Error building a netlink object")] + BuildError(#[from] BuilderError), #[error("Error applying batch")] QueryError(#[from] QueryError), - #[error("Error encoding the prefix")] - LogPrefixError(#[from] LogPrefixError), } const TABLE_NAME: &str = "main-table"; - +const INBOUND_CHAIN_NAME: &str = "in-chain"; +const FORWARD_CHAIN_NAME: &str = "forward-chain"; +const OUTBOUND_CHAIN_NAME: &str = "out-chain"; fn main() -> Result<(), Error> { let fw = Firewall::new()?; @@ -37,93 +28,89 @@ fn main() -> Result<(), Error> { Ok(()) } - /// An example firewall. See the source of its `start()` method. pub struct Firewall { batch: Batch, - inbound: Rc<Chain>, - _outbound: Rc<Chain>, - _forward: Rc<Chain>, - table: Rc<Table>, + inbound: Chain, + _outbound: Chain, + _forward: Chain, + table: Table, } impl Firewall { pub fn new() -> Result<Self, Error> { let mut batch = Batch::new(); - let table = Rc::new( - Table::new(&CString::new(TABLE_NAME)?, ProtoFamily::Inet) - ); + let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME); batch.add(&table, MsgType::Add); // Create base chains. Base chains are hooked into a Direction/Hook. - let inbound = Rc::new( - Chain::from_hook(Hook::In, Rc::clone(&table)) - .verdict(Policy::Drop) - .add_to_batch(&mut batch) - ); - let _outbound = Rc::new( - Chain::from_hook(Hook::Out, Rc::clone(&table)) - .verdict(Policy::Accept) - .add_to_batch(&mut batch) - ); - let _forward = Rc::new( - Chain::from_hook(Hook::Forward, Rc::clone(&table)) - .verdict(Policy::Accept) - .add_to_batch(&mut batch) - ); + let inbound = Chain::new(&table) + .with_name(INBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::In, 0)) + .with_policy(ChainPolicy::Drop) + .add_to_batch(&mut batch); + let _outbound = Chain::new(&table) + .with_name(OUTBOUND_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Out, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); + let _forward = Chain::new(&table) + .with_name(FORWARD_CHAIN_NAME) + .with_hook(Hook::new(HookClass::Forward, 0)) + .with_policy(ChainPolicy::Accept) + .add_to_batch(&mut batch); Ok(Firewall { table, batch, inbound, _outbound, - _forward + _forward, }) } /// Allow some common-sense exceptions to inbound drop, and accept outbound and forward. pub fn start(mut self) -> Result<(), Error> { // Allow all established connections to get in. - Rule::new(Rc::clone(&self.inbound)) - .established() - .accept() - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .established()? + .accept() + .add_to_batch(&mut self.batch); // Allow all traffic on the loopback interface. - Rule::new(Rc::clone(&self.inbound)) - .iface("lo")? - .accept() - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .iface("lo")? + .accept() + .add_to_batch(&mut self.batch); // Allow ssh from anywhere, and log to dmesg with a prefix. - Rule::new(Rc::clone(&self.inbound)) - .dport("22", &Protocol::TCP)? - .accept() - .log(None, Some(LogPrefix::new("allow ssh connection:")?)) - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .dport(22, Protocol::TCP) + .accept() + .with_expr(Log::new(None, Some("allow ssh connection:"))?) + .add_to_batch(&mut self.batch); // Allow http from all IPs in 192.168.1.255/24 . let local_net = IpNetwork::new([192, 168, 1, 0].into(), 24).unwrap(); - Rule::new(Rc::clone(&self.inbound)) - .dport("80", &Protocol::TCP)? - .snetwork(local_net) - .accept() - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .dport(80, Protocol::TCP) + .snetwork(local_net)? + .accept() + .add_to_batch(&mut self.batch); // Allow ICMP traffic, drop IGMP. - Rule::new(Rc::clone(&self.inbound)) - .icmp() - .accept() - .add_to_batch(&mut self.batch); - Rule::new(Rc::clone(&self.inbound)) - .igmp() - .drop() - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .icmp() + .accept() + .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .igmp() + .drop() + .add_to_batch(&mut self.batch); // Log all traffic not accepted to NF_LOG group 1, accessible with ulogd. - Rule::new(Rc::clone(&self.inbound)) - .log(Some(LogGroup(1)), None) - .add_to_batch(&mut self.batch); + Rule::new(&self.inbound)? + .with_expr(Log::new(Some(1), None::<String>)?) + .add_to_batch(&mut self.batch); - let mut finalized_batch = self.batch.finalize().unwrap(); - send_batch(&mut finalized_batch)?; + self.batch.send()?; println!("table {} commited", TABLE_NAME); Ok(()) } @@ -132,11 +119,8 @@ impl Firewall { self.batch.add(&self.table, MsgType::Add); self.batch.add(&self.table, MsgType::Del); - let mut finalized_batch = self.batch.finalize().unwrap(); - send_batch(&mut finalized_batch)?; + self.batch.send()?; println!("table {} destroyed", TABLE_NAME); Ok(()) } } - - diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..c527b71 --- /dev/null +++ b/flake.lock @@ -0,0 +1,78 @@ +{ + "nodes": { + "crate2nix": { + "flake": false, + "locked": { + "lastModified": 1667176522, + "narHash": "sha256-BCAfYlEdC19gprvgTV3ht5gC24qQ+HL6kbVJWBOxcio=", + "owner": "kolloch", + "repo": "crate2nix", + "rev": "3e6fbcc8ecd384018196223023cdd7868bbce4e6", + "type": "github" + }, + "original": { + "owner": "kolloch", + "ref": "master", + "repo": "crate2nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1667610399, + "narHash": "sha256-XZd0f4ZWAY0QOoUSdiNWj/eFiKb4B9CJPtl9uO9SYY4=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "1dd8696f96db47156e1424a49578fe7dd4ce99a4", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-mozilla": { + "flake": false, + "locked": { + "lastModified": 1664789696, + "narHash": "sha256-UGWJHQShiwLCr4/DysMVFrYdYYHcOqAOVsWNUu+l6YU=", + "owner": "mozilla", + "repo": "nixpkgs-mozilla", + "rev": "80627b282705101e7b38e19ca6e8df105031b072", + "type": "github" + }, + "original": { + "owner": "mozilla", + "repo": "nixpkgs-mozilla", + "type": "github" + } + }, + "root": { + "inputs": { + "crate2nix": "crate2nix", + "nixpkgs": "nixpkgs", + "nixpkgs-mozilla": "nixpkgs-mozilla", + "utils": "utils" + } + }, + "utils": { + "locked": { + "lastModified": 1667395993, + "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..0ea7178 --- /dev/null +++ b/flake.nix @@ -0,0 +1,66 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + nixpkgs-mozilla = { url = "github:mozilla/nixpkgs-mozilla"; flake = false; }; + crate2nix = { url = "github:kolloch/crate2nix/master"; flake = false; }; + utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, nixpkgs-mozilla, crate2nix, utils } @ inputs : + let + rustOverlay = (final: prev: + let + rustChannel = prev.rustChannelOf { + channel = "1.66.0"; + sha256 = "S7epLlflwt0d1GZP44u5Xosgf6dRrmr8xxC+Ml2Pq7c="; + }; + in + { + inherit rustChannel; + rustc = rustChannel.rust; + cargo = rustChannel.rust; + } + ); + rustDevOverlay = final: prev: { + # rust-analyzer needs core source + rustc-with-src = prev.rustc.override { extensions = [ "rust-src" ]; }; + }; + in + utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + overlays = [ (import "${nixpkgs-mozilla}/rust-overlay.nix") rustOverlay rustDevOverlay ]; + }; + nativeBuildInputs = with pkgs; [ pkg-config ]; + buildInputs = with pkgs; [ clang linuxHeaders ]; + LIBCLANG_PATH = pkgs.lib.makeLibraryPath [ pkgs.llvmPackages_latest.libclang.lib ]; + customBuildCrate = pkgs: pkgs.buildRustCrate.override { + defaultCrateOverrides = pkgs.defaultCrateOverrides // { + rustables = attrs: { + nativeBuildInputs = nativeBuildInputs; + buildInputs = buildInputs; + LIBCLANG_PATH = LIBCLANG_PATH; + }; + }; + }; + cargoNix = import ./Cargo.nix { + inherit pkgs; + buildRustCrateForPkgs = customBuildCrate; + release = false; + }; + in { + defaultPackage = cargoNix.rootCrate.build; + packages = { + rustables = cargoNix.rootCrate.build; + }; + devShell = pkgs.mkShell { + name = "rustables"; + nativeBuildInputs = nativeBuildInputs; + buildInputs = buildInputs; + LIBCLANG_PATH = LIBCLANG_PATH; + packages = with pkgs; [ rust-analyzer rustc-with-src ]; + }; + } + ); +} diff --git a/include/tests_wrapper.h b/include/tests_wrapper.h deleted file mode 100644 index 8f976e8..0000000 --- a/include/tests_wrapper.h +++ /dev/null @@ -1 +0,0 @@ -#include "linux/netfilter/nf_tables.h" diff --git a/include/wrapper.h b/include/wrapper.h index e6eb221..cb96617 100644 --- a/include/wrapper.h +++ b/include/wrapper.h @@ -1,12 +1,3 @@ -#include <libnftnl/batch.h> -#include <libnftnl/chain.h> -#include <libnftnl/common.h> -#include <libnftnl/expr.h> -#include <libnftnl/gen.h> -#include <libnftnl/object.h> -#include <libnftnl/rule.h> -#include <libnftnl/ruleset.h> -#include <libnftnl/set.h> -#include <libnftnl/table.h> -#include <libnftnl/trace.h> -#include <libnftnl/udata.h> +#include <linux/netlink.h> +#include <linux/netfilter/nfnetlink.h> +#include <linux/netfilter/nf_tables.h> diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 0000000..5d0f297 --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rustables-macros" +version = "0.1.0" +authors = ["Simon Thoby"] +license = "GPL-3.0-or-later" +description = "Internal macros for generation netlink structures for the rustables project" +repository = "https://gitlab.com/rustwall/rustables" +resolver = "2" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0", features = ["full"] } +quote = "1.0" +proc-macro2 = "1.0" +proc-macro-error = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 0000000..39f0d01 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,497 @@ +use proc_macro::TokenStream; +use proc_macro2::{Group, Span}; +use quote::quote; + +use proc_macro_error::{abort, proc_macro_error}; +use syn::parse::Parser; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{ + parse, parse2, Attribute, Expr, ExprCast, Ident, ItemEnum, ItemStruct, Lit, Meta, Path, Result, + Token, Type, TypePath, Visibility, +}; + +struct Field<'a> { + name: &'a Ident, + ty: &'a Type, + args: FieldArgs, + netlink_type: Path, + vis: &'a Visibility, + attrs: Vec<&'a Attribute>, +} + +#[derive(Default)] +struct FieldArgs { + netlink_type: Option<Path>, + override_function_name: Option<String>, +} + +fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs> { + let input = parse2::<Group>(input)?.stream(); + let mut args = FieldArgs::default(); + let parser = Punctuated::<Meta, Token![,]>::parse_terminated; + let attribute_args = parser.parse2(input)?; + for arg in attribute_args.iter() { + match arg { + Meta::Path(path) => { + if args.netlink_type.is_none() { + args.netlink_type = Some(path.clone()); + } else { + abort!( + arg.span(), + "Only a single netlink value can exist for a given field" + ); + } + } + Meta::NameValue(namevalue) => { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "name_in_functions" => { + if let Lit::Str(val) = &namevalue.lit { + args.override_function_name = Some(val.value()); + } else { + abort!(&namevalue.lit.span(), "Expected a string literal"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } + _ => abort!(arg.span(), "Unrecognized argument"), + } + } + Ok(args) +} + +struct StructArgs { + nested: bool, + derive_decoder: bool, + derive_deserialize: bool, +} + +impl Default for StructArgs { + fn default() -> Self { + Self { + nested: false, + derive_decoder: true, + derive_deserialize: true, + } + } +} + +fn parse_struct_args(input: TokenStream) -> Result<StructArgs> { + let mut args = StructArgs::default(); + let parser = Punctuated::<Meta, Token![,]>::parse_terminated; + let attribute_args = parser.parse(input.clone())?; + for arg in attribute_args.iter() { + if let Meta::NameValue(namevalue) = arg { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "derive_decoder" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.derive_decoder = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + "nested" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.nested = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + "derive_deserialize" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.derive_deserialize = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } else { + abort!(arg.span(), "Unrecognized argument"); + } + } + Ok(args) +} + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream { + let ast: ItemStruct = parse(item).unwrap(); + let name = ast.ident; + + let args = match parse_struct_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; + + let mut fields = Vec::with_capacity(ast.fields.len()); + let mut identical_fields = Vec::new(); + + 'out: for field in ast.fields.iter() { + for attr in field.attrs.iter() { + if let Some(id) = attr.path.get_ident() { + if id == "field" { + let field_args = match parse_field_args(attr.tokens.clone()) { + Ok(x) => x, + Err(_) => { + abort!(attr.tokens.span(), "Could not parse the field attributes") + } + }; + if let Some(netlink_type) = field_args.netlink_type.clone() { + fields.push(Field { + name: field.ident.as_ref().expect("Should be a names struct"), + ty: &field.ty, + args: field_args, + netlink_type, + vis: &field.vis, + // drop the "field" attribute + attrs: field + .attrs + .iter() + .filter(|x| x.path.get_ident() != attr.path.get_ident()) + .collect(), + }); + } else { + abort!(attr.tokens.span(), "Missing Netlink Type in field"); + } + continue 'out; + } + } + } + identical_fields.push(field); + } + + let getters_and_setters = fields.iter().map(|field| { + let field_name = field.name; + // use the name override if any + let field_str = field_name.to_string(); + let field_str = field + .args + .override_function_name + .as_ref() + .map(|x| x.as_str()) + .unwrap_or(field_str.as_str()); + let field_type = field.ty; + + let getter_name = format!("get_{}", field_str); + let getter_name = Ident::new(&getter_name, field.name.span()); + + let muttable_getter_name = format!("get_mut_{}", field_str); + let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span()); + + let setter_name = format!("set_{}", field_str); + let setter_name = Ident::new(&setter_name, field.name.span()); + + let in_place_edit_name = format!("with_{}", field_str); + let in_place_edit_name = Ident::new(&in_place_edit_name, field.name.span()); + quote!( + #[allow(dead_code)] + impl #name { + pub fn #getter_name(&self) -> Option<&#field_type> { + self.#field_name.as_ref() + } + + pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> { + self.#field_name.as_mut() + } + + pub fn #setter_name(&mut self, val: impl Into<#field_type>) { + self.#field_name = Some(val.into()); + } + + pub fn #in_place_edit_name(mut self, val: impl Into<#field_type>) -> Self { + self.#field_name = Some(val.into()); + self + } + }) + }); + + let decoder = if args.derive_decoder { + let match_entries = fields.iter().map(|field| { + let field_name = field.name; + let field_type = field.ty; + let netlink_value = &field.netlink_type; + quote!( + x if x == #netlink_value => { + debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>()); + let (val, remaining) = <#field_type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err(crate::error::DecodeError::InvalidDataSize); + } + self.#field_name = Some(val); + Ok(()) + } + ) + }); + quote!( + impl crate::nlmsg::AttributeDecoder for #name { + #[allow(dead_code)] + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> { + use crate::nlmsg::NfNetlinkDeserializable; + debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>()); + match attr_type { + #(#match_entries),* + _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + ) + } else { + proc_macro2::TokenStream::new() + }; + + let nfnetlinkattribute_impl = { + let size_entries = fields.iter().map(|field| { + let field_name = field.name; + quote!( + if let Some(val) = &self.#field_name { + // Attribute header + attribute value + size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); + } + ) + }); + let write_entries = fields.iter().map(|field| { + let field_name = field.name; + let field_str = field_name.to_string(); + let netlink_value = &field.netlink_type; + quote!( + if let Some(val) = &self.#field_name { + debug!("writing attribute {} - {:?}", #field_str, val); + + crate::parser::write_attribute(#netlink_value, val, addr); + + #[allow(unused)] + { + let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>() + + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size()); + addr = addr.offset(size as isize); + } + } + ) + }); + let nested = args.nested; + quote!( + impl crate::nlmsg::NfNetlinkAttribute for #name { + fn is_nested(&self) -> bool { + #nested + } + + fn get_size(&self) -> usize { + use crate::nlmsg::NfNetlinkAttribute; + + let mut size = 0; + #(#size_entries) * + size + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + use crate::nlmsg::NfNetlinkAttribute; + + #(#write_entries) * + } + } + ) + }; + + let vis = &ast.vis; + let attrs = ast.attrs; + let new_fields = fields.iter().map(|field| { + let name = field.name; + let ty = field.ty; + let attrs = &field.attrs; + let vis = &field.vis; + quote!( #(#attrs) * #vis #name: Option<#ty>, ) + }); + let nfnetlinkdeserialize_impl = if args.derive_deserialize { + quote!( + impl crate::nlmsg::NfNetlinkDeserializable for #name { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> { + Ok((crate::parser::read_attributes(buf)?, &[])) + } + } + ) + } else { + proc_macro2::TokenStream::new() + }; + let res = quote! { + #(#attrs) * #vis struct #name { + #(#new_fields)* + #(#identical_fields),* + } + + #(#getters_and_setters) * + + #decoder + + #nfnetlinkattribute_impl + + #nfnetlinkdeserialize_impl + }; + + res.into() +} + +struct Variant<'a> { + inner: &'a syn::Variant, + name: &'a Ident, + value: &'a Path, +} + +#[derive(Default)] +struct EnumArgs { + nested: bool, + ty: Option<Path>, +} + +fn parse_enum_args(input: TokenStream) -> Result<EnumArgs> { + let mut args = EnumArgs::default(); + let parser = Punctuated::<Meta, Token![,]>::parse_terminated; + let attribute_args = parser.parse(input)?; + for arg in attribute_args.iter() { + match arg { + Meta::Path(path) => { + if args.ty.is_none() { + args.ty = Some(path.clone()); + } else { + abort!(arg.span(), "A value can only have a single representation"); + } + } + Meta::NameValue(namevalue) => { + let key = namevalue + .path + .get_ident() + .expect("the macro parameter is not an ident?") + .to_string(); + match key.as_str() { + "nested" => { + if let Lit::Bool(boolean) = &namevalue.lit { + args.nested = boolean.value; + } else { + abort!(&namevalue.lit.span(), "Expected a boolean"); + } + } + _ => abort!(key.span(), "Unsupported macro parameter"), + } + } + _ => abort!(arg.span(), "Unrecognized argument"), + } + } + Ok(args) +} + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream { + let ast: ItemEnum = parse(item).unwrap(); + let name = ast.ident; + + let args = match parse_enum_args(attrs) { + Ok(x) => x, + Err(_) => abort!(Span::call_site(), "Could not parse the macro arguments"), + }; + + if args.ty.is_none() { + abort!( + Span::call_site(), + "The target type representation is unspecified" + ); + } + + let mut variants = Vec::with_capacity(ast.variants.len()); + + for variant in ast.variants.iter() { + if variant.discriminant.is_none() { + abort!(variant.ident.span(), "Missing value"); + } + let discriminant = variant.discriminant.as_ref().unwrap(); + if let syn::Expr::Path(path) = &discriminant.1 { + variants.push(Variant { + inner: variant, + name: &variant.ident, + value: &path.path, + }); + } else { + abort!(discriminant.1.span(), "Expected a path"); + } + } + + let repr_type = args.ty.unwrap(); + let match_entries = variants.iter().map(|variant| { + let variant_name = variant.name; + let variant_value = &variant.value; + quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), ) + }); + let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span()); + let tryfrom_impl = quote!( + impl ::core::convert::TryFrom<#repr_type> for #name { + type Error = crate::error::DecodeError; + + fn try_from(val: #repr_type) -> Result<Self, Self::Error> { + match val { + #(#match_entries) * + value => Err(crate::error::DecodeError::#unknown_type_ident(value)) + } + } + } + ); + let nfnetlinkdeserialize_impl = quote!( + impl crate::nlmsg::NfNetlinkDeserializable for #name { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> { + let (v, remaining_data) = #repr_type::deserialize(buf)?; + <#name>::try_from(v).map(|x| (x, remaining_data)) + } + } + ); + let vis = &ast.vis; + let attrs = ast.attrs; + let original_variants = variants.into_iter().map(|x| { + let mut inner = x.inner.clone(); + let mut discriminant = inner.discriminant.as_mut().unwrap(); + let cur_value = discriminant.1.clone(); + let cast_value = Expr::Cast(ExprCast { + attrs: vec![], + expr: Box::new(cur_value), + as_token: Token), + ty: Box::new(Type::Path(TypePath { + qself: None, + path: repr_type.clone(), + })), + }); + discriminant.1 = cast_value; + inner + }); + let res = quote! { + #[repr(#repr_type)] + #(#attrs) * #vis enum #name { + #(#original_variants),* + } + + impl crate::nlmsg::NfNetlinkAttribute for #name { + fn get_size(&self) -> usize { + (*self as #repr_type).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as #repr_type).write_payload(addr); + } + } + + #tryfrom_impl + + #nfnetlinkdeserialize_impl + }; + + res.into() +} diff --git a/src/batch.rs b/src/batch.rs index 198e8d0..b5c88b8 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,31 +1,29 @@ -use crate::{MsgType, NlMsg}; -use crate::sys::{self as sys, libc}; -use std::ffi::c_void; -use std::os::raw::c_char; -use std::ptr; +use libc; + use thiserror::Error; +use crate::error::QueryError; +use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::sys::NFNL_SUBSYS_NFTABLES; +use crate::{MsgType, ProtocolFamily}; + +use nix::sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, +}; + /// Error while communicating with netlink. #[derive(Error, Debug)] #[error("Error while communicating with netlink")] pub struct NetlinkError(()); -#[cfg(feature = "query")] -/// Check if the kernel supports batched netlink messages to netfilter. -pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> { - match unsafe { sys::nftnl_batch_is_supported() } { - 1 => Ok(true), - 0 => Ok(false), - _ => Err(NetlinkError(())), - } -} - -/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to -/// `nftnl_batch` in libnftnl. +/// A batch of netfilter messages to be performed in one atomic operation. pub struct Batch { - pub(crate) batch: *mut sys::nftnl_batch, - pub(crate) seq: u32, - pub(crate) is_empty: bool, + buf: Box<Vec<u8>>, + // the 'static lifetime here is a cheat, as the writer can only be used as long + // as `self.buf` exists. This is why this member must never be exposed directly to + // the rest of the crate (let alone publicly). + writer: NfNetlinkWriter<'static>, + seq: u32, } impl Batch { @@ -33,48 +31,40 @@ impl Batch { /// /// [default page size]: fn.default_batch_page_size.html pub fn new() -> Self { - Self::with_page_size(default_batch_page_size()) - } - - pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self { - Batch { - batch, + // TODO: use a pinned Box ? + let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize)); + let mut writer = NfNetlinkWriter::new(unsafe { + std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) + }); + let seq = 0; + writer.write_header( + libc::NFNL_MSG_BATCH_BEGIN as u16, + ProtocolFamily::Unspec, + 0, seq, - // we assume this batch is not empty by default - is_empty: false, + Some(libc::NFNL_SUBSYS_NFTABLES as u16), + ); + writer.finalize_writing_object(); + Batch { + buf, + writer, + seq: seq + 1, } } - /// Creates a new nftnl batch with the given batch size. - pub fn with_page_size(batch_page_size: u32) -> Self { - let batch = try_alloc!(unsafe { - sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize()) - }); - let mut this = Batch { - batch, - seq: 0, - is_empty: true, - }; - this.write_begin_msg(); - this - } - /// Adds the given message to this batch. - pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) { + pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) { trace!("Writing NlMsg with seq {} to batch", self.seq); - unsafe { msg.write(self.current(), self.seq, msg_type) }; - self.is_empty = false; - self.next() + msg.add_or_remove(&mut self.writer, msg_type, self.seq); + self.seq += 1; } - /// Adds all the messages in the given iterator to this batch. If any message fails to be - /// added the error for that failure is returned and all messages up until that message stay - /// added to the batch. - pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType) - where - T: NlMsg, - I: Iterator<Item = T>, - { + /// Adds all the messages in the given iterator to this batch. + pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>( + &mut self, + msg_iter: I, + msg_type: MsgType, + ) { for msg in msg_iter { self.add(&msg, msg_type); } @@ -86,109 +76,46 @@ impl Batch { /// Return None if there is no object in the batch (this could block forever). /// /// [`FinalizedBatch`]: struct.FinalizedBatch.html - pub fn finalize(mut self) -> Option<FinalizedBatch> { - self.write_end_msg(); - if self.is_empty { - return None; - } - Some(FinalizedBatch { batch: self }) - } - - fn current(&self) -> *mut c_void { - unsafe { sys::nftnl_batch_buffer(self.batch) } - } - - fn next(&mut self) { - if unsafe { sys::nftnl_batch_update(self.batch) } < 0 { - // See try_alloc definition. - std::process::abort(); - } - self.seq += 1; - } - - fn write_begin_msg(&mut self) { - unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) }; - self.next(); - } - - fn write_end_msg(&mut self) { - unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) }; - self.next(); - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_batch { - self.batch as *const sys::nftnl_batch - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch { - self.batch - } -} - -impl Drop for Batch { - fn drop(&mut self) { - unsafe { sys::nftnl_batch_free(self.batch) }; - } -} - -/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper -/// batch end message. Created from [`Batch::finalize`]. -/// -/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch. -/// -/// [`Batch`]: struct.Batch.html -/// [`Batch::finalize`]: struct.Batch.html#method.finalize -pub struct FinalizedBatch { - batch: Batch, -} - -impl FinalizedBatch { - /// Returns the iterator over byte buffers to send to netlink. - pub fn iter(&mut self) -> Iter<'_> { - let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize }; - let mut iovecs = vec![ - libc::iovec { - iov_base: ptr::null_mut(), - iov_len: 0, - }; - num_pages - ]; - let iovecs_ptr = iovecs.as_mut_ptr(); - unsafe { - sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32); - } - Iter { - iovecs: iovecs.into_iter(), - _marker: ::std::marker::PhantomData, + pub fn finalize(mut self) -> Vec<u8> { + self.writer.write_header( + libc::NFNL_MSG_BATCH_END as u16, + ProtocolFamily::Unspec, + 0, + self.seq, + Some(NFNL_SUBSYS_NFTABLES as u16), + ); + self.writer.finalize_writing_object(); + *self.buf + } + + pub fn send(self) -> Result<(), QueryError> { + use crate::query::{recv_and_process, socket_close_wrapper}; + + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(QueryError::NetlinkOpenError)?; + + let max_seq = self.seq - 1; + + let addr = SockAddr::Netlink(NetlinkAddr::new(0, 0)); + // while this bind() is not strictly necessary, strace have trouble decoding the messages + // if we don't + socket::bind(sock, &addr).expect("bind"); + + let to_send = self.finalize(); + let mut sent = 0; + while sent != to_send.len() { + sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) + .map_err(QueryError::NetlinkSendError)?; } - } -} - -impl<'a> IntoIterator for &'a mut FinalizedBatch { - type Item = &'a [u8]; - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Iter<'a> { - self.iter() - } -} - -pub struct Iter<'a> { - iovecs: ::std::vec::IntoIter<libc::iovec>, - _marker: ::std::marker::PhantomData<&'a ()>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = &'a [u8]; - fn next(&mut self) -> Option<&'a [u8]> { - self.iovecs.next().map(|iovec| unsafe { - ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len) - }) + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, Some(max_seq), None, &mut ()) + })?) } } diff --git a/src/chain.rs b/src/chain.rs index a942a37..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,41 +1,85 @@ -use crate::{MsgType, Table}; -use crate::sys::{self as sys, libc}; -#[cfg(feature = "query")] -use std::convert::TryFrom; -use std::{ - ffi::{c_void, CStr, CString}, - fmt, - os::raw::c_char, - rc::Rc, +use libc::{NF_ACCEPT, NF_DROP}; +use rustables_macros::nfnetlink_struct; + +use crate::error::{DecodeError, QueryError}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; +use crate::sys::{ + NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, + NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, + NFT_MSG_NEWCHAIN, }; +use crate::{Batch, ProtocolFamily, Table}; +use std::fmt::Debug; -pub type Priority = i32; +pub type ChainPriority = i32; /// The netfilter event hooks a chain can register for. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum Hook { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum HookClass { /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. - PreRouting = libc::NF_INET_PRE_ROUTING as u16, + PreRouting = libc::NF_INET_PRE_ROUTING, /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. - In = libc::NF_INET_LOCAL_IN as u16, + In = libc::NF_INET_LOCAL_IN, /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. - Forward = libc::NF_INET_FORWARD as u16, + Forward = libc::NF_INET_FORWARD, /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. - Out = libc::NF_INET_LOCAL_OUT as u16, + Out = libc::NF_INET_LOCAL_OUT, /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. - PostRouting = libc::NF_INET_POST_ROUTING as u16, + PostRouting = libc::NF_INET_POST_ROUTING, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Hook { + /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + #[field(NFTA_HOOK_HOOKNUM)] + class: u32, + #[field(NFTA_HOOK_PRIORITY)] + priority: u32, +} + +impl Hook { + pub fn new(class: HookClass, priority: ChainPriority) -> Self { + Hook::default() + .with_class(class as u32) + .with_priority(priority as u32) + } } /// A chain policy. Decides what to do with a packet that was processed by the chain but did not /// match any rules. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u32)] -pub enum Policy { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum ChainPolicy { /// Accept the packet. - Accept = libc::NF_ACCEPT as u32, + Accept = NF_ACCEPT, /// Drop the packet. - Drop = libc::NF_DROP as u32, + Drop = NF_DROP, +} + +impl NfNetlinkAttribute for ChainPolicy { + fn get_size(&self) -> usize { + (*self as i32).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainPolicy { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok(( + match v { + NF_ACCEPT => ChainPolicy::Accept, + NF_DROP => ChainPolicy::Accept, + _ => return Err(DecodeError::UnknownChainPolicy), + }, + remaining_data, + )) + } } /// Base chain type. @@ -53,240 +97,117 @@ pub enum ChainType { } impl ChainType { - fn as_c_str(&self) -> &'static [u8] { + fn as_str(&self) -> &'static str { match *self { - ChainType::Filter => b"filter\0", - ChainType::Route => b"route\0", - ChainType::Nat => b"nat\0", + ChainType::Filter => "filter", + ChainType::Route => "route", + ChainType::Nat => "nat", } } } -/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s. -/// -/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more -/// details. +impl NfNetlinkAttribute for ChainType { + fn get_size(&self) -> usize { + self.as_str().len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.as_str().to_string().write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainType { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (s, remaining_data) = String::deserialize(buf)?; + Ok(( + match s.as_str() { + "filter" => ChainType::Filter, + "route" => ChainType::Route, + "nat" => ChainType::Nat, + _ => return Err(DecodeError::UnknownChainType), + }, + remaining_data, + )) + } +} + +/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s. /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -/// [`set_hook`]: #method.set_hook +#[derive(PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Chain { - pub(crate) chain: *mut sys::nftnl_chain, - pub(crate) table: Rc<Table>, + family: ProtocolFamily, + #[field(NFTA_CHAIN_TABLE)] + table: String, + #[field(NFTA_CHAIN_NAME)] + name: String, + #[field(NFTA_CHAIN_HOOK)] + hook: Hook, + #[field(NFTA_CHAIN_POLICY)] + policy: ChainPolicy, + #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")] + chain_type: ChainType, + #[field(NFTA_CHAIN_FLAGS)] + flags: u32, + #[field(NFTA_CHAIN_USERDATA)] + userdata: Vec<u8>, } impl Chain { - /// Creates a new chain instance inside the given [`Table`] and with the given name. + /// Creates a new chain instance inside the given [`Table`]. /// /// [`Table`]: struct.Table.html - pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain { - unsafe { - let chain = try_alloc!(sys::nftnl_chain_alloc()); - sys::nftnl_chain_set_u32( - chain, - sys::NFTNL_CHAIN_FAMILY as u16, - table.get_family() as u32, - ); - sys::nftnl_chain_set_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - table.get_name().as_ptr(), - ); - sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr()); - Chain { chain, table } - } - } - - pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self { - Chain { chain, table } - } + pub fn new(table: &Table) -> Chain { + let mut chain = Chain::default(); + chain.family = table.get_family(); - /// Sets the hook and priority for this chain. Without calling this method the chain will - /// become a "regular chain" without any hook and will thus not receive any traffic unless - /// some rule forward packets to it via goto or jump verdicts. - /// - /// By calling `set_hook` with a hook the chain that is created will be registered with that - /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the - /// networking stack. - pub fn set_hook(&mut self, hook: Hook, priority: Priority) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32); - sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority); + if let Some(table_name) = table.get_name() { + chain.set_table(table_name); } - } - /// Set the type of a base chain. This only applies if the chain has been registered - /// with a hook by calling `set_hook`. - pub fn set_type(&mut self, chain_type: ChainType) { - unsafe { - sys::nftnl_chain_set_str( - self.chain, - sys::NFTNL_CHAIN_TYPE as u16, - chain_type.as_c_str().as_ptr() as *const c_char, - ); - } + chain } - /// Sets the default policy for this chain. That means what action netfilter will apply to - /// packets processed by this chain, but that did not match any rules in it. - pub fn set_policy(&mut self, policy: Policy) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32); - } - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16); - if ptr == std::ptr::null() { - return None; - } - Some(CStr::from_ptr(ptr)) - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns the name of this chain. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } - } - - /// Returns a textual description of the chain. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_chain_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.chain, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Returns a reference to the [`Table`] this chain belongs to. - /// - /// [`Table`]: struct.Table.html - pub fn get_table(&self) -> Rc<Table> { - self.table.clone() - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_chain { - self.chain as *const sys::nftnl_chain - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain { - self.chain + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl fmt::Debug for Chain { - /// Returns a string representation of the chain. - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{:?}", self.get_str()) - } -} +impl NfNetlinkObject for Chain { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN; -impl PartialEq for Chain { - fn eq(&self, other: &Self) -> bool { - self.get_table() == other.get_table() && self.get_name() == other.get_name() + fn get_family(&self) -> ProtocolFamily { + self.family } -} -unsafe impl crate::NlMsg for Chain { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let raw_msg_type = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWCHAIN, - MsgType::Del => libc::NFT_MSG_DELCHAIN, - }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16, - MsgType::Del => libc::NLM_F_ACK as u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - raw_msg_type as u16, - self.table.get_family() as u16, - flags, - seq, - ); - sys::nftnl_chain_nlmsg_build_payload(header, self.chain); + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -impl Drop for Chain { - fn drop(&mut self) { - unsafe { sys::nftnl_chain_free(self.chain) }; - } -} - -#[cfg(feature = "query")] -pub fn get_chains_cb<'a>( - header: &libc::nlmsghdr, - (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), -) -> libc::c_int { - unsafe { - let chain = sys::nftnl_chain_alloc(); - if chain == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_chain_nlmsg_parse(header, chain); - if err < 0 { - error!("Failed to parse nelink chain message - {}", err); - sys::nftnl_chain_free(chain); - return err; - } - - let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - )); - let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16); - let family = match crate::ProtoFamily::try_from(family as i32) { - Ok(family) => family, - Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_ERROR; +pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> { + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETCHAIN as u16, + &|chain: Chain, (table, chains): &mut (&Table, &mut Vec<Chain>)| { + if chain.get_table() == table.get_name() { + chains.push(chain); + } else { + info!( + "Ignoring chain {:?} because it doesn't map the table {:?}", + chain.get_name(), + table.get_name() + ); } - }; - - if table_name != table.get_name() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - if family != crate::ProtoFamily::Unspec && family != table.get_family() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - chains.push(Chain::from_raw(chain, table.clone())); - } - mnl::mnl_sys::MNL_CB_OK -} - -#[cfg(feature = "query")] -pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) + Ok(()) + }, + None, + &mut (&table, &mut result), + )?; + Ok(result) } diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc<Table>) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc<Table>) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/data_type.rs b/src/data_type.rs new file mode 100644 index 0000000..43a7f1a --- /dev/null +++ b/src/data_type.rs @@ -0,0 +1,42 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +pub trait DataType { + const TYPE: u32; + const LEN: u32; + + fn data(&self) -> Vec<u8>; +} + +impl DataType for Ipv4Addr { + const TYPE: u32 = 7; + const LEN: u32 = 4; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl DataType for Ipv6Addr { + const TYPE: u32 = 8; + const LEN: u32 = 16; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl<const N: usize> DataType for [u8; N] { + const TYPE: u32 = 5; + const LEN: u32 = N as u32; + + fn data(&self) -> Vec<u8> { + self.to_vec() + } +} + +pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> { + match ip { + IpAddr::V4(x) => x.octets().to_vec(), + IpAddr::V6(x) => x.octets().to_vec(), + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..f6b6247 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,180 @@ +use std::string::FromUtf8Error; + +use nix::errno::Errno; +use thiserror::Error; + +use crate::sys::nlmsgerr; + +#[derive(Error, Debug)] +pub enum DecodeError { + #[error("The buffer is too small to hold a valid message")] + BufTooSmall, + + #[error("The message is too small")] + NlMsgTooSmall, + + #[error("The message holds unexpected data")] + InvalidDataSize, + + #[error("Invalid subsystem, expected NFTABLES")] + InvalidSubsystem(u8), + + #[error("Invalid version, expected NFNETLINK_V0")] + InvalidVersion(u8), + + #[error("Invalid port ID")] + InvalidPortId(u32), + + #[error("Invalid sequence number")] + InvalidSeq(u32), + + #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] + ConcurrentGenerationUpdate, + + #[error("Unsupported message type")] + UnsupportedType(u16), + + #[error("Invalid attribute type")] + InvalidAttributeType, + + #[error("Invalid type for a chain")] + UnknownChainType, + + #[error("Invalid policy for a chain")] + UnknownChainPolicy, + + #[error("Unknown type for a Meta expression")] + UnknownMetaType(u32), + + #[error("Unsupported value for an icmp reject type")] + UnknownRejectType(u32), + + #[error("Unsupported value for an icmp code in a reject expression")] + UnknownIcmpCode(u8), + + #[error("Invalid value for a register")] + UnknownRegister(u32), + + #[error("Invalid type for a verdict expression")] + UnknownVerdictType(i32), + + #[error("Invalid type for a nat expression")] + UnknownNatType(i32), + + #[error("Invalid type for a payload expression")] + UnknownPayloadType(u32), + + #[error("Invalid type for a compare expression")] + UnknownCmpOp(u32), + + #[error("Invalid type for a conntrack key")] + UnknownConntrackKey(u32), + + #[error("Unsupported value for a link layer header field")] + UnknownLinkLayerHeaderField(u32, u32), + + #[error("Unsupported value for an IPv4 header field")] + UnknownIPv4HeaderField(u32, u32), + + #[error("Unsupported value for an IPv6 header field")] + UnknownIPv6HeaderField(u32, u32), + + #[error("Unsupported value for a TCP header field")] + UnknownTCPHeaderField(u32, u32), + + #[error("Unsupported value for an UDP header field")] + UnknownUDPHeaderField(u32, u32), + + #[error("Unsupported value for an ICMPv6 header field")] + UnknownICMPv6HeaderField(u32, u32), + + #[error("Missing the 'base' attribute to deserialize the payload object")] + PayloadMissingBase, + + #[error("Missing the 'offset' attribute to deserialize the payload object")] + PayloadMissingOffset, + + #[error("Missing the 'len' attribute to deserialize the payload object")] + PayloadMissingLen, + + #[error("The object does not contain a name for the expression being parsed")] + MissingExpressionName, + + #[error("Unsupported attribute type")] + UnsupportedAttributeType(u16), + + #[error("Unexpected message type")] + UnexpectedType(u16), + + #[error("The decoded String is not UTF8 compliant")] + StringDecodeFailure(#[from] FromUtf8Error), + + #[error("Invalid value for a protocol family")] + UnknownProtocolFamily(i32), + + #[error("A custom error occured")] + Custom(Box<dyn std::error::Error + 'static>), +} + +#[derive(thiserror::Error, Debug)] +pub enum BuilderError { + #[error("The length of the arguments are not compatible with each other")] + IncompatibleLength, + + #[error("The table does not have a name")] + MissingTableName, + + #[error("Missing information in the chain to create a rule")] + MissingChainInformationError, + + #[error("Missing name for the set")] + MissingSetName, + + #[error("The interface name is too long to be written")] + InterfaceNameTooLong, + + #[error("The log prefix string is more than 127 characters long")] + TooLongLogPrefix, +} + +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + #[error("Unable to open netlink socket to netfilter")] + NetlinkOpenError(#[source] nix::Error), + + #[error("Unable to send netlink command to netfilter")] + NetlinkSendError(#[source] nix::Error), + + #[error("Error while reading from netlink socket")] + NetlinkRecvError(#[source] nix::Error), + + #[error("Error while processing an incoming netlink message")] + ProcessNetlinkError(#[from] DecodeError), + + #[error("Error while building netlink objects in Rust")] + BuilderError(#[from] BuilderError), + + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), + + #[error("Custom error when customizing the query")] + InitError(#[from] Box<dyn std::error::Error + Send + 'static>), + + #[error("Couldn't allocate a netlink object, out of memory ?")] + NetlinkAllocationFailed, + + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, + + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, + + #[error("Only a part of the message was sent")] + TruncatedSend, + + #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")] + UndecidableMessageTermination, + + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), +} diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index d34d22c..fb40a04 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,69 +1,47 @@ -use super::{Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// Expression for performing bitwise masking and XOR on the data in a register. -pub struct Bitwise<M: ToSlice, X: ToSlice> { - mask: M, - xor: X, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::parser_impls::NfNetlinkData; +use crate::sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Bitwise { + #[field(NFTA_BITWISE_SREG)] + sreg: Register, + #[field(NFTA_BITWISE_DREG)] + dreg: Register, + #[field(NFTA_BITWISE_LEN)] + len: u32, + #[field(NFTA_BITWISE_MASK)] + mask: NfNetlinkData, + #[field(NFTA_BITWISE_XOR)] + xor: NfNetlinkData, } -impl<M: ToSlice, X: ToSlice> Bitwise<M, X> { - /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and - /// then performs xor with the value in `xor`. - pub fn new(mask: M, xor: X) -> Self { - Self { mask, xor } +impl Expression for Bitwise { + fn get_name() -> &'static str { + "bitwise" } } -impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> { - fn get_raw_name() -> *const c_char { - b"bitwise\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let mask = self.mask.to_slice(); - let xor = self.xor.to_slice(); - assert!(mask.len() == xor.len()); - let len = mask.len() as u32; - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_DREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_BITWISE_LEN as u16, len); - - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_MASK as u16, - mask.as_ref() as *const _ as *const c_void, - len, - ); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_XOR as u16, - xor.as_ref() as *const _ as *const c_void, - len, - ); - - expr +impl Bitwise { + /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and + /// then performs xor with the value in `xor` + pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> { + let mask = mask.into(); + let xor = xor.into(); + if mask.len() != xor.len() { + return Err(BuilderError::IncompatibleLength); } + Ok(Bitwise::default() + .with_sreg(Register::Reg1) + .with_dreg(Register::Reg1) + .with_len(mask.len() as u32) + .with_xor(NfNetlinkData::default().with_value(xor)) + .with_mask(NfNetlinkData::default().with_value(mask))) } } - -#[macro_export] -macro_rules! nft_expr_bitwise { - (mask $mask:expr,xor $xor:expr) => { - $crate::expr::Bitwise::new($mask, $xor) - }; -} diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index f6ea900..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,187 +1,64 @@ -use super::{DeserializationError, Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::{ - borrow::Cow, - ffi::{c_void, CString}, - os::raw::c_char, +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::{ + parser_impls::NfNetlinkData, + sys::{ + NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, + NFT_CMP_LTE, NFT_CMP_NEQ, + }, }; +use super::{Expression, Register}; + /// Comparison operator. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32, nested = true)] pub enum CmpOp { /// Equals. - Eq, + Eq = NFT_CMP_EQ, /// Not equal. - Neq, + Neq = NFT_CMP_NEQ, /// Less than. - Lt, + Lt = NFT_CMP_LT, /// Less than, or equal. - Lte, + Lte = NFT_CMP_LTE, /// Greater than. - Gt, + Gt = NFT_CMP_GT, /// Greater than, or equal. - Gte, -} - -impl CmpOp { - /// Returns the corresponding `NFT_*` constant for this comparison operation. - pub fn to_raw(self) -> u32 { - use self::CmpOp::*; - match self { - Eq => libc::NFT_CMP_EQ as u32, - Neq => libc::NFT_CMP_NEQ as u32, - Lt => libc::NFT_CMP_LT as u32, - Lte => libc::NFT_CMP_LTE as u32, - Gt => libc::NFT_CMP_GT as u32, - Gte => libc::NFT_CMP_GTE as u32, - } - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - use self::CmpOp::*; - match val as i32 { - libc::NFT_CMP_EQ => Ok(Eq), - libc::NFT_CMP_NEQ => Ok(Neq), - libc::NFT_CMP_LT => Ok(Lt), - libc::NFT_CMP_LTE => Ok(Lte), - libc::NFT_CMP_GT => Ok(Gt), - libc::NFT_CMP_GTE => Ok(Gte), - _ => Err(DeserializationError::InvalidValue), - } - } + Gte = NFT_CMP_GTE, } /// Comparator expression. Allows comparing the content of the netfilter register with any value. -#[derive(Debug, PartialEq)] -pub struct Cmp<T> { +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct] +pub struct Cmp { + #[field(NFTA_CMP_SREG)] + sreg: Register, + #[field(NFTA_CMP_OP)] op: CmpOp, - data: T, + #[field(NFTA_CMP_DATA)] + data: NfNetlinkData, } -impl<T: ToSlice> Cmp<T> { +impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: T) -> Self { - Cmp { op, data } - } -} - -impl<T: ToSlice> Expression for Cmp<T> { - fn get_raw_name() -> *const c_char { - b"cmp\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let data = self.data.to_slice(); - trace!("Creating a cmp expr comparing with data {:?}", data); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CMP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16, self.op.to_raw()); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr - } - } -} - -impl<const N: usize> Expression for Cmp<[u8; N]> { - fn get_raw_name() -> *const c_char { - Cmp::<u8>::get_raw_name() - } - - /// The raw data contained inside `Cmp` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size - /// as the raw input would be *extremely* dangerous in terms of memory safety. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?; - Ok(Cmp { op, data }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { Cmp { - data: &self.data as &[u8], - op: self.op, + sreg: Some(Register::Reg1), + op: Some(op), + data: Some(NfNetlinkData::default().with_value(data.into())), } - .to_expr(rule) } } -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_cmp { - (@cmp_op ==) => { - $crate::expr::CmpOp::Eq - }; - (@cmp_op !=) => { - $crate::expr::CmpOp::Neq - }; - (@cmp_op <) => { - $crate::expr::CmpOp::Lt - }; - (@cmp_op <=) => { - $crate::expr::CmpOp::Lte - }; - (@cmp_op >) => { - $crate::expr::CmpOp::Gt - }; - (@cmp_op >=) => { - $crate::expr::CmpOp::Gte - }; - ($op:tt $data:expr) => { - $crate::expr::Cmp::new(nft_expr_cmp!(@cmp_op $op), $data) - }; +impl Expression for Cmp { + fn get_name() -> &'static str { + "cmp" + } } +/* /// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please note /// that it is faster to check interface index than name. /// @@ -207,13 +84,4 @@ impl ToSlice for InterfaceName { Cow::from(bytes) } } - -impl<'a> ToSlice for &'a InterfaceName { - fn to_slice(&self) -> Cow<'_, [u8]> { - let bytes = match *self { - InterfaceName::Exact(ref name) => name.as_bytes_with_nul(), - InterfaceName::StartingWith(ref name) => name.as_bytes(), - }; - Cow::from(bytes) - } -} +*/ diff --git a/src/expr/counter.rs b/src/expr/counter.rs index 4732e85..d22fb8a 100644 --- a/src/expr/counter.rs +++ b/src/expr/counter.rs @@ -1,46 +1,21 @@ -use super::{DeserializationError, Expression, Rule}; +use rustables_macros::nfnetlink_struct; + +use super::Expression; use crate::sys; -use std::os::raw::c_char; /// A counter expression adds a counter to the rule that is incremented to count number of packets /// and number of bytes for all packets that have matched the rule. -#[derive(Debug, PartialEq)] +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct] pub struct Counter { + #[field(sys::NFTA_COUNTER_BYTES)] pub nb_bytes: u64, + #[field(sys::NFTA_COUNTER_PACKETS)] pub nb_packets: u64, } -impl Counter { - pub fn new() -> Self { - Self { - nb_bytes: 0, - nb_packets: 0, - } - } -} - impl Expression for Counter { - fn get_raw_name() -> *const c_char { - b"counter\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); - let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); - Ok(Counter { - nb_bytes, - nb_packets, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets); - expr - } + fn get_name() -> &'static str { + "counter" } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index 7d6614c..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -1,9 +1,13 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_CT_DIRECTION, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_CT_SREG, NFT_CT_MARK, NFT_CT_STATE, +}; + +use super::{Expression, Register}; bitflags::bitflags! { - pub struct States: u32 { + pub struct ConnTrackState: u32 { const INVALID = 1; const ESTABLISHED = 2; const RELATED = 4; @@ -12,76 +16,54 @@ bitflags::bitflags! { } } -pub enum Conntrack { - State, - Mark { set: bool }, +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_enum(u32, nested = true)] +pub enum ConntrackKey { + State = NFT_CT_STATE, + Mark = NFT_CT_MARK, } -impl Conntrack { - fn raw_key(&self) -> u32 { - match *self { - Conntrack::State => libc::NFT_CT_STATE as u32, - Conntrack::Mark { .. } => libc::NFT_CT_MARK as u32, - } - } +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Conntrack { + #[field(NFTA_CT_DREG)] + pub dreg: Register, + #[field(NFTA_CT_KEY)] + pub key: ConntrackKey, + #[field(NFTA_CT_DIRECTION)] + pub direction: u8, + #[field(NFTA_CT_SREG)] + pub sreg: Register, } impl Expression for Conntrack { - fn get_raw_name() -> *const c_char { - b"ct\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "ct" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16); - let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16); - - match ct_key as i32 { - libc::NFT_CT_STATE => Ok(Conntrack::State), - libc::NFT_CT_MARK => Ok(Conntrack::Mark { - set: ct_sreg_is_set, - }), - _ => Err(DeserializationError::InvalidValue), - } - } +impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) } - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + pub fn set_mark_value(&mut self, reg: Register) { + self.set_sreg(reg); + self.set_key(ConntrackKey::Mark); + } - if let Conntrack::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16, self.raw_key()); + pub fn with_mark_value(mut self, reg: Register) -> Self { + self.set_mark_value(reg); + self + } - expr - } + pub fn retrieve_value(&mut self, key: ConntrackKey) { + self.set_key(key); + self.set_dreg(Register::Reg1); } -} -#[macro_export] -macro_rules! nft_expr_ct { - (state) => { - $crate::expr::Conntrack::State - }; - (mark set) => { - $crate::expr::Conntrack::Mark { set: true } - }; - (mark) => { - $crate::expr::Conntrack::Mark { set: false } - }; + pub fn with_retrieve_value(mut self, key: ConntrackKey) -> Self { + self.retrieve_value(key); + self + } } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 71453b3..2fd9bd5 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,124 +1,50 @@ -use super::{DeserializationError, Expression, Register, Rule, ToSlice}; -use crate::sys; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// An immediate expression. Used to set immediate data. Verdicts are handled separately by -/// [crate::expr::Verdict]. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Immediate<T> { - pub data: T, - pub register: Register, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register, Verdict, VerdictKind, VerdictType}; +use crate::{ + parser_impls::NfNetlinkData, + sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Immediate { + #[field(NFTA_IMMEDIATE_DREG)] + dreg: Register, + #[field(NFTA_IMMEDIATE_DATA)] + data: NfNetlinkData, } -impl<T> Immediate<T> { - pub fn new(data: T, register: Register) -> Self { - Self { data, register } +impl Immediate { + pub fn new_data(data: Vec<u8>, register: Register) -> Self { + Immediate::default() + .with_dreg(register) + .with_data(NfNetlinkData::default().with_value(data)) } -} - -impl<T: ToSlice> Expression for Immediate<T> { - fn get_raw_name() -> *const c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - self.register.to_raw(), - ); - - let data = self.data.to_slice(); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr + pub fn new_verdict(kind: VerdictKind) -> Self { + let code = match kind { + VerdictKind::Drop => VerdictType::Drop, + VerdictKind::Accept => VerdictType::Accept, + VerdictKind::Queue => VerdictType::Queue, + VerdictKind::Continue => VerdictType::Continue, + VerdictKind::Break => VerdictType::Break, + VerdictKind::Jump { .. } => VerdictType::Jump, + VerdictKind::Goto { .. } => VerdictType::Goto, + VerdictKind::Return => VerdictType::Return, + }; + let mut data = Verdict::default().with_code(code); + if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { + data.set_chain(chain); } + Immediate::default() + .with_dreg(Register::Verdict) + .with_data(NfNetlinkData::default().with_verdict(data)) } } -impl<const N: usize> Expression for Immediate<[u8; N]> { - fn get_raw_name() -> *const c_char { - Immediate::<u8>::get_raw_name() - } - - /// The raw data contained inside `Immediate` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Immediate::new(42u8, Register::Reg1)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size as the raw input - /// would be *extremely* dangerous in terms of memory safety. - // As casting bytes to any type of the same size as the input would be *extremely* dangerous in - // terms of memory safety, rustables only accept to deserialize expressions with variable-size - // data to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - ))?; - - Ok(Immediate { data, register }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - Immediate { - register: self.register, - data: &self.data as &[u8], - } - .to_expr(rule) +impl Expression for Immediate { + fn get_name() -> &'static str { + "immediate" } } - -#[macro_export] -macro_rules! nft_expr_immediate { - (data $value:expr) => { - $crate::expr::Immediate { - data: $value, - register: $crate::expr::Register::Reg1, - } - }; -} diff --git a/src/expr/log.rs b/src/expr/log.rs index 8d20b48..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,112 +1,41 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] /// A Log expression will log all packets that match the rule. -#[derive(Debug, PartialEq)] pub struct Log { - pub group: Option<LogGroup>, - pub prefix: Option<LogPrefix>, + #[field(NFTA_LOG_GROUP)] + group: u16, + #[field(NFTA_LOG_PREFIX)] + prefix: String, } -impl Expression for Log { - fn get_raw_name() -> *const sys::libc::c_char { - b"log\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut group = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) { - group = Some(LogGroup(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_LOG_GROUP as u16, - ) as u16)); - } - let mut prefix = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { - let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); - if raw_prefix.is_null() { - return Err(DeserializationError::NullPointer); - } else { - prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); - } - } - Ok(Log { group, prefix }) +impl Log { + pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> { + let mut res = Log::default(); + if let Some(group) = group { + res.set_group(group); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char)); - if let Some(log_group) = self.group { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32); - }; - if let Some(LogPrefix(prefix)) = &self.prefix { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr()); - }; + if let Some(prefix) = prefix { + let prefix = prefix.into(); - expr + if prefix.bytes().count() > 127 { + return Err(BuilderError::TooLongLogPrefix); + } + res.set_prefix(prefix); } + Ok(res) } } -#[derive(Error, Debug)] -pub enum LogPrefixError { - #[error("The log prefix string is more than 128 characters long")] - TooLongPrefix, - #[error("The log prefix string contains an invalid Nul character.")] - PrefixContainsANul(#[from] std::ffi::NulError), -} - -/// The NFLOG group that will be assigned to each log line. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct LogGroup(pub u16); - -/// A prefix that will get prepended to each log line. -#[derive(Debug, Clone, PartialEq)] -pub struct LogPrefix(CString); - -impl LogPrefix { - /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that - /// LogPrefix should not be more than 127 characters long. - pub fn new(prefix: &str) -> Result<Self, LogPrefixError> { - if prefix.chars().count() > 127 { - return Err(LogPrefixError::TooLongPrefix); - } - Ok(LogPrefix(CString::new(prefix)?)) +impl Expression for Log { + fn get_name() -> &'static str { + "log" } } - -#[macro_export] -macro_rules! nft_expr_log { - (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log { - group: $group, - prefix: $prefix, - } - }; - (prefix $prefix:expr) => { - $crate::expr::Log { - group: None, - prefix: $prefix, - } - }; - (group $group:ident) => { - $crate::expr::Log { - group: $group, - prefix: None, - } - }; - () => { - $crate::expr::Log { - group: None, - prefix: None, - } - }; -} diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs index a0cc021..2ef830e 100644 --- a/src/expr/lookup.rs +++ b/src/expr/lookup.rs @@ -1,78 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::set::Set; -use crate::sys::{self, libc}; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -#[derive(Debug, PartialEq)] +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG}; +use crate::Set; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] pub struct Lookup { - set_name: CString, + #[field(NFTA_LOOKUP_SET)] + set: String, + #[field(NFTA_LOOKUP_SREG)] + sreg: Register, + #[field(NFTA_LOOKUP_DREG)] + dreg: Register, + #[field(NFTA_LOOKUP_SET_ID)] set_id: u32, } impl Lookup { - /// Creates a new lookup entry. May return None if the set has no name. - pub fn new<K>(set: &Set<K>) -> Option<Self> { - set.get_name().map(|set_name| Lookup { - set_name: set_name.to_owned(), - set_id: set.get_id(), - }) - } -} - -impl Expression for Lookup { - fn get_raw_name() -> *const libc::c_char { - b"lookup\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16); - let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); - - if set_name.is_null() { - return Err(DeserializationError::NullPointer); - } - - let set_name = CStr::from_ptr(set_name).to_owned(); - - Ok(Lookup { set_id, set_name }) + /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name. + pub fn new(set: &Set) -> Result<Self, BuilderError> { + let mut res = Lookup::default() + .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?) + .with_sreg(Register::Reg1); + + if let Some(id) = set.get_id() { + res.set_set_id(*id); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_LOOKUP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_str( - expr, - sys::NFTNL_EXPR_LOOKUP_SET as u16, - self.set_name.as_ptr() as *const _ as *const c_char, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id); - // This code is left here since it's quite likely we need it again when we get further - // if self.reverse { - // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16, - // libc::NFT_LOOKUP_F_INV as u32); - // } - - expr - } + Ok(res) } } -#[macro_export] -macro_rules! nft_expr_lookup { - ($set:expr) => { - $crate::expr::Lookup::new($set) - }; +impl Expression for Lookup { + fn get_name() -> &'static str { + "lookup" + } } diff --git a/src/expr/masquerade.rs b/src/expr/masquerade.rs index c1a06de..dce787f 100644 --- a/src/expr/masquerade.rs +++ b/src/expr/masquerade.rs @@ -1,24 +1,20 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; + +use super::Expression; /// Sets the source IP to that of the output interface. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Masquerade; -impl Expression for Masquerade { - fn get_raw_name() -> *const sys::libc::c_char { - b"masq\0" as *const _ as *const c_char - } - - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Ok(Masquerade) +impl Clone for Masquerade { + fn clone(&self) -> Self { + Masquerade {} } +} - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }) +impl Expression for Masquerade { + fn get_name() -> &'static str { + "masq" } } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index a015f65..3ecb1d1 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,175 +1,62 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::sys; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32)] #[non_exhaustive] -pub enum Meta { +pub enum MetaType { /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. - Protocol, + Protocol = sys::NFT_META_PROTOCOL, /// Packet mark. - Mark { set: bool }, + Mark = sys::NFT_META_MARK, /// Packet input interface index (dev->ifindex). - Iif, + Iif = sys::NFT_META_IIF, /// Packet output interface index (dev->ifindex). - Oif, + Oif = sys::NFT_META_OIF, /// Packet input interface name (dev->name). - IifName, + IifName = sys::NFT_META_IIFNAME, /// Packet output interface name (dev->name). - OifName, + OifName = sys::NFT_META_OIFNAME, /// Packet input interface type (dev->type). - IifType, + IifType = libc::NFT_META_IIFTYPE, /// Packet output interface type (dev->type). - OifType, + OifType = sys::NFT_META_OIFTYPE, /// Originating socket UID (fsuid). - SkUid, + SkUid = sys::NFT_META_SKUID, /// Originating socket GID (fsgid). - SkGid, + SkGid = sys::NFT_META_SKGID, /// Netfilter protocol (Transport layer protocol). - NfProto, + NfProto = sys::NFT_META_NFPROTO, /// Layer 4 protocol number. - L4Proto, + L4Proto = sys::NFT_META_L4PROTO, /// Socket control group (skb->sk->sk_classid). - Cgroup, + Cgroup = sys::NFT_META_CGROUP, /// A 32bit pseudo-random number. - PRandom, + PRandom = sys::NFT_META_PRANDOM, } -impl Meta { - /// Returns the corresponding `NFT_*` constant for this meta expression. - pub fn to_raw_key(&self) -> u32 { - use Meta::*; - match *self { - Protocol => libc::NFT_META_PROTOCOL as u32, - Mark { .. } => libc::NFT_META_MARK as u32, - Iif => libc::NFT_META_IIF as u32, - Oif => libc::NFT_META_OIF as u32, - IifName => libc::NFT_META_IIFNAME as u32, - OifName => libc::NFT_META_OIFNAME as u32, - IifType => libc::NFT_META_IIFTYPE as u32, - OifType => libc::NFT_META_OIFTYPE as u32, - SkUid => libc::NFT_META_SKUID as u32, - SkGid => libc::NFT_META_SKGID as u32, - NfProto => libc::NFT_META_NFPROTO as u32, - L4Proto => libc::NFT_META_L4PROTO as u32, - Cgroup => libc::NFT_META_CGROUP as u32, - PRandom => libc::NFT_META_PRANDOM as u32, - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Meta { + #[field(sys::NFTA_META_DREG)] + dreg: Register, + #[field(sys::NFTA_META_KEY)] + key: MetaType, + #[field(sys::NFTA_META_SREG)] + sreg: Register, +} - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_META_PROTOCOL => Ok(Self::Protocol), - libc::NFT_META_MARK => Ok(Self::Mark { set: false }), - libc::NFT_META_IIF => Ok(Self::Iif), - libc::NFT_META_OIF => Ok(Self::Oif), - libc::NFT_META_IIFNAME => Ok(Self::IifName), - libc::NFT_META_OIFNAME => Ok(Self::OifName), - libc::NFT_META_IIFTYPE => Ok(Self::IifType), - libc::NFT_META_OIFTYPE => Ok(Self::OifType), - libc::NFT_META_SKUID => Ok(Self::SkUid), - libc::NFT_META_SKGID => Ok(Self::SkGid), - libc::NFT_META_NFPROTO => Ok(Self::NfProto), - libc::NFT_META_L4PROTO => Ok(Self::L4Proto), - libc::NFT_META_CGROUP => Ok(Self::Cgroup), - libc::NFT_META_PRANDOM => Ok(Self::PRandom), - _ => Err(DeserializationError::InvalidValue), - } +impl Meta { + pub fn new(ty: MetaType) -> Self { + Meta::default().with_dreg(Register::Reg1).with_key(ty) } } impl Expression for Meta { - fn get_raw_name() -> *const libc::c_char { - b"meta\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_META_KEY as u16, - ))?; - - if let Self::Mark { ref mut set } = ret { - *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); - } - - Ok(ret) - } + fn get_name() -> &'static str { + "meta" } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - if let Meta::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key()); - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_meta { - (proto) => { - $crate::expr::Meta::Protocol - }; - (mark set) => { - $crate::expr::Meta::Mark { set: true } - }; - (mark) => { - $crate::expr::Meta::Mark { set: false } - }; - (iif) => { - $crate::expr::Meta::Iif - }; - (oif) => { - $crate::expr::Meta::Oif - }; - (iifname) => { - $crate::expr::Meta::IifName - }; - (oifname) => { - $crate::expr::Meta::OifName - }; - (iiftype) => { - $crate::expr::Meta::IifType - }; - (oiftype) => { - $crate::expr::Meta::OifType - }; - (skuid) => { - $crate::expr::Meta::SkUid - }; - (skgid) => { - $crate::expr::Meta::SkGid - }; - (nfproto) => { - $crate::expr::Meta::NfProto - }; - (l4proto) => { - $crate::expr::Meta::L4Proto - }; - (cgroup) => { - $crate::expr::Meta::Cgroup - }; - (random) => { - $crate::expr::Meta::PRandom - }; } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index dc59507..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -3,14 +3,14 @@ //! //! [`Rule`]: struct.Rule.html -use std::borrow::Cow; -use std::net::IpAddr; -use std::net::Ipv4Addr; -use std::net::Ipv6Addr; +use std::fmt::Debug; -use super::rule::Rule; -use crate::sys::{self, libc}; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; + +use crate::error::DecodeError; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; +use crate::parser_impls::NfNetlinkList; +use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME}; mod bitwise; pub use self::bitwise::*; @@ -46,7 +46,7 @@ mod payload; pub use self::payload::*; mod reject; -pub use self::reject::{IcmpCode, Reject}; +pub use self::reject::{IcmpCode, Reject, RejectType}; mod register; pub use self::register::Register; @@ -54,189 +54,161 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -mod wrapper; -pub use self::wrapper::ExpressionWrapper; - -#[derive(Debug, Error)] -pub enum DeserializationError { - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, - - #[error(transparent)] - /// Couldn't find a matching protocol. - InvalidProtolFamily(#[from] super::InvalidProtocolFamily), -} - -/// Trait for every safe wrapper of an nftables expression. pub trait Expression { - /// Returns the raw name used by nftables to identify the rule. - fn get_raw_name() -> *const libc::c_char; - - /// Try to parse the expression from a raw nftables expression, returning a - /// [DeserializationError] if the attempted parsing failed. - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Err(DeserializationError::NotImplemented) - } - - /// Allocates and returns the low level `nftnl_expr` representation of this expression. The - /// caller to this method is responsible for freeing the expression. - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; + fn get_name() -> &'static str; } -/// A type that can be converted into a byte buffer. -pub trait ToSlice { - /// Returns the data this type represents. - fn to_slice(&self) -> Cow<'_, [u8]>; +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true, derive_decoder = false)] +pub struct RawExpression { + #[field(NFTA_EXPR_NAME)] + name: String, + #[field(NFTA_EXPR_DATA)] + data: ExpressionVariant, } -impl<'a> ToSlice for &'a [u8] { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Borrowed(self) +impl<T> From<T> for RawExpression +where + T: Expression, + ExpressionVariant: From<T>, +{ + fn from(val: T) -> Self { + RawExpression::default() + .with_name(T::get_name()) + .with_data(ExpressionVariant::from(val)) } } -impl<'a> ToSlice for &'a [u16] { - fn to_slice(&self) -> Cow<'_, [u8]> { - let ptr = self.as_ptr() as *const u8; - let len = self.len() * 2; - Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) - } -} - -impl ToSlice for IpAddr { - fn to_slice(&self) -> Cow<'_, [u8]> { - match *self { - IpAddr::V4(ref addr) => addr.to_slice(), - IpAddr::V6(ref addr) => addr.to_slice(), +macro_rules! create_expr_variant { + ($enum:ident $(, [$name:ident, $type:ty])+) => { + #[derive(Debug, Clone, PartialEq, Eq)] + pub enum $enum { + $( + $name($type), + )+ } - } -} -impl ToSlice for Ipv4Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} - -impl ToSlice for Ipv6Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} + impl $crate::nlmsg::NfNetlinkAttribute for $enum { + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + match self { + $( + $enum::$name(val) => val.get_size(), + )+ + } + } + + unsafe fn write_payload(&self, addr: *mut u8) { + match self { + $( + $enum::$name(val) => val.write_payload(addr), + )+ + } + } + } -impl ToSlice for u8 { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(vec![*self]) - } + $( + impl From<$type> for $enum { + fn from(val: $type) -> Self { + $enum::$name(val) + } + } + )+ + + impl $crate::nlmsg::AttributeDecoder for RawExpression { + fn decode_attribute( + &mut self, + attr_type: u16, + buf: &[u8], + ) -> Result<(), $crate::error::DecodeError> { + debug!("Decoding attribute {} in an expression", attr_type); + match attr_type { + x if x == sys::NFTA_EXPR_NAME => { + debug!("Calling {}::deserialize()", std::any::type_name::<String>()); + let (val, remaining) = String::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.name = Some(val); + Ok(()) + }, + x if x == sys::NFTA_EXPR_DATA => { + // we can assume we have already the name parsed, as that's how we identify the + // type of expression + let name = self.name.as_ref() + .ok_or($crate::error::DecodeError::MissingExpressionName)?; + match name { + $( + x if x == <$type>::get_name() => { + debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); + let (res, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.data = Some(ExpressionVariant::from(res)); + Ok(()) + }, + )+ + name => { + info!("Unrecognized expression '{}', generating an ExpressionRaw", name); + self.data = Some(ExpressionVariant::ExpressionRaw(ExpressionRaw::deserialize(buf)?.0)); + Ok(()) + } + } + }, + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + }; } -impl ToSlice for u16 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = (*self & 0x00ff) as u8; - let b1 = (*self >> 8) as u8; - Cow::Owned(vec![b0, b1]) +create_expr_variant!( + ExpressionVariant, + [Bitwise, Bitwise], + [Cmp, Cmp], + [Conntrack, Conntrack], + [Counter, Counter], + [ExpressionRaw, ExpressionRaw], + [Immediate, Immediate], + [Log, Log], + [Lookup, Lookup], + [Masquerade, Masquerade], + [Meta, Meta], + [Nat, Nat], + [Payload, Payload], + [Reject, Reject] +); + +pub type ExpressionList = NfNetlinkList<RawExpression>; + +// default type for expressions that we do not handle yet +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExpressionRaw(Vec<u8>); + +impl NfNetlinkAttribute for ExpressionRaw { + fn get_size(&self) -> usize { + self.0.get_size() } -} -impl ToSlice for u32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) + unsafe fn write_payload(&self, addr: *mut u8) { + self.0.write_payload(addr); } } -impl ToSlice for i32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) +impl NfNetlinkDeserializable for ExpressionRaw { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((ExpressionRaw(buf.to_vec()), &[])) } } -impl<'a> ToSlice for &'a str { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::from(self.as_bytes()) +// Because we loose the name of the expression when parsing, this is the only expression +// where deserializing a message and then reserializing it is invalid +impl Expression for ExpressionRaw { + fn get_name() -> &'static str { + "unknown_expression" } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr { - (bitwise mask $mask:expr,xor $xor:expr) => { - nft_expr_bitwise!(mask $mask, xor $xor) - }; - (cmp $op:tt $data:expr) => { - nft_expr_cmp!($op $data) - }; - (counter) => { - $crate::expr::Counter { nb_bytes: 0, nb_packets: 0} - }; - (ct $key:ident set) => { - nft_expr_ct!($key set) - }; - (ct $key:ident) => { - nft_expr_ct!($key) - }; - (immediate $expr:ident $value:expr) => { - nft_expr_immediate!($expr $value) - }; - (log group $group:ident prefix $prefix:expr) => { - nft_expr_log!(group $group prefix $prefix) - }; - (log group $group:ident) => { - nft_expr_log!(group $group) - }; - (log prefix $prefix:expr) => { - nft_expr_log!(prefix $prefix) - }; - (log) => { - nft_expr_log!() - }; - (lookup $set:expr) => { - nft_expr_lookup!($set) - }; - (masquerade) => { - $crate::expr::Masquerade - }; - (meta $expr:ident set) => { - nft_expr_meta!($expr set) - }; - (meta $expr:ident) => { - nft_expr_meta!($expr) - }; - (payload $proto:ident $field:ident) => { - nft_expr_payload!($proto $field) - }; - (verdict $verdict:ident) => { - nft_expr_verdict!($verdict) - }; - (verdict $verdict:ident $chain:expr) => { - nft_expr_verdict!($verdict $chain) - }; -} diff --git a/src/expr/nat.rs b/src/expr/nat.rs index ce6b881..406b2e6 100644 --- a/src/expr/nat.rs +++ b/src/expr/nat.rs @@ -1,99 +1,37 @@ -use super::{DeserializationError, Expression, Register, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc}; -use std::{convert::TryFrom, os::raw::c_char}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::{ + sys::{self, NFT_NAT_DNAT, NFT_NAT_SNAT}, + ProtocolFamily, +}; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(i32)] pub enum NatType { /// Source NAT. Changes the source address of a packet. - SNat = libc::NFT_NAT_SNAT, + SNat = NFT_NAT_SNAT, /// Destination NAT. Changes the destination address of a packet. - DNat = libc::NFT_NAT_DNAT, -} - -impl NatType { - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_NAT_SNAT => Ok(NatType::SNat), - libc::NFT_NAT_DNAT => Ok(NatType::DNat), - _ => Err(DeserializationError::InvalidValue), - } - } + DNat = NFT_NAT_DNAT, } /// A source or destination NAT statement. Modifies the source or destination address (and possibly /// port) of packets. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Nat { + #[field(sys::NFTA_NAT_TYPE)] pub nat_type: NatType, - pub family: ProtoFamily, + #[field(sys::NFTA_NAT_FAMILY)] + pub family: ProtocolFamily, + #[field(sys::NFTA_NAT_REG_ADDR_MIN)] pub ip_register: Register, - pub port_register: Option<Register>, + #[field(sys::NFTA_NAT_REG_PROTO_MIN)] + pub port_register: Register, } impl Expression for Nat { - fn get_raw_name() -> *const libc::c_char { - b"nat\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_TYPE as u16, - ))?; - - let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_FAMILY as u16, - ) as i32)?; - - let ip_register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - ))?; - - let mut port_register = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) { - port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - ))?); - } - - Ok(Nat { - ip_register, - nat_type, - family, - port_register, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }); - - unsafe { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, self.family as u32); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - self.ip_register.to_raw(), - ); - if let Some(port_register) = self.port_register { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - port_register.to_raw(), - ); - } - } - - expr + fn get_name() -> &'static str { + "nat" } } diff --git a/src/expr/payload.rs b/src/expr/payload.rs index a108fe8..d0b2cea 100644 --- a/src/expr/payload.rs +++ b/src/expr/payload.rs @@ -1,128 +1,96 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -pub trait HeaderField { - fn offset(&self) -> u32; - fn len(&self) -> u32; +use super::{Expression, Register}; +use crate::{ + error::DecodeError, + sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER}, +}; + +/// Payload expressions refer to data from the packet's payload. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Payload { + #[field(sys::NFTA_PAYLOAD_DREG)] + dreg: Register, + #[field(sys::NFTA_PAYLOAD_BASE)] + base: u32, + #[field(sys::NFTA_PAYLOAD_OFFSET)] + offset: u32, + #[field(sys::NFTA_PAYLOAD_LEN)] + len: u32, + #[field(sys::NFTA_PAYLOAD_SREG)] + sreg: Register, +} + +impl Expression for Payload { + fn get_name() -> &'static str { + "payload" + } } /// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum Payload { +pub enum HighLevelPayload { LinkLayer(LLHeaderField), Network(NetworkHeaderField), Transport(TransportHeaderField), } -impl Payload { - pub fn build(&self) -> RawPayload { +impl HighLevelPayload { + pub fn build(&self) -> Payload { match *self { - Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Network(ref f) => RawPayload::Network(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), + HighLevelPayload::LinkLayer(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_LL_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Network(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_NETWORK_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Transport(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_TRANSPORT_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), } + .with_dreg(Register::Reg1) } } -impl Expression for Payload { - fn get_raw_name() -> *const libc::c_char { - RawPayload::get_raw_name() - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - self.build().to_expr(rule) - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct RawPayloadData { - offset: u32, - len: u32, -} - -/// Because deserializing a `Payload` expression is not possible (there is not enough information -/// in the expression itself), this enum should be used to deserialize payloads. +/// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum RawPayload { - LinkLayer(RawPayloadData), - Network(RawPayloadData), - Transport(RawPayloadData), +pub enum PayloadType { + LinkLayer(LLHeaderField), + Network, + Transport, } -impl RawPayload { - fn base(&self) -> u32 { - match self { - Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32, - Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32, - Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32, +impl PayloadType { + pub fn parse_from_payload(raw: &Payload) -> Result<Self, DecodeError> { + if raw.base.is_none() { + return Err(DecodeError::PayloadMissingBase); } - } -} - -impl HeaderField for RawPayload { - fn offset(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset, + if raw.len.is_none() { + return Err(DecodeError::PayloadMissingLen); } - } - - fn len(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len, + if raw.offset.is_none() { + return Err(DecodeError::PayloadMissingOffset); } + Ok(match raw.base { + Some(NFT_PAYLOAD_LL_HEADER) => PayloadType::LinkLayer(LLHeaderField::from_raw_data( + raw.offset.unwrap(), + raw.len.unwrap(), + )?), + Some(NFT_PAYLOAD_NETWORK_HEADER) => PayloadType::Network, + Some(NFT_PAYLOAD_TRANSPORT_HEADER) => PayloadType::Transport, + Some(v) => return Err(DecodeError::UnknownPayloadType(v)), + None => return Err(DecodeError::PayloadMissingBase), + }) } } -impl Expression for RawPayload { - fn get_raw_name() -> *const libc::c_char { - b"payload\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16); - let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16); - let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16); - match base as i32 { - libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })), - libc::NFT_PAYLOAD_NETWORK_HEADER => { - Ok(Self::Network(RawPayloadData { offset, len })) - } - libc::NFT_PAYLOAD_TRANSPORT_HEADER => { - Ok(Self::Transport(RawPayloadData { offset, len })) - } - - _ => return Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16, self.len()); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_PAYLOAD_DREG as u16, - libc::NFT_REG_1 as u32, - ); - - expr - } - } +pub trait HeaderField { + fn offset(&self) -> u32; + fn len(&self) -> u32; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -154,58 +122,52 @@ impl HeaderField for LLHeaderField { } impl LLHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 6 { - Ok(Self::Daddr) - } else if off == 6 && len == 6 { - Ok(Self::Saddr) - } else if off == 12 && len == 2 { - Ok(Self::EtherType) - } else { - Err(DeserializationError::InvalidValue) - } + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 6) => Self::Daddr, + (6, 6) => Self::Saddr, + (12, 2) => Self::EtherType, + _ => return Err(DecodeError::UnknownLinkLayerHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum NetworkHeaderField { - Ipv4(Ipv4HeaderField), - Ipv6(Ipv6HeaderField), + IPv4(IPv4HeaderField), + IPv6(IPv6HeaderField), } impl HeaderField for NetworkHeaderField { fn offset(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.offset(), - Ipv6(ref f) => f.offset(), + IPv4(ref f) => f.offset(), + IPv6(ref f) => f.offset(), } } fn len(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.len(), - Ipv6(ref f) => f.len(), + IPv4(ref f) => f.len(), + IPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv4HeaderField { +pub enum IPv4HeaderField { Ttl, Protocol, Saddr, Daddr, } -impl HeaderField for Ipv4HeaderField { +impl HeaderField for IPv4HeaderField { fn offset(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 8, Protocol => 9, @@ -215,7 +177,7 @@ impl HeaderField for Ipv4HeaderField { } fn len(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 1, Protocol => 1, @@ -225,37 +187,30 @@ impl HeaderField for Ipv4HeaderField { } } -impl Ipv4HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 8 && len == 1 { - Ok(Self::Ttl) - } else if off == 9 && len == 1 { - Ok(Self::Protocol) - } else if off == 12 && len == 4 { - Ok(Self::Saddr) - } else if off == 16 && len == 4 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv4HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (8, 1) => Self::Ttl, + (9, 1) => Self::Protocol, + (12, 4) => Self::Saddr, + (16, 4) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv4HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv6HeaderField { +pub enum IPv6HeaderField { NextHeader, HopLimit, Saddr, Daddr, } -impl HeaderField for Ipv6HeaderField { +impl HeaderField for IPv6HeaderField { fn offset(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 6, HopLimit => 7, @@ -265,7 +220,7 @@ impl HeaderField for Ipv6HeaderField { } fn len(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 1, HopLimit => 1, @@ -275,31 +230,24 @@ impl HeaderField for Ipv6HeaderField { } } -impl Ipv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 6 && len == 1 { - Ok(Self::NextHeader) - } else if off == 7 && len == 1 { - Ok(Self::HopLimit) - } else if off == 8 && len == 16 { - Ok(Self::Saddr) - } else if off == 24 && len == 16 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (6, 1) => Self::NextHeader, + (7, 1) => Self::HopLimit, + (8, 16) => Self::Saddr, + (24, 16) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv6HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] pub enum TransportHeaderField { - Tcp(TcpHeaderField), - Udp(UdpHeaderField), - Icmpv6(Icmpv6HeaderField), + Tcp(TCPHeaderField), + Udp(UDPHeaderField), + ICMPv6(ICMPv6HeaderField), } impl HeaderField for TransportHeaderField { @@ -308,7 +256,7 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.offset(), Udp(ref f) => f.offset(), - Icmpv6(ref f) => f.offset(), + ICMPv6(ref f) => f.offset(), } } @@ -317,21 +265,21 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.len(), Udp(ref f) => f.len(), - Icmpv6(ref f) => f.len(), + ICMPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum TcpHeaderField { +pub enum TCPHeaderField { Sport, Dport, } -impl HeaderField for TcpHeaderField { +impl HeaderField for TCPHeaderField { fn offset(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -339,7 +287,7 @@ impl HeaderField for TcpHeaderField { } fn len(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -347,32 +295,27 @@ impl HeaderField for TcpHeaderField { } } -impl TcpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else { - Err(DeserializationError::InvalidValue) - } +impl TCPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + _ => return Err(DecodeError::UnknownTCPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum UdpHeaderField { +pub enum UDPHeaderField { Sport, Dport, Len, } -impl HeaderField for UdpHeaderField { +impl HeaderField for UDPHeaderField { fn offset(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -381,7 +324,7 @@ impl HeaderField for UdpHeaderField { } fn len(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -390,34 +333,28 @@ impl HeaderField for UdpHeaderField { } } -impl UdpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else if off == 4 && len == 2 { - Ok(Self::Len) - } else { - Err(DeserializationError::InvalidValue) - } +impl UDPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + (4, 2) => Self::Len, + _ => return Err(DecodeError::UnknownUDPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Icmpv6HeaderField { +pub enum ICMPv6HeaderField { Type, Code, Checksum, } -impl HeaderField for Icmpv6HeaderField { +impl HeaderField for ICMPv6HeaderField { fn offset(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 0, Code => 1, @@ -426,7 +363,7 @@ impl HeaderField for Icmpv6HeaderField { } fn len(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 1, Code => 1, @@ -435,97 +372,13 @@ impl HeaderField for Icmpv6HeaderField { } } -impl Icmpv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 1 { - Ok(Self::Type) - } else if off == 1 && len == 1 { - Ok(Self::Code) - } else if off == 2 && len == 2 { - Ok(Self::Checksum) - } else { - Err(DeserializationError::InvalidValue) - } +impl ICMPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 1) => Self::Type, + (1, 1) => Self::Code, + (2, 2) => Self::Checksum, + _ => return Err(DecodeError::UnknownICMPv6HeaderField(offset, len)), + }) } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_payload { - (@ipv4_field ttl) => { - $crate::expr::Ipv4HeaderField::Ttl - }; - (@ipv4_field protocol) => { - $crate::expr::Ipv4HeaderField::Protocol - }; - (@ipv4_field saddr) => { - $crate::expr::Ipv4HeaderField::Saddr - }; - (@ipv4_field daddr) => { - $crate::expr::Ipv4HeaderField::Daddr - }; - - (@ipv6_field nextheader) => { - $crate::expr::Ipv6HeaderField::NextHeader - }; - (@ipv6_field hoplimit) => { - $crate::expr::Ipv6HeaderField::HopLimit - }; - (@ipv6_field saddr) => { - $crate::expr::Ipv6HeaderField::Saddr - }; - (@ipv6_field daddr) => { - $crate::expr::Ipv6HeaderField::Daddr - }; - - (@tcp_field sport) => { - $crate::expr::TcpHeaderField::Sport - }; - (@tcp_field dport) => { - $crate::expr::TcpHeaderField::Dport - }; - - (@udp_field sport) => { - $crate::expr::UdpHeaderField::Sport - }; - (@udp_field dport) => { - $crate::expr::UdpHeaderField::Dport - }; - (@udp_field len) => { - $crate::expr::UdpHeaderField::Len - }; - - (ethernet daddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Daddr) - }; - (ethernet saddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Saddr) - }; - (ethernet ethertype) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::EtherType) - }; - - (ipv4 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv4( - nft_expr_payload!(@ipv4_field $field), - )) - }; - (ipv6 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv6( - nft_expr_payload!(@ipv6_field $field), - )) - }; - - (tcp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Tcp( - nft_expr_payload!(@tcp_field $field), - )) - }; - (udp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Udp( - nft_expr_payload!(@udp_field $field), - )) - }; -} diff --git a/src/expr/register.rs b/src/expr/register.rs index a05af7e..9cc1bee 100644 --- a/src/expr/register.rs +++ b/src/expr/register.rs @@ -1,34 +1,17 @@ use std::fmt::Debug; -use crate::sys::libc; +use rustables_macros::nfnetlink_enum; -use super::DeserializationError; +use crate::sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}; /// A netfilter data register. The expressions store and read data to and from these when /// evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(u32)] pub enum Register { - Verdict = libc::NFT_REG_VERDICT, - Reg1 = libc::NFT_REG_1, - Reg2 = libc::NFT_REG_2, - Reg3 = libc::NFT_REG_3, - Reg4 = libc::NFT_REG_4, -} - -impl Register { - pub fn to_raw(self) -> u32 { - self as u32 - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_REG_VERDICT => Ok(Self::Verdict), - libc::NFT_REG_1 => Ok(Self::Reg1), - libc::NFT_REG_2 => Ok(Self::Reg2), - libc::NFT_REG_3 => Ok(Self::Reg3), - libc::NFT_REG_4 => Ok(Self::Reg4), - _ => Err(DeserializationError::InvalidValue), - } - } + Verdict = NFT_REG_VERDICT, + Reg1 = NFT_REG_1, + Reg2 = NFT_REG_2, + Reg3 = NFT_REG_3, + Reg4 = NFT_REG_4, } diff --git a/src/expr/reject.rs b/src/expr/reject.rs index 19752ce..83fd843 100644 --- a/src/expr/reject.rs +++ b/src/expr/reject.rs @@ -1,95 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc::{self, c_char}}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -/// A reject expression that defines the type of rejection message sent when discarding a packet. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub enum Reject { - /// Returns an ICMP unreachable packet. - Icmp(IcmpCode), - /// Rejects by sending a TCP RST packet. - TcpRst, -} +use crate::sys; -impl Reject { - fn to_raw(&self, family: ProtoFamily) -> u32 { - use libc::*; - let value = match *self { - Self::Icmp(..) => match family { - ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH, - _ => NFT_REJECT_ICMP_UNREACH, - }, - Self::TcpRst => NFT_REJECT_TCP_RST, - }; - value as u32 - } -} +use super::Expression; impl Expression for Reject { - fn get_raw_name() -> *const libc::c_char { - b"reject\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "reject" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16) - == libc::NFT_REJECT_TCP_RST as u32 - { - Ok(Self::TcpRst) - } else { - Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8( - expr, - sys::NFTNL_EXPR_REJECT_CODE as u16, - ))?)) - } - } - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - let family = rule.get_chain().get_table().get_family(); - - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_REJECT_TYPE as u16, - self.to_raw(family), - ); - - let reject_code = match *self { - Reject::Icmp(code) => code as u8, - Reject::TcpRst => 0, - }; - - sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code); - - expr - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +/// A reject expression that defines the type of rejection message sent when discarding a packet. +pub struct Reject { + #[field(sys::NFTA_REJECT_TYPE, name_in_functions = "type")] + reject_type: RejectType, + #[field(sys::NFTA_REJECT_ICMP_CODE)] + icmp_code: IcmpCode, } /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -#[repr(u8)] -pub enum IcmpCode { - NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8, - PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8, - HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8, - AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8, +#[nfnetlink_enum(u32)] +pub enum RejectType { + IcmpUnreach = sys::NFT_REJECT_ICMP_UNREACH, + TcpRst = sys::NFT_REJECT_TCP_RST, + IcmpxUnreach = sys::NFT_REJECT_ICMPX_UNREACH, } -impl IcmpCode { - fn from_raw(code: u8) -> Result<Self, DeserializationError> { - match code as i32 { - libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute), - libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach), - libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach), - libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited), - _ => Err(DeserializationError::InvalidValue), - } - } +/// An ICMP reject code. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +#[nfnetlink_enum(u8)] +pub enum IcmpCode { + NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE, + PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH, + HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH, + AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED, } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 3c4c374..7edf7cd 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,11 +1,39 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc::{self, c_char}}; -use std::ffi::{CStr, CString}; +use std::fmt::Debug; + +use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, + NFT_GOTO, NFT_JUMP, NFT_RETURN, +}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[nfnetlink_enum(i32)] +pub enum VerdictType { + Drop = NF_DROP, + Accept = NF_ACCEPT, + Queue = NF_QUEUE, + Continue = NFT_CONTINUE, + Break = NFT_BREAK, + Jump = NFT_JUMP, + Goto = NFT_GOTO, + Return = NFT_RETURN, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Verdict { + #[field(NFTA_VERDICT_CODE)] + code: VerdictType, + #[field(NFTA_VERDICT_CHAIN)] + chain: String, + #[field(NFTA_VERDICT_CHAIN_ID)] + chain_id: u32, +} -/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl -/// terms, but here it is simplified to only represent a verdict. #[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum Verdict { +pub enum VerdictKind { /// Silently drop the packet. Drop, /// Accept the packet and let it pass. @@ -14,135 +42,10 @@ pub enum Verdict { Continue, Break, Jump { - chain: CString, + chain: String, }, Goto { - chain: CString, + chain: String, }, Return, } - -impl Verdict { - fn chain(&self) -> Option<&CStr> { - match *self { - Verdict::Jump { ref chain } => Some(chain.as_c_str()), - Verdict::Goto { ref chain } => Some(chain.as_c_str()), - _ => None, - } - } -} - -impl Expression for Verdict { - fn get_raw_name() -> *const libc::c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let mut chain = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { - let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); - - if raw_chain.is_null() { - return Err(DeserializationError::NullPointer); - } - chain = Some(CStr::from_ptr(raw_chain).to_owned()); - } - - let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); - - match verdict as i32 { - libc::NF_DROP => Ok(Verdict::Drop), - libc::NF_ACCEPT => Ok(Verdict::Accept), - libc::NF_QUEUE => Ok(Verdict::Queue), - libc::NFT_CONTINUE => Ok(Verdict::Continue), - libc::NFT_BREAK => Ok(Verdict::Break), - libc::NFT_JUMP => { - if let Some(chain) = chain { - Ok(Verdict::Jump { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_GOTO => { - if let Some(chain) = chain { - Ok(Verdict::Goto { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_RETURN => Ok(Verdict::Return), - _ => Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let immediate_const = match *self { - Verdict::Drop => libc::NF_DROP, - Verdict::Accept => libc::NF_ACCEPT, - Verdict::Queue => libc::NF_QUEUE, - Verdict::Continue => libc::NFT_CONTINUE, - Verdict::Break => libc::NFT_BREAK, - Verdict::Jump { .. } => libc::NFT_JUMP, - Verdict::Goto { .. } => libc::NFT_GOTO, - Verdict::Return => libc::NFT_RETURN, - }; - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc( - b"immediate\0" as *const _ as *const c_char - )); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - libc::NFT_REG_VERDICT as u32, - ); - - if let Some(chain) = self.chain() { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr()); - } - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_VERDICT as u16, - immediate_const as u32, - ); - - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_verdict { - (drop) => { - $crate::expr::Verdict::Drop - }; - (accept) => { - $crate::expr::Verdict::Accept - }; - (reject icmp $code:expr) => { - $crate::expr::Verdict::Reject(RejectionType::Icmp($code)) - }; - (reject tcp-rst) => { - $crate::expr::Verdict::Reject(RejectionType::TcpRst) - }; - (queue) => { - $crate::expr::Verdict::Queue - }; - (continue) => { - $crate::expr::Verdict::Continue - }; - (break) => { - $crate::expr::Verdict::Break - }; - (jump $chain:expr) => { - $crate::expr::Verdict::Jump { chain: $chain } - }; - (goto $chain:expr) => { - $crate::expr::Verdict::Goto { chain: $chain } - }; - (return) => { - $crate::expr::Verdict::Return - }; -} diff --git a/src/expr/wrapper.rs b/src/expr/wrapper.rs deleted file mode 100644 index 12ef60b..0000000 --- a/src/expr/wrapper.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::ffi::CStr; -use std::ffi::CString; -use std::fmt::Debug; -use std::rc::Rc; -use std::os::raw::c_char; - -use super::{DeserializationError, Expression}; -use crate::{sys, Rule}; - -pub struct ExpressionWrapper { - pub(crate) expr: *const sys::nftnl_expr, - // we also need the rule here to ensure that the rule lives as long as the `expr` pointer - #[allow(dead_code)] - pub(crate) rule: Rc<Rule>, -} - -impl Debug for ExpressionWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl ExpressionWrapper { - /// Retrieves a textual description of the expression. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_expr_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.expr, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves the type of expression ("log", "counter", ...). - pub fn get_kind(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Attempts to decode the expression as the type T. - pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> { - if let Some(kind) = self.get_kind() { - let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) }; - if kind == raw_name { - return T::from_expr(self.expr); - } - } - Err(DeserializationError::InvalidExpressionKind) - } -} @@ -1,4 +1,4 @@ -// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby +// Copyryght (c) 2021-2022 GPL lafleur@boum.org and Simon Thoby // // This file is free software: you may copy, redistribute and/or modify it // under the terms of the GNU General Public License as published by the @@ -24,106 +24,70 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Safe abstraction for [`libnftnl`]. Provides userspace access to the in-kernel nf_tables -//! subsystem. Can be used to create and remove tables, chains, sets and rules from the nftables +//! Safe abstraction for userspace access to the in-kernel nf_tables subsystem. +//! Can be used to create and remove tables, chains, sets and rules from the nftables //! firewall, the successor to iptables. //! //! This library currently has quite rough edges and does not make adding and removing netfilter //! entries super easy and elegant. That is partly because the library needs more work, but also //! partly because nftables is super low level and extremely customizable, making it hard, and //! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. -//! One can also look at how the original project this crate was developed to support uses it: -//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app) //! -//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by -//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft` -//! binary. Since the implementation is mostly based on trial and error, there might of course be -//! a number of places where the underlying library is used in an invalid or not intended way. -//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome! +//! Understanding how to use the netlink subsystem and implementing this crate has mostly been done by +//! reading the source code for the [`nftables`] userspace program and its corresponding kernel code, +//! as well as attaching debuggers to the `nft` binary. +//! Since the implementation is mostly based on trial and error, there might of course be +//! a number of places where the forged netlink messages are used in an invalid or not intended way. +//! Contributions are welcome! //! -//! # Supported versions of `libnftnl` -//! -//! This crate will automatically link to the currently installed version of libnftnl upon build. -//! It requires libnftnl version 1.0.6 or higher. See how the low level FFI bindings to the C -//! library are generated in [`build.rs`]. -//! -//! # Access to raw handles -//! -//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absolutely -//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`. -//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For -//! example, a program using a const handle to a object in a thread and writing through a mutable -//! handle in another could reach all kind of undefined (and dangerous!) behaviors. By enabling -//! that feature flag, you acknowledge that guaranteeing the respect of safety invariants is now -//! your responsibility! Despite these shortcomings, that feature is still available because it -//! may allow you to perform manipulations that this library doesn't currently expose. If that is -//! your case, we would be very happy to hear from you and maybe help you get the necessary -//! functionality upstream. -//! -//! Our current lack of confidence in our availability to provide a safe abstraction over the use -//! of raw handles in the face of concurrency is the reason we decided to settly on `Rc` pointers -//! instead of `Arc` (besides, this should gives us some nice performance boost, not that it -//! matters much of course) and why we do not declare the types exposed by the library as `Send` -//! nor `Sync`. -//! -//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/ //! [`nftables`]: https://netfilter.org/projects/nftables/ -//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs - -use thiserror::Error; #[macro_use] extern crate log; -pub mod sys; -use std::{convert::TryFrom, ffi::c_void, ops::Deref}; -use sys::libc; - -macro_rules! try_alloc { - ($e:expr) => {{ - let ptr = $e; - if ptr.is_null() { - // OOM, and the tried allocation was likely very small, - // so we are in a very tight situation. We do what libstd does, aborts. - std::process::abort(); - } - ptr - }}; -} +use libc; + +use rustables_macros::nfnetlink_enum; +use std::convert::TryFrom; mod batch; -#[cfg(feature = "query")] -pub use batch::{batch_is_supported, default_batch_page_size}; -pub use batch::{Batch, FinalizedBatch, NetlinkError}; +pub use batch::{default_batch_page_size, Batch}; -pub mod expr; +pub mod data_type; -pub mod table; +mod table; +pub use table::list_tables; pub use table::Table; -#[cfg(feature = "query")] -pub use table::{get_tables_cb, list_tables}; mod chain; -#[cfg(feature = "query")] -pub use chain::{get_chains_cb, list_chains_for_table}; -pub use chain::{Chain, ChainType, Hook, Policy, Priority}; +pub use chain::list_chains_for_table; +pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; -mod chain_methods; -pub use chain_methods::ChainMethods; +pub mod error; pub mod query; +pub(crate) mod nlmsg; +pub(crate) mod parser; +pub(crate) mod parser_impls; + mod rule; +pub use rule::list_rules_for_chain; pub use rule::Rule; -#[cfg(feature = "query")] -pub use rule::{get_rules_cb, list_rules_for_chain}; + +pub mod expr; mod rule_methods; -pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError}; +pub use rule_methods::{iface_index, Protocol}; pub mod set; pub use set::Set; +pub mod sys; + +#[cfg(test)] +mod tests; + /// The type of the message as it's sent to netfilter. A message consists of an object, such as a /// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with /// that object. If a [`Table`] object is sent with `MsgType::Add` then that table will be added @@ -133,7 +97,7 @@ pub use set::Set; /// [`Chain`]: struct.Chain.html /// [`Rule`]: struct.Rule.html /// [`MsgType`]: enum.MsgType.html -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MsgType { /// Add the object to netfilter. Add, @@ -142,69 +106,22 @@ pub enum MsgType { } /// Denotes a protocol. Used to specify which protocol a table or set belongs to. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum ProtoFamily { - Unspec = libc::NFPROTO_UNSPEC as u16, +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_enum(i32)] +pub enum ProtocolFamily { + Unspec = libc::NFPROTO_UNSPEC, /// Inet - Means both IPv4 and IPv6 - Inet = libc::NFPROTO_INET as u16, - Ipv4 = libc::NFPROTO_IPV4 as u16, - Arp = libc::NFPROTO_ARP as u16, - NetDev = libc::NFPROTO_NETDEV as u16, - Bridge = libc::NFPROTO_BRIDGE as u16, - Ipv6 = libc::NFPROTO_IPV6 as u16, - DecNet = libc::NFPROTO_DECNET as u16, -} -#[derive(Error, Debug)] -#[error("Couldn't find a matching protocol")] -pub struct InvalidProtocolFamily; - -impl TryFrom<i32> for ProtoFamily { - type Error = InvalidProtocolFamily; - fn try_from(value: i32) -> Result<Self, Self::Error> { - match value { - libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec), - libc::NFPROTO_INET => Ok(ProtoFamily::Inet), - libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4), - libc::NFPROTO_ARP => Ok(ProtoFamily::Arp), - libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev), - libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge), - libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6), - libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet), - _ => Err(InvalidProtocolFamily), - } - } + Inet = libc::NFPROTO_INET, + Ipv4 = libc::NFPROTO_IPV4, + Arp = libc::NFPROTO_ARP, + NetDev = libc::NFPROTO_NETDEV, + Bridge = libc::NFPROTO_BRIDGE, + Ipv6 = libc::NFPROTO_IPV6, + DecNet = libc::NFPROTO_DECNET, } -/// Trait for all types in this crate that can serialize to a Netlink message. -/// -/// # Unsafe -/// -/// This trait is unsafe to implement because it must never serialize to anything larger than the -/// largest possible netlink message. Internally the `nft_nlmsg_maxsize()` function is used to -/// make sure the `buf` pointer passed to `write` always has room for the largest possible Netlink -/// message. -pub unsafe trait NlMsg { - /// Serializes the Netlink message to the buffer at `buf`. `buf` must have space for at least - /// `nft_nlmsg_maxsize()` bytes. This is not checked by the compiler, which is why this method - /// is unsafe. - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType); -} - -unsafe impl<T, R> NlMsg for T -where - T: Deref<Target = R>, - R: NlMsg, -{ - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - self.deref().write(buf, seq, msg_type); +impl Default for ProtocolFamily { + fn default() -> Self { + ProtocolFamily::Unspec } } - -/// The largest nf_tables netlink message is the set element message, which contains the -/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set -/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is -/// a bit larger than 64 KBytes. -pub fn nft_nlmsg_maxsize() -> u32 { - u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 -} diff --git a/src/nlmsg.rs b/src/nlmsg.rs new file mode 100644 index 0000000..1c5b519 --- /dev/null +++ b/src/nlmsg.rs @@ -0,0 +1,182 @@ +use std::{fmt::Debug, mem::size_of}; + +use crate::{ + error::DecodeError, + sys::{ + nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE, + }, + MsgType, ProtocolFamily, +}; +/// +/// The largest nf_tables netlink message is the set element message, which contains the +/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set +/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is +/// a bit larger than 64 KBytes. +pub fn nft_nlmsg_maxsize() -> u32 { + u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 +} + +#[inline] +pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { + // align on a 4 bytes boundary + (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) +} + +#[inline] +pub const fn pad_netlink_object<T>() -> usize { + let size = size_of::<T>(); + pad_netlink_object_with_variable_size(size) +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +pub struct NfNetlinkWriter<'a> { + buf: &'a mut Vec<u8>, + headers: Option<(usize, usize)>, +} + +impl<'a> NfNetlinkWriter<'a> { + pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> { + NfNetlinkWriter { buf, headers: None } + } + + pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] { + let padded_size = pad_netlink_object_with_variable_size(size); + let start = self.buf.len(); + self.buf.resize(start + padded_size, 0); + + if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { + let mut hdr: &mut nlmsghdr = unsafe { + std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) + }; + hdr.nlmsg_len += padded_size as u32; + } + + &mut self.buf[start..start + size] + } + + // rewrite of `__nftnl_nlmsg_build_hdr` + pub fn write_header( + &mut self, + msg_type: u16, + family: ProtocolFamily, + flags: u16, + seq: u32, + ressource_id: Option<u16>, + ) { + if self.headers.is_some() { + error!("Calling write_header while still holding headers open!?"); + } + + let nlmsghdr_len = pad_netlink_object::<nlmsghdr>(); + let nfgenmsg_len = pad_netlink_object::<nfgenmsg>(); + + let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); + let mut hdr: &mut nlmsghdr = + unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; + hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32; + hdr.nlmsg_type = msg_type; + // batch messages are not specific to the nftables subsystem + if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 { + hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8; + } + hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; + hdr.nlmsg_seq = seq; + + let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); + let mut nfgenmsg: &mut nfgenmsg = + unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; + nfgenmsg.nfgen_family = family as u8; + nfgenmsg.version = NFNETLINK_V0 as u8; + nfgenmsg.res_id = ressource_id.unwrap_or(0); + + self.headers = Some(( + self.buf.len() - (nlmsghdr_len + nfgenmsg_len), + self.buf.len() - nfgenmsg_len, + )); + } + + pub fn finalize_writing_object(&mut self) { + self.headers = None; + } +} + +pub trait AttributeDecoder { + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; +} + +pub trait NfNetlinkDeserializable: Sized { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; +} + +pub trait NfNetlinkObject: + Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute +{ + const MSG_TYPE_ADD: u32; + const MSG_TYPE_DEL: u32; + + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => Self::MSG_TYPE_ADD, + MsgType::Del => Self::MSG_TYPE_DEL, + } as u16; + writer.write_header( + raw_msg_type, + self.get_family(), + (if let MsgType::Add = msg_type { + self.get_add_flags() + } else { + self.get_del_flags() + } | NLM_F_ACK) as u16, + seq, + None, + ); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } + writer.finalize_writing_object(); + } + + fn get_family(&self) -> ProtocolFamily; + + fn set_family(&mut self, _family: ProtocolFamily) { + // the default impl do nothing, because some types are family-agnostic + } + + fn with_family(mut self, family: ProtocolFamily) -> Self { + self.set_family(family); + self + } + + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE + } + + fn get_del_flags(&self) -> u32 { + 0 + } +} + +pub type NetlinkType = u16; + +pub trait NfNetlinkAttribute: Debug + Sized { + // is it a nested argument that must be marked with a NLA_F_NESTED flag? + fn is_nested(&self) -> bool { + false + } + + fn get_size(&self) -> usize { + size_of::<Self>() + } + + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); + unsafe fn write_payload(&self, addr: *mut u8); +} diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..6ea34c1 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,216 @@ +use std::{ + fmt::{Debug, DebugStruct}, + mem::{size_of, transmute}, +}; + +use crate::{ + error::DecodeError, + nlmsg::{ + get_operation_from_nlmsghdr_type, get_subsystem_from_nlmsghdr_type, pad_netlink_object, + pad_netlink_object_with_variable_size, AttributeDecoder, NetlinkType, NfNetlinkAttribute, + }, + sys::{ + nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, + NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, + NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, + }, +}; + +pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> { + let size_of_hdr = size_of::<nlmsghdr>(); + + if buf.len() < size_of_hdr { + return Err(DecodeError::BufTooSmall); + } + + let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr; + let nlmsghdr = unsafe { *nlmsghdr_ptr }; + + if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr { + return Err(DecodeError::NlMsgTooSmall); + } + + if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 { + return Err(DecodeError::ConcurrentGenerationUpdate); + } + + Ok(nlmsghdr) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum NlMsg<'a> { + Done, + Noop, + Error(nlmsgerr), + NfGenMsg(nfgenmsg, &'a [u8]), +} + +pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeError> { + // in theory the message is composed of the following parts: + // - nlmsghdr (contains the message size and type) + // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) + // - the raw value that we want to validate (if the previous part is nfgenmsg) + let hdr = get_nlmsghdr(buf)?; + + let size_of_hdr = pad_netlink_object::<nlmsghdr>(); + + if hdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { + match hdr.nlmsg_type as u32 { + x if x == NLMSG_NOOP => return Ok((hdr, NlMsg::Noop)), + x if x == NLMSG_ERROR => { + if (hdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() { + return Err(DecodeError::NlMsgTooSmall); + } + let mut err = unsafe { + *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr() + as *const nlmsgerr) + }; + // some APIs return negative values, while other return positive values + err.error = err.error.abs(); + return Ok((hdr, NlMsg::Error(err))); + } + x if x == NLMSG_DONE => return Ok((hdr, NlMsg::Done)), + x => return Err(DecodeError::UnsupportedType(x as u16)), + } + } + + // batch messages are not specific to the nftables subsystem + if hdr.nlmsg_type != NFNL_MSG_BATCH_BEGIN as u16 && hdr.nlmsg_type != NFNL_MSG_BATCH_END as u16 + { + // verify that we are decoding nftables messages + let subsys = get_subsystem_from_nlmsghdr_type(hdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(DecodeError::InvalidSubsystem(subsys)); + } + } + + let size_of_nfgenmsg = pad_netlink_object::<nfgenmsg>(); + if hdr.nlmsg_len as usize > buf.len() + || (hdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg + { + return Err(DecodeError::NlMsgTooSmall); + } + + let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const nfgenmsg; + let nfgenmsg = unsafe { *nfgenmsg_ptr }; + + if nfgenmsg.version != NFNETLINK_V0 as u8 { + return Err(DecodeError::InvalidVersion(nfgenmsg.version)); + } + + let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..hdr.nlmsg_len as usize]; + + Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value))) +} + +/// Write the attribute, preceded by a `libc::nlattr` +// rewrite of `mnl_attr_put` +pub unsafe fn write_attribute<'a>( + ty: NetlinkType, + obj: &impl NfNetlinkAttribute, + mut buf: *mut u8, +) { + let header_len = pad_netlink_object::<libc::nlattr>(); + // copy the header + *(buf as *mut nlattr) = nlattr { + // nla_len contains the header size + the unpadded attribute length + nla_len: (header_len + obj.get_size() as usize) as u16, + nla_type: if obj.is_nested() { + ty | NLA_F_NESTED as u16 + } else { + ty + }, + }; + buf = buf.offset(pad_netlink_object::<nlattr>() as isize); + // copy the attribute data itself + obj.write_payload(buf); +} + +pub(crate) fn read_attributes<T: AttributeDecoder + Default>(buf: &[u8]) -> Result<T, DecodeError> { + debug!( + "Calling <{} as NfNetlinkDeserialize>::deserialize()", + std::any::type_name::<T>() + ); + let mut remaining_size = buf.len(); + let mut pos = 0; + let mut res = T::default(); + while remaining_size > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + pos += pad_netlink_object::<nlattr>(); + let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>(); + match T::decode_attribute(&mut res, nla_type, &buf[pos..pos + attr_remaining_size]) { + Ok(()) => {} + Err(DecodeError::UnsupportedAttributeType(t)) => info!( + "Ignoring unsupported attribute type {} for type {}", + t, + std::any::type_name::<T>() + ), + Err(e) => return Err(e), + } + pos += pad_netlink_object_with_variable_size(attr_remaining_size); + + remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if remaining_size != 0 { + Err(DecodeError::InvalidDataSize) + } else { + Ok(res) + } +} + +pub trait InnerFormat { + fn inner_format_struct<'a, 'b: 'a>( + &'a self, + s: DebugStruct<'a, 'b>, + ) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>; +} + +pub trait Parsable +where + Self: Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>; +} + +impl<T> Parsable for T +where + T: AttributeDecoder + Default + Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() started"); + let (hdr, msg) = parse_nlmsg(buf)?; + + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; + + if op != add_obj && op != del_obj { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } + + let obj_size = hdr.nlmsg_len as usize + - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); + + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + let remaining_data = &buf[remaining_data_offset..]; + + let (nfgenmsg, res) = match msg { + NlMsg::NfGenMsg(nfgenmsg, content) => { + (nfgenmsg, read_attributes(&content[..obj_size])?) + } + _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), + }; + + Ok((res, nfgenmsg, remaining_data)) + } +} diff --git a/src/parser_impls.rs b/src/parser_impls.rs new file mode 100644 index 0000000..b2681bb --- /dev/null +++ b/src/parser_impls.rs @@ -0,0 +1,243 @@ +use std::{fmt::Debug, mem::transmute}; + +use rustables_macros::nfnetlink_struct; + +use crate::{ + error::DecodeError, + expr::Verdict, + nlmsg::{ + pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkDeserializable, NfNetlinkObject, + }, + parser::{write_attribute, Parsable}, + sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, + ProtocolFamily, +}; + +impl NfNetlinkAttribute for u8 { + unsafe fn write_payload(&self, addr: *mut u8) { + *addr = *self; + } +} + +impl NfNetlinkDeserializable for u8 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf[0], &buf[1..])) + } +} + +impl NfNetlinkAttribute for u16 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u16 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) + } +} + +impl NfNetlinkAttribute for i32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for i32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u64 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u64 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ]), + &buf[8..], + )) + } +} + +impl NfNetlinkAttribute for String { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for String { + fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + // ignore the NULL byte terminator, if any + if buf.len() > 0 && buf[buf.len() - 1] == 0 { + buf = &buf[..buf.len() - 1]; + } + Ok((String::from_utf8(buf.to_vec())?, &[])) + } +} + +impl NfNetlinkAttribute for Vec<u8> { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for Vec<u8> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf.to_vec(), &[])) + } +} +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct NfNetlinkData { + #[field(NFTA_DATA_VALUE)] + value: Vec<u8>, + #[field(NFTA_DATA_VERDICT)] + verdict: Verdict, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default, +{ + objs: Vec<T>, +} + +impl<T> NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + pub fn add_value(&mut self, e: impl Into<T>) { + self.objs.push(e.into()); + } + + pub fn with_value(mut self, e: impl Into<T>) -> Self { + self.add_value(e); + self + } + + pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> { + self.objs.iter() + } +} + +impl<T> NfNetlinkAttribute for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + // one nlattr LIST_ELEM per object + self.objs.iter().fold(0, |acc, item| { + acc + item.get_size() + pad_netlink_object::<nlattr>() + }) + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + for item in &self.objs { + write_attribute(NFTA_LIST_ELEM, item, addr); + addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + } + } +} + +impl<T> NfNetlinkDeserializable for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let mut objs = Vec::new(); + + let mut pos = 0; + while buf.len() - pos > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + if nla_type != NFTA_LIST_ELEM { + return Err(DecodeError::UnsupportedAttributeType(nla_type)); + } + + let (obj, remaining) = T::deserialize( + &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], + )?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + objs.push(obj); + + pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if pos != buf.len() { + Err(DecodeError::InvalidDataSize) + } else { + Ok((Self { objs }, &[])) + } + } +} + +impl<O, T> From<Vec<O>> for NfNetlinkList<T> +where + T: From<O>, + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn from(v: Vec<O>) -> Self { + NfNetlinkList { + objs: v.into_iter().map(T::from).collect(), + } + } +} + +impl<T> NfNetlinkDeserializable for T +where + T: NfNetlinkObject + Parsable, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + buf, + <T as NfNetlinkObject>::MSG_TYPE_ADD, + <T as NfNetlinkObject>::MSG_TYPE_DEL, + )?; + obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?); + + Ok((obj, remaining_data)) + } +} diff --git a/src/query.rs b/src/query.rs index bc1d02e..7cf5050 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,129 +1,178 @@ -use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; -use sys::libc; - -/// Returns a buffer containing a netlink message which requests a list of all the netfilter -/// matching objects (e.g. tables, chains, rules, ...). -/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback -/// to execute on the header, to set parameters for example. -/// To pass arbitrary data inside that callback, please use a closure. -pub fn get_list_of_objects<Error>( - seq: u32, - target: u16, - setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, -) -> Result<Vec<u8>, Error> { - let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; - let hdr = unsafe { - &mut *sys::nftnl_nlmsg_build_hdr( - buffer.as_mut_ptr() as *mut libc::c_char, - target, - ProtoFamily::Unspec as u16, - (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16, - seq, - ) - }; - if let Some(cb) = setup_cb { - cb(hdr)?; - } - Ok(buffer) -} - -#[cfg(feature = "query")] -mod inner { - use crate::FinalizedBatch; - - use super::*; - - #[derive(thiserror::Error, Debug)] - pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - - #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] std::io::Error), - - #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] std::io::Error), +use std::os::unix::prelude::RawFd; + +use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}; + +use crate::{ + error::QueryError, + nlmsg::{ + nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkObject, NfNetlinkWriter, + }, + parser::{parse_nlmsg, NlMsg}, + sys::{NLM_F_DUMP, NLM_F_MULTI}, + ProtocolFamily, +}; + +pub(crate) fn recv_and_process<'a, T>( + sock: RawFd, + max_seq: Option<u32>, + cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>, + working_data: &'a mut T, +) -> Result<(), QueryError> { + let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize]; + let mut buf_start = 0; + let mut end_pos = 0; + + loop { + let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty()) + .map_err(QueryError::NetlinkRecvError)?; + if nb_recv <= 0 { + return Ok(()); + } + end_pos += nb_recv; + loop { + let buf = &msg_buffer.as_slice()[buf_start..end_pos]; + // exit the loop and try to receive further messages when we consumed all the buffer + if buf.len() == 0 { + break; + } - #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[source] std::io::Error), + debug!("Calling parse_nlmsg"); + let (nlmsghdr, msg) = parse_nlmsg(&buf)?; + debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg); + + match msg { + NlMsg::Done => { + return Ok(()); + } + NlMsg::Error(e) => { + if e.error != 0 { + return Err(QueryError::NetlinkError(e)); + } + } + NlMsg::Noop => {} + NlMsg::NfGenMsg(_genmsg, _data) => { + if let Some(cb) = cb { + cb(&buf[0..nlmsghdr.nlmsg_len as usize], working_data)?; + } + } + } - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + 'static>), + // we cannot know when a sequence of messages will end if the messages do not end + // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { + return Err(QueryError::UndecidableMessageTermination); + } - #[error("Couldn't allocate a netlink object, out of memory ?")] - NetlinkAllocationFailed, - } + // retrieve the next message + if let Some(max_seq) = max_seq { + if nlmsghdr.nlmsg_seq >= max_seq { + return Ok(()); + } + } - /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper - /// function called by mnl::cb_run2. - /// The callback expects a tuple of additional data (supplied as an argument to this function) - /// and of the output vector, to which it should append the parsed object it received. - pub fn list_objects_with_data<'a, A, T>( - data_type: u16, - cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int, - additional_data: &'a A, - req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, - ) -> Result<Vec<T>, Error> - where - T: 'a, - { - debug!("listing objects of kind {}", data_type); - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; - - let seq = 0; - let portid = 0; - - let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; - socket.send(&chains_buf).map_err(Error::NetlinkSendError)?; - - let mut res = Vec::new(); - - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = mnl::cb_run2( - &msg_buffer, - seq, - portid, - cb, - &mut (additional_data, &mut res), - ) - .map_err(Error::ProcessNetlinkError)? - { - break; + // netlink messages are 4bytes aligned + let aligned_length = pad_netlink_object_with_variable_size(nlmsghdr.nlmsg_len as usize); + buf_start += aligned_length; + } + // Ensure that we always have nft_nlmsg_maxsize() free space available in the buffer. + // We achieve this by relocating the buffer content at the beginning of the buffer + if end_pos >= nft_nlmsg_maxsize() as usize { + if buf_start < end_pos { + unsafe { + std::ptr::copy( + msg_buffer[buf_start..end_pos].as_ptr(), + msg_buffer.as_mut_ptr(), + end_pos - buf_start, + ); + } } + end_pos = end_pos - buf_start; + buf_start = 0; } - - Ok(res) } +} - pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; +pub(crate) fn socket_close_wrapper<E>( + sock: RawFd, + cb: impl FnOnce(RawFd) -> Result<(), E>, +) -> Result<(), QueryError> +where + QueryError: From<E>, +{ + let ret = cb(sock); - let seq = 0; - let portid = socket.portid(); + // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; + // and return EOPNOTSUPP if we try) + nix::unistd::close(sock).map_err(QueryError::CloseFailed)?; - socket.send_all(batch).map_err(Error::NetlinkSendError)?; - debug!("sent"); + Ok(ret?) +} - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = - mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)? - { - break; - } +/// Returns a buffer containing a netlink message which requests a list of all the netfilter +/// matching objects (e.g. tables, chains, rules, ...). +/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter. +pub fn get_list_of_objects<T: NfNetlinkAttribute>( + msg_type: u16, + seq: u32, + filter: Option<&T>, +) -> Result<Vec<u8>, QueryError> { + let mut buffer = Vec::new(); + let mut writer = NfNetlinkWriter::new(&mut buffer); + writer.write_header( + msg_type, + ProtocolFamily::Unspec, + NLM_F_DUMP as u16, + seq, + None, + ); + if let Some(filter) = filter { + let buf = writer.add_data_zeroed(filter.get_size()); + unsafe { + filter.write_payload(buf.as_mut_ptr()); } - Ok(()) } + writer.finalize_writing_object(); + Ok(buffer) } -#[cfg(feature = "query")] -pub use inner::*; +/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper +/// function called by mnl::cb_run2. +/// The callback expects a tuple of additional data (supplied as an argument to this function) +/// and of the output vector, to which it should append the parsed object it received. +pub fn list_objects_with_data<'a, Object, Accumulator>( + data_type: u16, + cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>, + filter: Option<&Object>, + working_data: &'a mut Accumulator, +) -> Result<(), QueryError> +where + Object: NfNetlinkObject + NfNetlinkAttribute, +{ + debug!("Listing objects of kind {}", data_type); + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(QueryError::NetlinkOpenError)?; + + let seq = 0; + + let chains_buf = get_list_of_objects(data_type, seq, filter)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?; + + socket_close_wrapper(sock, move |sock| { + // the kernel should return NLM_F_MULTI objects + recv_and_process( + sock, + None, + Some(&|buf: &[u8], working_data: &mut Accumulator| { + debug!("Calling Object::deserialize()"); + cb(Object::deserialize(buf)?.0, working_data) + }), + working_data, + ) + }) +} diff --git a/src/rule.rs b/src/rule.rs index 2ee5308..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,341 +1,111 @@ -use crate::expr::ExpressionWrapper; -use crate::{chain::Chain, expr::Expression, MsgType}; -use crate::sys::{self, libc}; -use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; -use std::os::raw::c_char; -use std::rc::Rc; + +use rustables_macros::nfnetlink_struct; + +use crate::chain::Chain; +use crate::error::{BuilderError, QueryError}; +use crate::expr::{ExpressionList, RawExpression}; +use crate::nlmsg::NfNetlinkObject; +use crate::query::list_objects_with_data; +use crate::sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION, + NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, + NLM_F_CREATE, +}; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Rule { - pub(crate) rule: *mut sys::nftnl_rule, - pub(crate) chain: Rc<Chain>, + family: ProtocolFamily, + #[field(NFTA_RULE_TABLE)] + table: String, + #[field(NFTA_RULE_CHAIN)] + chain: String, + #[field(NFTA_RULE_HANDLE)] + handle: u64, + #[field(NFTA_RULE_EXPRESSIONS)] + expressions: ExpressionList, + #[field(NFTA_RULE_POSITION)] + position: u64, + #[field(NFTA_RULE_USERDATA)] + userdata: Vec<u8>, + #[field(NFTA_RULE_ID)] + id: u32, } impl Rule { /// Creates a new rule object in the given [`Chain`]. /// /// [`Chain`]: struct.Chain.html - pub fn new(chain: Rc<Chain>) -> Rule { - unsafe { - let rule = try_alloc!(sys::nftnl_rule_alloc()); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - Rule { rule, chain } - } - } - - pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self { - Rule { rule, chain } - } - - pub fn get_position(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) } - } - - /// Sets the position of this rule within the chain it lives in. By default a new rule is added - /// to the end of the chain. - pub fn set_position(&mut self, position: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position); - } - } - - pub fn get_handle(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) } - } - - pub fn set_handle(&mut self, handle: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle); - } - } - - /// Adds an expression to this rule. Expressions are evaluated from first to last added. - /// As soon as an expression does not match the packet it's being evaluated for, evaluation - /// stops and the packet is evaluated against the next rule in the chain. - pub fn add_expr(&mut self, expr: &impl Expression) { - unsafe { sys::nftnl_rule_add_expr(self.rule, expr.to_expr(self)) } - } - - /// Returns a reference to the [`Chain`] this rule lives in. - /// - /// [`Chain`]: struct.Chain.html - pub fn get_chain(&self) -> Rc<Chain> { - self.chain.clone() - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns a textual description of the rule. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_rule_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.rule, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves an iterator to loop over the expressions of the rule. - pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter { - RuleExprsIter::new(self.clone()) - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_rule { - self.rule as *const sys::nftnl_rule - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule { - self.rule - } - - /// Performs a deep comparizon of rules, by checking they have the same expressions inside. - /// This is not enabled by default in our PartialEq implementation because of the difficulty to - /// compare an expression generated by the library with the expressions returned by the kernel - /// when iterating over the currently in-use rules. The kernel-returned expressions may have - /// additional attributes despite being generated from the same rule. This is particularly true - /// for the 'nat' expression). - pub fn deep_eq(&self, other: &Self) -> bool { - if self != other { - return false; - } - - let self_exprs = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) }); - let other_exprs = try_alloc!(unsafe { - sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule) - }); - - loop { - let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) }; - let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) }; - if self_next.is_null() && other_next.is_null() { - return true; - } else if self_next.is_null() || other_next.is_null() { - return false; - } - - // we are falling back on comparing the strings, because there is no easy mechanism to - // perform a memcmp() between the two expressions :/ - let mut self_str = [0; 256]; - let mut other_str = [0; 256]; - unsafe { - sys::nftnl_expr_snprintf( - self_str.as_mut_ptr(), - (self_str.len() - 1) as u64, - self_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - sys::nftnl_expr_snprintf( - other_str.as_mut_ptr(), - (other_str.len() - 1) as u64, - other_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); + pub fn new(chain: &Chain) -> Result<Rule, BuilderError> { + Ok(Rule::default() + .with_family(chain.get_family()) + .with_table( + chain + .get_table() + .ok_or(BuilderError::MissingChainInformationError)?, + ) + .with_chain( + chain + .get_name() + .ok_or(BuilderError::MissingChainInformationError)?, + )) + } + + pub fn add_expr(&mut self, e: impl Into<RawExpression>) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() } - - if self_str != other_str { - return false; - } - } - } -} - -impl Debug for Rule { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl PartialEq for Rule { - fn eq(&self, other: &Self) -> bool { - if self.get_chain() != other.get_chain() { - return false; - } - - unsafe { - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16) - { - if self.get_handle() != other.get_handle() { - return false; - } - } - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16) - { - if self.get_position() != other.get_position() { - return false; - } - } - } - - return false; - } -} - -unsafe impl crate::NlMsg for Rule { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWRULE, - MsgType::Del => libc::NFT_MSG_DELRULE, }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16, - MsgType::Del => 0u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.chain.get_table().get_family() as u16, - flags, - seq, - ); - sys::nftnl_rule_nlmsg_build_payload(header, self.rule); + exprs.add_value(e); } -} -impl Drop for Rule { - fn drop(&mut self) { - unsafe { sys::nftnl_rule_free(self.rule) }; + pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self { + self.add_expr(e); + self } -} - -pub struct RuleExprsIter { - rule: Rc<Rule>, - iter: *mut sys::nftnl_expr_iter, -} -impl RuleExprsIter { - fn new(rule: Rc<Rule>) -> Self { - let iter = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) }); - RuleExprsIter { rule, iter } + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl Iterator for RuleExprsIter { - type Item = ExpressionWrapper; +impl NfNetlinkObject for Rule { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE; - fn next(&mut self) -> Option<Self::Item> { - let next = unsafe { sys::nftnl_expr_iter_next(self.iter) }; - if next.is_null() { - trace!("RulesExprsIter iterator ending"); - None - } else { - trace!("RulesExprsIter returning new expression"); - Some(ExpressionWrapper { - expr: next, - rule: self.rule.clone(), - }) - } + fn get_family(&self) -> ProtocolFamily { + self.family } -} -impl Drop for RuleExprsIter { - fn drop(&mut self) { - unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } -} - -#[cfg(feature = "query")] -pub fn get_rules_cb( - header: &libc::nlmsghdr, - (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> libc::c_int { - unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_rule_nlmsg_parse(header, rule); - if err < 0 { - error!("Failed to parse nelink rule message - {}", err); - sys::nftnl_rule_free(rule); - return err; - } - rules.push(Rule::from_raw(rule, chain.clone())); + // append at the end of the chain, instead of the beginning + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE | NLM_F_APPEND } - mnl::mnl_sys::MNL_CB_OK } -#[cfg(feature = "query")] -pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { - crate::query::list_objects_with_data( +pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> { + let mut result = Vec::new(); + list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - get_rules_cb, - &chain, - // only retrieve rules from the currently targetted chain - Some(&|hdr| unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule as *const _ == std::ptr::null() { - return Err(crate::query::Error::NetlinkAllocationFailed); - } - - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - sys::nftnl_rule_nlmsg_build_payload(hdr, rule); - - sys::nftnl_rule_free(rule); + &|rule: Rule, rules: &mut Vec<Rule>| { + rules.push(rule); Ok(()) - }), - ) + }, + // only retrieve rules from the currently targetted chain + Some(&Rule::new(chain)?), + &mut result, + )?; + Ok(result) } diff --git a/src/rule_methods.rs b/src/rule_methods.rs index d7145d7..dff9bf6 100644 --- a/src/rule_methods.rs +++ b/src/rule_methods.rs @@ -1,230 +1,211 @@ -use crate::{Batch, Rule, nft_expr, sys::libc}; -use crate::expr::{LogGroup, LogPrefix}; -use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; +use std::ffi::CString; use std::net::IpAddr; -use std::num::ParseIntError; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C string to a string")] - CStringError(#[from] NulError), - #[error("no interface found under that name")] - NoSuchIface, - #[error("Error converting from a string to an integer")] - ParseError(#[from] ParseIntError), - #[error("the interface name is too long")] - NameTooLong, -} +use ipnetwork::IpNetwork; +use crate::data_type::ip_to_vec; +use crate::error::BuilderError; +use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey}; +use crate::expr::{ + Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta, + MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField, + VerdictKind, +}; +use crate::Rule; /// Simple protocol description. Note that it does not implement other layer 4 protocols as /// IGMP et al. See [`Rule::igmp`] for a workaround. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { TCP, - UDP + UDP, } -/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a -/// verdict. Mostly adapted from [talpid-core's firewall]. -/// All methods return the rule itself, allowing them to be chained. Usage example : -/// ```rust -/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook}; -/// use std::ffi::CString; -/// use std::rc::Rc; -/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet)); -/// let mut batch = Batch::new(); -/// batch.add(&table, MsgType::Add); -/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table)) -/// .add_to_batch(&mut batch)); -/// let rule = Rule::new(inbound) -/// .dport("80", &Protocol::TCP).unwrap() -/// .accept() -/// .add_to_batch(&mut batch); -/// ``` -/// [talpid-core's firewall]: -/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs -pub trait RuleMethods { - /// Matches ICMP packets. - fn icmp(self) -> Self; - /// Matches IGMP packets. - fn igmp(self) -> Self; - /// Matches packets to destination `port` and `protocol`. - fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets on `protocol`. - fn protocol(self, protocol: Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets in an already established connection. - fn established(self) -> Self where Self: std::marker::Sized; - /// Matches packets going through `iface_index`. Interface indexes can be queried with - /// `iface_index()`. - fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo". - fn iface(self, iface_name: &str) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix - /// appended to each log line. - fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self; - /// Matches packets whose source IP address is `saddr`. - fn saddr(self, ip: IpAddr) -> Self; - /// Matches packets whose source network is `snet`. - fn snetwork(self, ip: IpNetwork) -> Self; - /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. - fn accept(self) -> Self; - /// Adds the `Drop` verdict to the rule. The packet will be dropped. - fn drop(self) -> Self; - /// Appends this rule to `batch`. - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - -/// A trait to add helper functions to match some criterium over `crate::Rule`. -impl RuleMethods for Rule { - fn icmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8)); - self - } - fn igmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8)); +impl Rule { + fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self { + self = self.protocol(protocol); + self.add_expr( + HighLevelPayload::Transport(match protocol { + Protocol::TCP => TransportHeaderField::Tcp(if source { + TCPHeaderField::Sport + } else { + TCPHeaderField::Dport + }), + Protocol::UDP => TransportHeaderField::Udp(if source { + UDPHeaderField::Sport + } else { + UDPHeaderField::Dport + }), + }) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes())); self } - fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - &Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - self.add_expr(&nft_expr!(payload tcp dport)); - }, - &Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - self.add_expr(&nft_expr!(payload udp dport)); - } - } - // Convert the port to Big-Endian number spelling. - // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969 - self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be())); - Ok(self) - } - fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - }, - Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - } - } - Ok(self) - } - fn established(mut self) -> Self { - let allowed_states = crate::expr::ct::States::ESTABLISHED.bits(); - self.add_expr(&nft_expr!(ct state)); - self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32)); - self.add_expr(&nft_expr!(cmp != 0u32)); - self - } - fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta iif)); - self.add_expr(&nft_expr!(cmp == iface_index)); - Ok(self) - } - fn iface(mut self, iface_name: &str) -> Result<Self, Error> { - if iface_name.len() >= libc::IFNAMSIZ { - return Err(Error::NameTooLong); - } - let mut name_arr = [0u8; libc::IFNAMSIZ]; - for (pos, i) in iface_name.bytes().enumerate() { - name_arr[pos] = i; - } - self.add_expr(&nft_expr!(meta iifname)); - self.add_expr(&nft_expr!(cmp == name_arr.as_ref())); - Ok(self) - } - fn saddr(mut self, ip: IpAddr) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self { + self.add_expr(Meta::new(MetaType::NfProto)); match ip { IpAddr::V4(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); + } IpAddr::V6(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); } } self } - fn snetwork(mut self, net: IpNetwork) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + + pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> { + self.add_expr(Meta::new(MetaType::NfProto)); match net { IpNetwork::V4(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32)); - self.add_expr(&nft_expr!(cmp == net.network())); - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?); + } IpNetwork::V6(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..])); - self.add_expr(&nft_expr!(cmp == net.network())); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?); } } + self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network()))); + Ok(self) + } +} + +impl Rule { + /// Matches ICMP packets. + pub fn icmp(mut self) -> Self { + // quid of icmpv6? + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8])); self } - fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self { - match (group.is_some(), prefix.is_some()) { - (true, true) => { - self.add_expr(&nft_expr!(log group group prefix prefix)); - }, - (false, true) => { - self.add_expr(&nft_expr!(log prefix prefix)); - }, - (true, false) => { - self.add_expr(&nft_expr!(log group group)); - }, - (false, false) => { - self.add_expr(&nft_expr!(log)); - } - } + /// Matches IGMP packets. + pub fn igmp(mut self) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8])); self } - fn accept(mut self) -> Self { - self.add_expr(&nft_expr!(verdict accept)); + /// Matches packets from source `port` and `protocol`. + pub fn sport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets to destination `port` and `protocol`. + pub fn dport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets on `protocol`. + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new( + CmpOp::Eq, + [match protocol { + Protocol::TCP => libc::IPPROTO_TCP, + Protocol::UDP => libc::IPPROTO_UDP, + } as u8], + )); + self + } + /// Matches packets in an already established connection. + pub fn established(mut self) -> Result<Self, BuilderError> { + let allowed_states = ConnTrackState::ESTABLISHED.bits(); + self.add_expr(Conntrack::new(ConntrackKey::State)); + self.add_expr(Bitwise::new( + allowed_states.to_le_bytes(), + 0u32.to_be_bytes(), + )?); + self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes())); + Ok(self) + } + /// Matches packets going through `iface_index`. Interface indexes can be queried with + /// `iface_index()`. + pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self { + self.add_expr(Meta::new(MetaType::Iif)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes())); self } - fn drop(mut self) -> Self { - self.add_expr(&nft_expr!(verdict drop)); + /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo" + pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> { + if iface_name.len() >= libc::IFNAMSIZ { + return Err(BuilderError::InterfaceNameTooLong); + } + let mut iface_vec = iface_name.as_bytes().to_vec(); + // null terminator + iface_vec.push(0u8); + + self.add_expr(Meta::new(MetaType::IifName)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_vec)); + Ok(self) + } + /// Matches packets whose source IP address is `saddr`. + pub fn saddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, true) + } + /// Matches packets whose destination IP address is `saddr`. + pub fn daddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, false) + } + /// Matches packets whose source network is `net`. + pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, true) + } + /// Matches packets whose destination network is `net`. + pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, false) + } + /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. + pub fn accept(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Accept)); self } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, crate::MsgType::Add); + /// Adds the `Drop` verdict to the rule. The packet will be dropped. + pub fn drop(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Drop)); self } } /// Looks up the interface index for a given interface name. -pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> { +pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> { let c_name = CString::new(name)?; let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; match index { - 0 => Err(Error::NoSuchIface), - _ => Ok(index) + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(index), } } - - @@ -1,273 +1,116 @@ -use crate::sys::{self, libc}; -use crate::{table::Table, MsgType, ProtoFamily}; -use std::{ - cell::Cell, - ffi::{c_void, CStr, CString}, - fmt::Debug, - net::{Ipv4Addr, Ipv6Addr}, - os::raw::c_char, - rc::Rc, +use rustables_macros::nfnetlink_struct; + +use crate::data_type::DataType; +use crate::error::BuilderError; +use crate::nlmsg::NfNetlinkObject; +use crate::parser_impls::{NfNetlinkData, NfNetlinkList}; +use crate::sys::{ + NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, NFTA_SET_ELEM_LIST_SET, + NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_FLAGS, NFTA_SET_ID, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_DELSETELEM, + NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, }; - -#[macro_export] -macro_rules! nft_set { - ($name:expr, $id:expr, $table:expr, $family:expr) => { - $crate::set::Set::new($name, $id, $table, $family) - }; - ($name:expr, $id:expr, $table:expr, $family:expr; [ ]) => { - nft_set!($name, $id, $table, $family) - }; - ($name:expr, $id:expr, $table:expr, $family:expr; [ $($value:expr,)* ]) => {{ - let mut set = nft_set!($name, $id, $table, $family).expect("Set allocation failed"); - $( - set.add($value).expect(stringify!(Unable to add $value to set $name)); - )* - set - }}; -} - -pub struct Set<K> { - pub(crate) set: *mut sys::nftnl_set, - pub(crate) table: Rc<Table>, - pub(crate) family: ProtoFamily, - _marker: ::std::marker::PhantomData<K>, +use crate::table::Table; +use crate::ProtocolFamily; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(derive_deserialize = false)] +pub struct Set { + pub family: ProtocolFamily, + #[field(NFTA_SET_TABLE)] + pub table: String, + #[field(NFTA_SET_NAME)] + pub name: String, + #[field(NFTA_SET_FLAGS)] + pub flags: u32, + #[field(NFTA_SET_KEY_TYPE)] + pub key_type: u32, + #[field(NFTA_SET_KEY_LEN)] + pub key_len: u32, + #[field(NFTA_SET_ID)] + pub id: u32, + #[field(NFTA_SET_USERDATA)] + pub userdata: String, } -impl<K> Set<K> { - pub fn new(name: &CStr, id: u32, table: Rc<Table>, family: ProtoFamily) -> Self - where - K: SetKey, - { - unsafe { - let set = try_alloc!(sys::nftnl_set_alloc()); - - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, family as u32); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr()); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr()); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id); - - sys::nftnl_set_set_u32( - set, - sys::NFTNL_SET_FLAGS as u16, - (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32, - ); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN); - - Set { - set, - table, - family, - _marker: ::std::marker::PhantomData, - } - } - } - - pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: Rc<Table>, family: ProtoFamily) -> Self - where - K: SetKey, - { - Set { - set, - table, - family, - _marker: ::std::marker::PhantomData, - } - } - - pub fn add(&mut self, key: &K) - where - K: SetKey, - { - unsafe { - let elem = try_alloc!(sys::nftnl_set_elem_alloc()); - - let data = key.data(); - let data_len = data.len() as u32; - trace!("Adding key {:?} with len {}", data, data_len); - sys::nftnl_set_elem_set( - elem, - sys::NFTNL_SET_ELEM_KEY as u16, - data.as_ref() as *const _ as *const c_void, - data_len, - ); - sys::nftnl_set_elem_add(self.set, elem); - } - } - - pub fn elems_iter(&self) -> SetElemsIter<K> { - SetElemsIter::new(self) - } +impl NfNetlinkObject for Set { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSET; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSET; - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_set { - self.set as *const sys::nftnl_set - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set { - self.set - } - - pub fn get_family(&self) -> ProtoFamily { + fn get_family(&self) -> ProtocolFamily { self.family } - /// Returns a textual description of the set. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_set_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.set, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - pub fn get_name(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - pub fn get_id(&self) -> u32 { - unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) } + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -impl<K> Debug for Set<K> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -unsafe impl<K> crate::NlMsg for Set<K> { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWSET, - MsgType::Del => libc::NFT_MSG_DELSET, - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.table.get_family() as u16, - (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16, - seq, - ); - sys::nftnl_set_nlmsg_build_payload(header, self.set); - } +pub struct SetBuilder<K: DataType> { + inner: Set, + list: SetElementList, + _phantom: PhantomData<K>, } -impl<K> Drop for Set<K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_free(self.set) }; - } -} - -pub struct SetElemsIter<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} - -impl<'a, K> SetElemsIter<'a, K> { - fn new(set: &'a Set<K>) -> Self { - let iter = try_alloc!(unsafe { - sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set) +impl<K: DataType> SetBuilder<K> { + pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> { + let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; + let set_name = name.into(); + let set = Set::default() + .with_key_type(K::TYPE) + .with_key_len(K::LEN) + .with_table(table_name) + .with_name(&set_name); + + Ok(SetBuilder { + inner: set, + list: SetElementList { + table: Some(table_name.clone()), + set: Some(set_name), + elements: Some(SetElementListElements::default()), + }, + _phantom: PhantomData, + }) + } + + pub fn add(&mut self, key: &K) { + self.list.elements.as_mut().unwrap().add_value(SetElement { + key: Some(NfNetlinkData::default().with_value(key.data())), }); - SetElemsIter { - set, - iter, - ret: Rc::new(Cell::new(1)), - } } -} - -impl<'a, K> Iterator for SetElemsIter<'a, K> { - type Item = SetElemsMsg<'a, K>; - fn next(&mut self) -> Option<Self::Item> { - if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } { - trace!("SetElemsIter iterator ending"); - None - } else { - trace!("SetElemsIter returning new SetElemsMsg"); - Some(SetElemsMsg { - set: self.set, - iter: self.iter, - ret: self.ret.clone(), - }) - } + pub fn finish(self) -> (Set, SetElementList) { + (self.inner, self.list) } } -impl<'a, K> Drop for SetElemsIter<'a, K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) }; - } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true, derive_deserialize = false)] +pub struct SetElementList { + #[field(NFTA_SET_ELEM_LIST_TABLE)] + pub table: String, + #[field(NFTA_SET_ELEM_LIST_SET)] + pub set: String, + #[field(NFTA_SET_ELEM_LIST_ELEMENTS)] + pub elements: SetElementListElements, } -pub struct SetElemsMsg<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} +impl NfNetlinkObject for SetElementList { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSETELEM; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSETELEM; -unsafe impl<'a, K> crate::NlMsg for SetElemsMsg<'a, K> { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - trace!("Writing SetElemsMsg to NlMsg"); - let (type_, flags) = match msg_type { - MsgType::Add => ( - libc::NFT_MSG_NEWSETELEM, - libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, - ), - MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.set.get_family() as u16, - flags as u16, - seq, - ); - self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter( - header, self.iter, - )); + fn get_family(&self) -> ProtocolFamily { + ProtocolFamily::Unspec } } -pub trait SetKey { - const TYPE: u32; - const LEN: u32; - - fn data(&self) -> Box<[u8]>; -} - -impl SetKey for Ipv4Addr { - const TYPE: u32 = 7; - const LEN: u32 = 4; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct SetElement { + #[field(NFTA_SET_ELEM_KEY)] + pub key: NfNetlinkData, } -impl SetKey for Ipv6Addr { - const TYPE: u32 = 8; - const LEN: u32 = 16; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } -} +type SetElementListElements = NfNetlinkList<SetElement>; diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..4384a1c --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,3 @@ +#![allow(non_camel_case_types, dead_code)] + +include!(concat!(env!("OUT_DIR"), "/sys.rs")); diff --git a/src/table.rs b/src/table.rs index 593fffb..81a26ef 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,171 +1,68 @@ -use crate::{MsgType, ProtoFamily}; -use crate::sys::{self, libc}; -#[cfg(feature = "query")] -use std::convert::TryFrom; -use std::{ - ffi::{c_void, CStr, CString}, - fmt::Debug, - os::raw::c_char, +use std::fmt::Debug; + +use rustables_macros::nfnetlink_struct; + +use crate::error::QueryError; +use crate::nlmsg::NfNetlinkObject; +use crate::sys::{ + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFT_MSG_NEWTABLE, }; +use crate::{Batch, ProtocolFamily}; -/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol +/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html +#[derive(Default, PartialEq, Eq, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Table { - table: *mut sys::nftnl_table, - family: ProtoFamily, + family: ProtocolFamily, + #[field(NFTA_TABLE_NAME)] + name: String, + #[field(NFTA_TABLE_FLAGS)] + flags: u32, + #[field(NFTA_TABLE_USERDATA)] + userdata: Vec<u8>, } impl Table { - /// Creates a new table instance with the given name and protocol family. - pub fn new<T: AsRef<CStr>>(name: &T, family: ProtoFamily) -> Table { - unsafe { - let table = try_alloc!(sys::nftnl_table_alloc()); - - sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FAMILY as u16, family as u32); - sys::nftnl_table_set_str(table, sys::NFTNL_TABLE_NAME as u16, name.as_ref().as_ptr()); - sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FLAGS as u16, 0u32); - Table { table, family } - } - } - - pub unsafe fn from_raw(table: *mut sys::nftnl_table, family: ProtoFamily) -> Self { - Table { table, family } - } - - /// Returns the name of this table. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } - } - - /// Returns a textual description of the table. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_table_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.table, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Returns the protocol family for this table. - pub fn get_family(&self) -> ProtoFamily { - self.family - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } + pub fn new(family: ProtocolFamily) -> Table { + let mut res = Self::default(); + res.family = family; + res } - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_table_set_str(self.table, sys::NFTNL_TABLE_USERDATA as u16, data.as_ptr()); - } - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_table { - self.table as *const sys::nftnl_table - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&self) -> *mut sys::nftnl_table { - self.table - } -} - -impl PartialEq for Table { - fn eq(&self, other: &Self) -> bool { - self.get_name() == other.get_name() && self.get_family() == other.get_family() + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl Debug for Table { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} +impl NfNetlinkObject for Table { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWTABLE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELTABLE; -unsafe impl crate::NlMsg for Table { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let raw_msg_type = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWTABLE, - MsgType::Del => libc::NFT_MSG_DELTABLE, - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - raw_msg_type as u16, - self.family as u16, - libc::NLM_F_ACK as u16, - seq, - ); - sys::nftnl_table_nlmsg_build_payload(header, self.table); - } -} - -impl Drop for Table { - fn drop(&mut self) { - unsafe { sys::nftnl_table_free(self.table) }; + fn get_family(&self) -> ProtocolFamily { + self.family } -} -#[cfg(feature = "query")] -/// A callback to parse the response for messages created with `get_tables_nlmsg`. -pub fn get_tables_cb( - header: &libc::nlmsghdr, - (_, tables): &mut (&(), &mut Vec<Table>), -) -> libc::c_int { - unsafe { - let table = sys::nftnl_table_alloc(); - if table == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_table_nlmsg_parse(header, table); - if err < 0 { - error!("Failed to parse nelink table message - {}", err); - sys::nftnl_table_free(table); - return err; - } - let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16); - match crate::ProtoFamily::try_from(family as i32) { - Ok(family) => { - tables.push(Table::from_raw(table, family)); - mnl::mnl_sys::MNL_CB_OK - } - Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); - sys::nftnl_table_free(table); - mnl::mnl_sys::MNL_CB_ERROR - } - } + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -#[cfg(feature = "query")] -pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None) +pub fn list_tables() -> Result<Vec<Table>, QueryError> { + let mut result = Vec::new(); + crate::query::list_objects_with_data( + NFT_MSG_GETTABLE as u16, + &|table: Table, tables: &mut Vec<Table>| { + tables.push(table); + Ok(()) + }, + None, + &mut result, + )?; + Ok(result) } diff --git a/src/tests/batch.rs b/src/tests/batch.rs new file mode 100644 index 0000000..12f373f --- /dev/null +++ b/src/tests/batch.rs @@ -0,0 +1,96 @@ +use std::mem::size_of; + +use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST}; +use nix::libc::NFNL_MSG_BATCH_END; + +use crate::nlmsg::{pad_netlink_object_with_variable_size, NfNetlinkDeserializable}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; +use crate::{Batch, MsgType, Table}; + +use super::get_test_table; + +const HEADER_SIZE: u32 = + pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()) as u32; + +const DEFAULT_BATCH_BEGIN_HDR: nlmsghdr = nlmsghdr { + nlmsg_len: HEADER_SIZE, + nlmsg_flags: NLM_F_REQUEST as u16, + nlmsg_type: NFNL_MSG_BATCH_BEGIN as u16, + nlmsg_seq: 0, + nlmsg_pid: 0, +}; +const DEFAULT_BATCH_MSG: NlMsg = NlMsg::NfGenMsg( + nfgenmsg { + nfgen_family: AF_UNSPEC as u8, + version: NFNETLINK_V0 as u8, + res_id: NFNL_SUBSYS_NFTABLES as u16, + }, + &[], +); + +const DEFAULT_BATCH_END_HDR: nlmsghdr = nlmsghdr { + nlmsg_len: HEADER_SIZE, + nlmsg_flags: NLM_F_REQUEST as u16, + nlmsg_type: NFNL_MSG_BATCH_END as u16, + nlmsg_seq: 1, + nlmsg_pid: 0, +}; + +#[test] +fn batch_empty() { + let batch = Batch::new(); + let buf = batch.finalize(); + + let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); + + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + + let (hdr, msg) = parse_nlmsg(&buf[remaining_data_offset..]).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_END_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); +} + +#[test] +fn batch_with_objects() { + let mut original_tables = vec![]; + for i in 0..10 { + let mut table = get_test_table(); + table.set_userdata(vec![i as u8]); + original_tables.push(table); + } + + let mut batch = Batch::new(); + for i in 0..10 { + batch.add( + &original_tables[i], + if i % 2 == 0 { + MsgType::Add + } else { + MsgType::Del + }, + ); + } + let buf = batch.finalize(); + + let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); + let mut remaining_data = &buf[pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize)..]; + + for i in 0..10 { + let (deserialized_table, rest) = + Table::deserialize(&remaining_data).expect("could not deserialize a table"); + remaining_data = rest; + + assert_eq!(deserialized_table, original_tables[i]); + } + + let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message"); + let mut end_hdr = DEFAULT_BATCH_END_HDR; + end_hdr.nlmsg_seq = 11; + assert_eq!(hdr, end_hdr); + assert_eq!(msg, DEFAULT_BATCH_MSG); +} diff --git a/src/tests/chain.rs b/src/tests/chain.rs new file mode 100644 index 0000000..7f696e6 --- /dev/null +++ b/src/tests/chain.rs @@ -0,0 +1,120 @@ +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, + NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, + }, + ChainType, Hook, HookClass, MsgType, +}; + +use super::{ + get_test_chain, get_test_nlmsg, get_test_nlmsg_with_msg_type, NetlinkExpr, CHAIN_NAME, + CHAIN_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_chain() { + let mut chain = get_test_chain(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_chain_with_hook_and_type() { + let mut chain = get_test_chain() + .with_hook(Hook::new(HookClass::In, 0)) + .with_type(ChainType::Filter); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 84); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_TYPE, "filter".as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_CHAIN_HOOK, + vec![ + NetlinkExpr::List(vec![NetlinkExpr::Final( + NFTA_HOOK_HOOKNUM, + vec![0, 0, 0, 1] + )]), + NetlinkExpr::List(vec![NetlinkExpr::Final( + NFTA_HOOK_PRIORITY, + vec![0, 0, 0, 0] + )]) + ] + ), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_chain_with_userdata() { + let mut chain = get_test_chain(); + chain.set_userdata(CHAIN_USERDATA); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 72); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_USERDATA, CHAIN_USERDATA.as_bytes().to_vec()) + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_chain() { + let mut chain = get_test_chain(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut chain, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} diff --git a/tests/expr.rs b/src/tests/expr.rs index 4af18f2..35c4fea 100644 --- a/tests/expr.rs +++ b/src/tests/expr.rs @@ -1,53 +1,53 @@ -use rustables::expr::{ - Bitwise, Cmp, CmpOp, Conntrack, Counter, Expression, HeaderField, IcmpCode, Immediate, Log, - LogGroup, LogPrefix, Lookup, Meta, Nat, NatType, Payload, Register, Reject, TcpHeaderField, - TransportHeaderField, Verdict, -}; -use rustables::set::Set; -use rustables::sys::libc::{nlmsghdr, NF_DROP}; -use rustables::{ProtoFamily, Rule}; -use std::ffi::CStr; use std::net::Ipv4Addr; -mod sys; -use sys::*; - -mod lib; -use lib::*; +use libc::NF_DROP; -fn get_test_nlmsg_from_expr( - rule: &mut Rule, - expr: &impl Expression, -) -> (nlmsghdr, Nfgenmsg, Vec<u8>) { - rule.add_expr(expr); +use crate::{ + expr::{ + Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, + HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat, + NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, + }, + set::SetBuilder, + sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, + NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES, + NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, + NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, + NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG, + NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, + NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, + NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, + NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, + NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, + }, + tests::{get_test_table, SET_NAME}, + ProtocolFamily, +}; - let (nlmsghdr, nfgenmsg, raw_expr) = get_test_nlmsg(rule); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWRULE as u8 - ); - (nlmsghdr, nfgenmsg, raw_expr) -} +use super::{get_test_nlmsg, get_test_rule, NetlinkExpr, CHAIN_NAME, TABLE_NAME}; #[test] fn bitwise_expr_is_valid() { let netmask = Ipv4Addr::new(255, 255, 255, 0); - let bitwise = Bitwise::new(netmask, 0); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &bitwise); + let bitwise = Bitwise::new(netmask.octets(), [0, 0, 0, 0]).unwrap(); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(bitwise)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 124); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -86,22 +86,25 @@ fn bitwise_expr_is_valid() { #[test] fn cmp_expr_is_valid() { - let cmp = Cmp::new(CmpOp::Eq, 0); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &cmp); + let val = [1u8, 2, 3, 4]; + let cmp = Cmp::new(CmpOp::Eq, val.clone()); + let mut rule = get_test_rule().with_expressions(vec![cmp]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 100); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -109,7 +112,7 @@ fn cmp_expr_is_valid() { NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()), NetlinkExpr::Nested( NFTA_CMP_DATA, - vec![NetlinkExpr::Final(1u16, 0u32.to_be_bytes().to_vec())] + vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val.to_vec())] ) ] ) @@ -125,25 +128,27 @@ fn cmp_expr_is_valid() { fn counter_expr_is_valid() { let nb_bytes = 123456u64; let nb_packets = 987u64; - let mut counter = Counter::new(); - counter.nb_bytes = nb_bytes; - counter.nb_packets = nb_packets; + let counter = Counter::default() + .with_nb_bytes(nb_bytes) + .with_nb_packets(nb_packets); + + let mut rule = get_test_rule().with_expressions(vec![counter]); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &counter); + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 100); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -167,22 +172,24 @@ fn counter_expr_is_valid() { #[test] fn ct_expr_is_valid() { - let ct = Conntrack::State; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &ct); + let ct = Conntrack::default().with_retrieve_value(ConntrackKey::State); + let mut rule = get_test_rule().with_expressions(vec![ct]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 88); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -203,22 +210,25 @@ fn ct_expr_is_valid() { #[test] fn immediate_expr_is_valid() { - let immediate = Immediate::new(42u8, Register::Reg1); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &immediate); + let immediate = Immediate::new_data(vec![42u8], Register::Reg1); + let mut rule = + get_test_rule().with_expressions(ExpressionList::default().with_value(immediate)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 100); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -242,30 +252,29 @@ fn immediate_expr_is_valid() { #[test] fn log_expr_is_valid() { - let log = Log { - group: Some(LogGroup(1)), - prefix: Some(LogPrefix::new("mockprefix").unwrap()), - }; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &log); + let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression"); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(log)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"log\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"log".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ - NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix\0".to_vec()), - NetlinkExpr::Final(NFTA_LOG_GROUP, 1u16.to_be_bytes().to_vec()) + NetlinkExpr::Final(NFTA_LOG_GROUP, 1337u16.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix".to_vec()), ] ) ] @@ -278,36 +287,38 @@ fn log_expr_is_valid() { #[test] fn lookup_expr_is_valid() { - let set_name = &CStr::from_bytes_with_nul(b"mockset\0").unwrap(); - let mut rule = get_test_rule(); - let table = rule.get_chain().get_table(); - let mut set = Set::new(set_name, 0, table, ProtoFamily::Inet); + let table = get_test_table(); + let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap(); let address: Ipv4Addr = [8, 8, 8, 8].into(); - set.add(&address); + set_builder.add(&address); + let (set, _set_elements) = set_builder.finish(); let lookup = Lookup::new(&set).unwrap(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &lookup); - assert_eq!(nlmsghdr.nlmsg_len, 104); + + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()), NetlinkExpr::Final( NFTA_LOOKUP_SREG, NFT_REG_1.to_be_bytes().to_vec() ), - NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset\0".to_vec()), - NetlinkExpr::Final(NFTA_LOOKUP_SET_ID, 0u32.to_be_bytes().to_vec()), ] ) ] @@ -318,25 +329,26 @@ fn lookup_expr_is_valid() { ); } -use rustables::expr::Masquerade; #[test] fn masquerade_expr_is_valid() { - let masquerade = Masquerade; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &masquerade); - assert_eq!(nlmsghdr.nlmsg_len, 76); + let masquerade = Masquerade::default(); + let mut rule = get_test_rule().with_expressions(vec![masquerade]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 72); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq".to_vec()), NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]), ] )] @@ -348,22 +360,26 @@ fn masquerade_expr_is_valid() { #[test] fn meta_expr_is_valid() { - let meta = Meta::Protocol; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &meta); - assert_eq!(nlmsghdr.nlmsg_len, 92); + let meta = Meta::default() + .with_key(MetaType::Protocol) + .with_dreg(Register::Reg1); + let mut rule = get_test_rule().with_expressions(vec![meta]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 88); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -387,27 +403,27 @@ fn meta_expr_is_valid() { #[test] fn nat_expr_is_valid() { - let nat = Nat { - nat_type: NatType::SNat, - family: ProtoFamily::Ipv4, - ip_register: Register::Reg1, - port_register: None, - }; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &nat); + let nat = Nat::default() + .with_nat_type(NatType::SNat) + .with_family(ProtocolFamily::Ipv4) + .with_ip_register(Register::Reg1); + let mut rule = get_test_rule().with_expressions(vec![nat]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 96); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -417,7 +433,7 @@ fn nat_expr_is_valid() { ), NetlinkExpr::Final( NFTA_NAT_FAMILY, - (ProtoFamily::Ipv4 as u32).to_be_bytes().to_vec(), + (ProtocolFamily::Ipv4 as u32).to_be_bytes().to_vec(), ), NetlinkExpr::Final( NFTA_NAT_REG_ADDR_MIN, @@ -435,24 +451,26 @@ fn nat_expr_is_valid() { #[test] fn payload_expr_is_valid() { - let tcp_header_field = TcpHeaderField::Sport; + let tcp_header_field = TCPHeaderField::Sport; let transport_header_field = TransportHeaderField::Tcp(tcp_header_field); - let payload = Payload::Transport(transport_header_field); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &payload); + let payload = HighLevelPayload::Transport(transport_header_field); + let mut rule = get_test_rule().with_expressions(vec![payload.build()]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 108); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -485,22 +503,25 @@ fn payload_expr_is_valid() { #[test] fn reject_expr_is_valid() { let code = IcmpCode::NoRoute; - let reject = Reject::Icmp(code); - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &reject); + let reject = Reject::default() + .with_type(RejectType::IcmpxUnreach) + .with_icmp_code(code); + let mut rule = get_test_rule().with_expressions(vec![reject]); + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 92); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ @@ -524,22 +545,24 @@ fn reject_expr_is_valid() { #[test] fn verdict_expr_is_valid() { - let verdict = Verdict::Drop; - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_from_expr(&mut rule, &verdict); + let verdict = Immediate::new_verdict(VerdictKind::Drop); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(verdict)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); assert_eq!(nlmsghdr.nlmsg_len, 104); assert_eq!( raw_expr, NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), NetlinkExpr::Nested( NFTA_RULE_EXPRESSIONS, vec![NetlinkExpr::Nested( NFTA_LIST_ELEM, vec![ - NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate\0".to_vec()), + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()), NetlinkExpr::Nested( NFTA_EXPR_DATA, vec![ diff --git a/src/tests/mod.rs b/src/tests/mod.rs new file mode 100644 index 0000000..75fe8b0 --- /dev/null +++ b/src/tests/mod.rs @@ -0,0 +1,193 @@ +use crate::data_type::DataType; +use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::set::{Set, SetBuilder}; +use crate::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; + +mod batch; +mod chain; +mod expr; +mod rule; +mod set; +mod table; + +pub const TABLE_NAME: &'static str = "mocktable"; +pub const CHAIN_NAME: &'static str = "mockchain"; +pub const SET_NAME: &'static str = "mockset"; + +pub const TABLE_USERDATA: &'static str = "mocktabledata"; +pub const CHAIN_USERDATA: &'static str = "mockchaindata"; +pub const RULE_USERDATA: &'static str = "mockruledata"; +pub const SET_USERDATA: &'static str = "mocksetdata"; + +type NetLinkType = u16; + +#[derive(Debug, thiserror::Error)] +#[error("empty data")] +pub struct EmptyDataError; + +#[derive(Debug, Clone, Eq, Ord)] +pub enum NetlinkExpr { + Nested(NetLinkType, Vec<NetlinkExpr>), + Final(NetLinkType, Vec<u8>), + List(Vec<NetlinkExpr>), +} + +impl NetlinkExpr { + pub fn to_raw(self) -> Vec<u8> { + match self.sort() { + NetlinkExpr::Final(ty, val) => { + let len = val.len() + 4; + let mut res = Vec::with_capacity(len); + + res.extend(&(len as u16).to_le_bytes()); + res.extend(&ty.to_le_bytes()); + res.extend(val); + // alignment + while res.len() % 4 != 0 { + res.push(0); + } + + res + } + NetlinkExpr::Nested(ty, exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut sub = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + sub.append(&mut expr.to_raw()); + } + + let len = sub.len() + 4; + let mut res = Vec::with_capacity(len); + + // set the "NESTED" flag + res.extend(&(len as u16).to_le_bytes()); + res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes()); + res.extend(sub); + + res + } + NetlinkExpr::List(exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut list = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + list.append(&mut expr.to_raw()); + } + + list + } + } + } + + pub fn sort(self) -> Self { + match self { + NetlinkExpr::Final(_, _) => self, + NetlinkExpr::Nested(ty, mut exprs) => { + exprs.sort(); + NetlinkExpr::Nested(ty, exprs) + } + NetlinkExpr::List(mut exprs) => { + exprs.sort(); + NetlinkExpr::List(exprs) + } + } + } +} + +impl PartialEq for NetlinkExpr { + fn eq(&self, other: &Self) -> bool { + match (self.clone().sort(), other.clone().sort()) { + (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2, + _ => false, + } + } +} + +impl PartialOrd for NetlinkExpr { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + match (self, other) { + ( + NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _), + NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _), + ) => k1.partial_cmp(k2), + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2), + (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less), + (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater), + } + } +} + +pub fn get_test_table() -> Table { + Table::new(ProtocolFamily::Inet) + .with_name(TABLE_NAME) + .with_flags(0u32) +} + +pub fn get_test_table_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), + ]) + .sort() +} + +pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.as_bytes().to_vec()), + ]) + .sort() +} + +pub fn get_test_chain() -> Chain { + Chain::new(&get_test_table()).with_name(CHAIN_NAME) +} + +pub fn get_test_rule() -> Rule { + Rule::new(&get_test_chain()).unwrap() +} + +pub fn get_test_set<K: DataType>() -> Set { + SetBuilder::<K>::new(SET_NAME, &get_test_table()) + .expect("Couldn't create a set") + .finish() + .0 + .with_userdata(SET_USERDATA) +} + +pub fn get_test_nlmsg_with_msg_type<'a>( + buf: &'a mut Vec<u8>, + obj: &mut impl NfNetlinkObject, + msg_type: MsgType, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + let mut writer = NfNetlinkWriter::new(buf); + obj.add_or_remove(&mut writer, msg_type, 0); + + let (hdr, msg) = parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); + + let (nfgenmsg, raw_value) = match msg { + NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), + _ => panic!("Invalid return value type, expected a valid message"), + }; + + // sanity checks on the global message (this should be very similar/factorisable for the + // most part in other tests) + // TODO: check the messages flags + assert_eq!(nfgenmsg.res_id.to_be(), 0); + + (hdr, nfgenmsg, raw_value) +} + +pub fn get_test_nlmsg<'a>( + buf: &'a mut Vec<u8>, + obj: &mut impl NfNetlinkObject, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add) +} diff --git a/src/tests/rule.rs b/src/tests/rule.rs new file mode 100644 index 0000000..08b4139 --- /dev/null +++ b/src/tests/rule.rs @@ -0,0 +1,132 @@ +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA, + NFT_MSG_DELRULE, NFT_MSG_NEWRULE, + }, + MsgType, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_rule, NetlinkExpr, CHAIN_NAME, + RULE_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_rule() { + let mut rule = get_test_rule(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_rule_with_userdata() { + let mut rule = get_test_rule().with_userdata(RULE_USERDATA); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 68); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.as_bytes().to_vec()) + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_rule_with_position_and_handle() { + let handle: u64 = 1337; + let position: u64 = 42; + let mut rule = get_test_rule().with_handle(handle).with_position(position); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 76); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule() { + let mut rule = get_test_rule(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule_with_handle() { + let handle: u64 = 42; + let mut rule = get_test_rule().with_handle(handle); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 64); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} diff --git a/src/tests/set.rs b/src/tests/set.rs new file mode 100644 index 0000000..6c8247c --- /dev/null +++ b/src/tests/set.rs @@ -0,0 +1,119 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use crate::{ + data_type::DataType, + nlmsg::get_operation_from_nlmsghdr_type, + set::SetBuilder, + sys::{ + NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, + }, + MsgType, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, + SET_NAME, SET_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_set() { + let mut set = get_test_set::<Ipv4Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut set); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 80); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_set() { + let mut set = get_test_set::<Ipv6Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut set, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 80); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_set_with_data() { + let ip1 = Ipv4Addr::new(127, 0, 0, 1); + let ip2 = Ipv4Addr::new(1, 1, 1, 1); + let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table()) + .expect("Couldn't create a set"); + + set_builder.add(&ip1); + set_builder.add(&ip2); + let (_set, mut elem_list) = set_builder.finish(); + + let mut buf = Vec::new(); + + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut elem_list); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSETELEM as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 84); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_SET, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_SET_ELEM_LIST_ELEMENTS, + vec![ + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip1.data().to_vec())] + )] + ), + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip2.data().to_vec())] + )] + ), + ] + ), + ]) + .to_raw() + ); +} diff --git a/src/tests/table.rs b/src/tests/table.rs new file mode 100644 index 0000000..39bf399 --- /dev/null +++ b/src/tests/table.rs @@ -0,0 +1,67 @@ +use crate::{ + nlmsg::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize, NfNetlinkDeserializable}, + sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE}, + MsgType, Table, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_table, get_test_table_raw_expr, + get_test_table_with_userdata_raw_expr, TABLE_USERDATA, +}; + +#[test] +fn new_empty_table() { + let mut table = get_test_table(); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 44); + + assert_eq!(raw_expr, get_test_table_raw_expr().to_raw()); +} + +#[test] +fn new_empty_table_with_userdata() { + let mut table = get_test_table(); + table.set_userdata(TABLE_USERDATA.as_bytes().to_vec()); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 64); + + assert_eq!(raw_expr, get_test_table_with_userdata_raw_expr().to_raw()); +} + +#[test] +fn delete_empty_table() { + let mut table = get_test_table(); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut table, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 44); + + assert_eq!(raw_expr, get_test_table_raw_expr().to_raw()); +} + +#[test] +fn parse_table() { + let mut table = get_test_table(); + table.set_userdata(TABLE_USERDATA.as_bytes().to_vec()); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (_nlmsghdr, _nfgenmsg, _raw_expr) = get_test_nlmsg(&mut buf, &mut table); + + let (deserialized_table, remaining) = + Table::deserialize(&buf).expect("Couldn't deserialize the object"); + assert_eq!(table, deserialized_table); + assert_eq!(remaining.len(), 0); +} diff --git a/tests/chain.rs b/tests/chain.rs deleted file mode 100644 index 4b6da91..0000000 --- a/tests/chain.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::ffi::CStr; - -mod sys; -use rustables::MsgType; -use sys::*; - -mod lib; -use lib::*; - -#[test] -fn new_empty_chain() { - let mut chain = get_test_chain(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut chain); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWCHAIN as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 52); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn new_empty_chain_with_userdata() { - let mut chain = get_test_chain(); - chain.set_userdata(CStr::from_bytes_with_nul(CHAIN_USERDATA).unwrap()); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut chain); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWCHAIN as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 72); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.to_vec()), - NetlinkExpr::Final(NFTA_CHAIN_USERDATA, CHAIN_USERDATA.to_vec()) - ]) - .to_raw() - ); -} - -#[test] -fn delete_empty_chain() { - let mut chain = get_test_chain(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut chain, MsgType::Del); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_DELCHAIN as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 52); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.to_vec()), - ]) - .to_raw() - ); -} diff --git a/tests/lib.rs b/tests/lib.rs deleted file mode 100644 index 0d7132c..0000000 --- a/tests/lib.rs +++ /dev/null @@ -1,169 +0,0 @@ -#![allow(dead_code)] -use libc::{nlmsghdr, AF_UNIX, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; -use rustables::set::SetKey; -use rustables::{nft_nlmsg_maxsize, Chain, MsgType, NlMsg, ProtoFamily, Rule, Set, Table}; -use std::ffi::{c_void, CStr}; -use std::mem::size_of; -use std::rc::Rc; - -pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { - ((x & 0xff00) >> 8) as u8 -} - -pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { - (x & 0x00ff) as u8 -} - -pub const TABLE_NAME: &[u8; 10] = b"mocktable\0"; -pub const CHAIN_NAME: &[u8; 10] = b"mockchain\0"; -pub const SET_NAME: &[u8; 8] = b"mockset\0"; - -pub const TABLE_USERDATA: &[u8; 14] = b"mocktabledata\0"; -pub const CHAIN_USERDATA: &[u8; 14] = b"mockchaindata\0"; -pub const RULE_USERDATA: &[u8; 13] = b"mockruledata\0"; -pub const SET_USERDATA: &[u8; 12] = b"mocksetdata\0"; - -pub const SET_ID: u32 = 123456; - -type NetLinkType = u16; - -#[derive(Debug, thiserror::Error)] -#[error("empty data")] -pub struct EmptyDataError; - -#[derive(Debug, PartialEq)] -pub enum NetlinkExpr { - Nested(NetLinkType, Vec<NetlinkExpr>), - Final(NetLinkType, Vec<u8>), - List(Vec<NetlinkExpr>), -} - -impl NetlinkExpr { - pub fn to_raw(self) -> Vec<u8> { - match self { - NetlinkExpr::Final(ty, val) => { - let len = val.len() + 4; - let mut res = Vec::with_capacity(len); - - res.extend(&(len as u16).to_le_bytes()); - res.extend(&ty.to_le_bytes()); - res.extend(val); - // alignment - while res.len() % 4 != 0 { - res.push(0); - } - - res - } - NetlinkExpr::Nested(ty, exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut sub = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - sub.append(&mut expr.to_raw()); - } - - let len = sub.len() + 4; - let mut res = Vec::with_capacity(len); - - // set the "NESTED" flag - res.extend(&(len as u16).to_le_bytes()); - res.extend(&(ty | 0x8000).to_le_bytes()); - res.extend(sub); - - res - } - NetlinkExpr::List(exprs) => { - // some heuristic to decrease allocations (even though this is - // only useful for testing so performance is not an objective) - let mut list = Vec::with_capacity(exprs.len() * 50); - - for expr in exprs { - list.append(&mut expr.to_raw()); - } - - list - } - } - } -} - -#[repr(C)] -#[derive(Clone, Copy)] -pub struct Nfgenmsg { - family: u8, /* AF_xxx */ - version: u8, /* nfnetlink version */ - res_id: u16, /* resource id */ -} - -pub fn get_test_table() -> Table { - Table::new( - &CStr::from_bytes_with_nul(TABLE_NAME).unwrap(), - ProtoFamily::Inet, - ) -} - -pub fn get_test_chain() -> Chain { - Chain::new( - &CStr::from_bytes_with_nul(CHAIN_NAME).unwrap(), - Rc::new(get_test_table()), - ) -} - -pub fn get_test_rule() -> Rule { - Rule::new(Rc::new(get_test_chain())) -} - -pub fn get_test_set<T: SetKey>() -> Set<T> { - Set::new( - CStr::from_bytes_with_nul(SET_NAME).unwrap(), - SET_ID, - Rc::new(get_test_table()), - ProtoFamily::Ipv4, - ) -} - -pub fn get_test_nlmsg_with_msg_type( - obj: &mut dyn NlMsg, - msg_type: MsgType, -) -> (nlmsghdr, Nfgenmsg, Vec<u8>) { - let mut buf = vec![0u8; nft_nlmsg_maxsize() as usize]; - unsafe { - obj.write(buf.as_mut_ptr() as *mut c_void, 0, msg_type); - - // right now the message is composed of the following parts: - // - nlmsghdr (contains the message size and type) - // - nfgenmsg (nftables header that describes the message family) - // - the raw value that we want to validate - - let size_of_hdr = size_of::<nlmsghdr>(); - let size_of_nfgenmsg = size_of::<Nfgenmsg>(); - let nlmsghdr = *(buf[0..size_of_hdr].as_ptr() as *const nlmsghdr); - let nfgenmsg = - *(buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const Nfgenmsg); - let raw_value = buf[size_of_hdr + size_of_nfgenmsg..nlmsghdr.nlmsg_len as usize] - .iter() - .map(|x| *x) - .collect(); - - // sanity checks on the global message (this should be very similar/factorisable for the - // most part in other tests) - // TODO: check the messages flags - assert_eq!( - get_subsystem_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFNL_SUBSYS_NFTABLES as u8 - ); - assert_eq!(nlmsghdr.nlmsg_seq, 0); - assert_eq!(nlmsghdr.nlmsg_pid, 0); - assert_eq!(nfgenmsg.family, AF_UNIX as u8); - assert_eq!(nfgenmsg.version, NFNETLINK_V0 as u8); - assert_eq!(nfgenmsg.res_id.to_be(), 0); - - (nlmsghdr, nfgenmsg, raw_value) - } -} - -pub fn get_test_nlmsg(obj: &mut dyn NlMsg) -> (nlmsghdr, Nfgenmsg, Vec<u8>) { - get_test_nlmsg_with_msg_type(obj, MsgType::Add) -} diff --git a/tests/rule.rs b/tests/rule.rs deleted file mode 100644 index b601a61..0000000 --- a/tests/rule.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::ffi::CStr; - -mod sys; -use rustables::MsgType; -use sys::*; - -mod lib; -use lib::*; - -#[test] -fn new_empty_rule() { - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut rule); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 52); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn new_empty_rule_with_userdata() { - let mut rule = get_test_rule(); - rule.set_userdata(CStr::from_bytes_with_nul(RULE_USERDATA).unwrap()); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut rule); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 72); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.to_vec()) - ]) - .to_raw() - ); -} - -#[test] -fn new_empty_rule_with_position_and_handle() { - let handle = 1337; - let position = 42; - let mut rule = get_test_rule(); - rule.set_handle(handle); - rule.set_position(position); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut rule); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 76); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn delete_empty_rule() { - let mut rule = get_test_rule(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut rule, MsgType::Del); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_DELRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 52); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn delete_empty_rule_with_handle() { - let handle = 42; - let mut rule = get_test_rule(); - rule.set_handle(handle); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut rule, MsgType::Del); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_DELRULE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 64); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.to_vec()), - NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} diff --git a/tests/set.rs b/tests/set.rs deleted file mode 100644 index d5b2ad7..0000000 --- a/tests/set.rs +++ /dev/null @@ -1,66 +0,0 @@ -mod sys; -use std::net::{Ipv4Addr, Ipv6Addr}; - -use rustables::{set::SetKey, MsgType}; -use sys::*; - -mod lib; -use lib::*; - -#[test] -fn new_empty_set() { - let mut set = get_test_set::<Ipv4Addr>(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut set); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWSET as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 80); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.to_vec()), - NetlinkExpr::Final( - NFTA_SET_FLAGS, - ((libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32) - .to_be_bytes() - .to_vec() - ), - NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn delete_empty_set() { - let mut set = get_test_set::<Ipv6Addr>(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut set, MsgType::Del); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_DELSET as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 80); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.to_vec()), - NetlinkExpr::Final( - NFTA_SET_FLAGS, - ((libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32) - .to_be_bytes() - .to_vec() - ), - NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_SET_ID, SET_ID.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} diff --git a/tests/table.rs b/tests/table.rs deleted file mode 100644 index 3d8957c..0000000 --- a/tests/table.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::ffi::CStr; - -mod sys; -use rustables::MsgType; -use sys::*; - -mod lib; -use lib::*; - -#[test] -fn new_empty_table() { - let mut table = get_test_table(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut table); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWTABLE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 44); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} - -#[test] -fn new_empty_table_with_userdata() { - let mut table = get_test_table(); - table.set_userdata(CStr::from_bytes_with_nul(TABLE_USERDATA).unwrap()); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut table); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_NEWTABLE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 64); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.to_vec()) - ]) - .to_raw() - ); -} - -#[test] -fn delete_empty_table() { - let mut table = get_test_table(); - let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg_with_msg_type(&mut table, MsgType::Del); - assert_eq!( - get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), - NFT_MSG_DELTABLE as u8 - ); - assert_eq!(nlmsghdr.nlmsg_len, 44); - - assert_eq!( - raw_expr, - NetlinkExpr::List(vec![ - NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.to_vec()), - NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), - ]) - .to_raw() - ); -} |