aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHimbeer <himbeer@disroot.org>2024-08-19 23:14:03 +0200
committerHimbeer <himbeer@disroot.org>2024-08-19 23:14:03 +0200
commitca6753281558526fe17961a61094966e3f04eadb (patch)
treed001e7474fd22d5dcd28745db376d969ce2f0343
parent7a0aad65f07deefbf17052682e6908694d79e30b (diff)
Speed up reverse lookups using a bidirectional hashmap
Previously reverse lookups required iteration of all hashmap entries. This commit makes it benefit from the same principle introduced in aee8706bc84d98e96f7a30f1c3dbc1411f4d06a6 by adding a second hashmap for reverse lookups so that a a simple indexing operation can be used instead of iterating over the key-value pairs manually.
-rw-r--r--src/main.rs109
1 files changed, 61 insertions, 48 deletions
diff --git a/src/main.rs b/src/main.rs
index c48f813..32fb013 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,6 +25,9 @@ const UPSTREAM_PRIMARY: &str = "[2620:fe::fe]:53";
const UPSTREAM_SECONDARY: &str = "9.9.9.9:53";
const UPSTREAM_TIMEOUT: Duration = Duration::from_secs(3);
+type ForwardHosts = Arc<RwLock<HashMap<String, IpAddr>>>;
+type ReverseHosts = Arc<RwLock<HashMap<IpAddr, String>>>;
+
#[derive(Debug, Error)]
pub enum Error {
#[error("hosts entry missing address column: {0}")]
@@ -99,26 +102,27 @@ fn read_leases(cache: Arc<RwLock<Vec<Lease>>>) -> Result<()> {
Ok(())
}
-fn refresh_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> {
+fn refresh_hosts(forward: ForwardHosts, reverse: ReverseHosts) -> Result<()> {
let mut signals = Signals::new([SIGUSR2])?;
for _ in signals.forever() {
- read_hosts(cache.clone())?;
+ read_hosts(forward.clone(), reverse.clone())?;
}
Ok(()) // unreachable
}
-fn refresh_hosts_supervised(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> ! {
+fn refresh_hosts_supervised(forward: ForwardHosts, reverse: ReverseHosts) -> ! {
loop {
- match refresh_hosts(cache.clone()) {
+ match refresh_hosts(forward.clone(), reverse.clone()) {
Ok(_) => {}
Err(e) => println!("[warn] hosts refresh: {}", e),
}
}
}
-fn read_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> {
- let mut hosts = HashMap::new();
+fn read_hosts(forward: ForwardHosts, reverse: ReverseHosts) -> Result<()> {
+ let mut forward_hosts = HashMap::new();
+ let mut reverse_hosts = HashMap::new();
let file = match File::open("/data/hosts.dnsd") {
Ok(file) => file,
@@ -137,12 +141,18 @@ fn read_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> {
let mut columns = split_input.split_whitespace();
let addr = columns.next().ok_or(Error::NoAddrColumn(line))?;
+ let parsed_addr = addr.parse()?;
for host in columns {
- hosts.insert(host.to_string() + ".", addr.parse()?);
+ let host = host.to_string() + ".";
+
+ forward_hosts.insert(host.clone(), parsed_addr);
+ reverse_hosts.insert(parsed_addr, host);
}
}
- *cache.write().unwrap() = hosts;
+ *forward.write().unwrap() = forward_hosts;
+ *reverse.write().unwrap() = reverse_hosts;
+
Ok(())
}
@@ -155,11 +165,13 @@ fn main() -> Result<()> {
let leases2 = leases.clone();
thread::spawn(move || refresh_leases_supervised(leases2));
- let hosts = Arc::new(RwLock::new(HashMap::new()));
- read_hosts(hosts.clone())?;
+ let forward_hosts = Arc::new(RwLock::new(HashMap::new()));
+ let reverse_hosts = Arc::new(RwLock::new(HashMap::new()));
+ read_hosts(forward_hosts.clone(), reverse_hosts.clone())?;
- let hosts2 = hosts.clone();
- thread::spawn(move || refresh_hosts_supervised(hosts2));
+ let forward2 = forward_hosts.clone();
+ let reverse2 = reverse_hosts.clone();
+ thread::spawn(move || refresh_hosts_supervised(forward2, reverse2));
let domain = match fs::read_to_string("/data/dnsd.domain") {
Ok(v) => match Name::from_utf8(v) {
@@ -186,9 +198,10 @@ fn main() -> Result<()> {
let sock2 = sock.try_clone()?;
let buf = buf.to_vec();
let leases3 = leases.clone();
- let hosts3 = hosts.clone();
- thread::spawn(
- move || match handle_query(&domain2, &sock2, &buf, raddr, leases3, hosts3) {
+ let forward3 = forward_hosts.clone();
+ let reverse3 = reverse_hosts.clone();
+ thread::spawn(move || {
+ match handle_query(&domain2, &sock2, &buf, raddr, leases3, forward3, reverse3) {
Ok(_) => {}
Err(e) => {
match respond_with_error(&sock2, &buf, raddr) {
@@ -198,8 +211,8 @@ fn main() -> Result<()> {
print_query_error(&buf, raddr, e);
}
- },
- );
+ }
+ });
}
}
@@ -260,7 +273,8 @@ fn handle_query(
buf: &[u8],
raddr: SocketAddr,
leases: Arc<RwLock<Vec<Lease>>>,
- hosts: Arc<RwLock<HashMap<String, IpAddr>>>,
+ forward_hosts: ForwardHosts,
+ reverse_hosts: ReverseHosts,
) -> Result<()> {
let bytes = Bytes::copy_from_slice(buf);
let mut msg = Dns::decode(bytes)?;
@@ -278,7 +292,8 @@ fn handle_query(
let (lan, fwd): (_, Vec<Question>) = msg.questions.into_iter().partition(|q| {
let known = is_file_known(
&usable_name(domain, &q.domain_name).expect("can't convert domain name"),
- hosts.clone(),
+ forward_hosts.clone(),
+ reverse_hosts.clone(),
) || is_dhcp_known(
&usable_name(domain, &q.domain_name).expect("can't convert domain name"),
leases.clone(),
@@ -314,7 +329,8 @@ fn handle_query(
let hostname = usable_name(domain, &q.domain_name).expect("can't convert domain name");
if q.q_type == QType::A {
- if let Some(entry) = file_entry(&hostname, hosts.clone()) {
+ if let Some(entry) = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone())
+ {
let IpAddr::V4(addr_as_v4) = entry.1 else {
return None;
};
@@ -345,7 +361,7 @@ fn handle_query(
Some(answer)
}
} else if q.q_type == QType::AAAA {
- let entry = file_entry(&hostname, hosts.clone())?;
+ let entry = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone())?;
let IpAddr::V6(addr_as_v6) = entry.1 else {
return None;
};
@@ -358,7 +374,8 @@ fn handle_query(
println!("[file] {} => {}", raddr, answer);
Some(answer)
} else if q.q_type == QType::PTR {
- if let Some(entry) = file_entry(&hostname, hosts.clone()) {
+ if let Some(entry) = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone())
+ {
let name = entry.0
+ &domain
.as_ref()
@@ -506,41 +523,37 @@ fn upstream_query<A: ToSocketAddrs>(upstream: A, bytes: &[u8]) -> Result<Dns> {
fn file_entry(
hostname: &Name,
- hosts: Arc<RwLock<HashMap<String, IpAddr>>>,
+ forward_hosts: ForwardHosts,
+ reverse_hosts: ReverseHosts,
) -> Option<(String, IpAddr)> {
- let hosts = hosts.read().unwrap();
- let (host, addr) = if Name::from_str("in-addr.arpa.").unwrap().zone_of(hostname)
- && hostname.iter().len() <= 6
- {
- let (host, addr) = hosts
- .iter()
- .filter(|(_, addr)| addr.is_ipv4())
- .find(|(_, addr)| {
- IpNet::new(**addr, 32).unwrap()
- == hostname.parse_arpa_name().expect("can't parse arpa name")
- })?;
- (host.clone(), *addr)
- } else if Name::from_str("ip6.arpa.").unwrap().zone_of(hostname) && hostname.iter().len() <= 34
+ let forward_hosts = forward_hosts.read().unwrap();
+ let reverse_hosts = reverse_hosts.read().unwrap();
+
+ let mapping = if (Name::from_str("in-addr.arpa.").unwrap().zone_of(hostname)
+ && hostname.iter().len() <= 6)
+ || (Name::from_str("ip6.arpa.").unwrap().zone_of(hostname) && hostname.iter().len() <= 34)
{
- let (host, addr) = hosts
- .iter()
- .filter(|(_, addr)| addr.is_ipv6())
- .find(|(_, addr)| {
- IpNet::new(**addr, 128).unwrap()
- == hostname.parse_arpa_name().expect("can't parse arpa name")
- })?;
- (host.clone(), *addr)
+ let addr = hostname
+ .parse_arpa_name()
+ .expect("can't parse arpa name")
+ .addr();
+ let host = reverse_hosts.get(&addr)?;
+ (host.to_string(), addr)
} else {
let hostname_utf8 = hostname.to_utf8();
- let addr = hosts.get(&hostname_utf8)?;
+ let addr = forward_hosts.get(&hostname_utf8)?;
(hostname_utf8, *addr)
};
- Some((host, addr))
+ Some(mapping)
}
-fn is_file_known(hostname: &Name, hosts: Arc<RwLock<HashMap<String, IpAddr>>>) -> bool {
- file_entry(hostname, hosts).is_some()
+fn is_file_known(
+ hostname: &Name,
+ forward_hosts: ForwardHosts,
+ reverse_hosts: ReverseHosts,
+) -> bool {
+ file_entry(hostname, forward_hosts, reverse_hosts).is_some()
}
fn find_lease(hostname: &Name, mut leases: impl Iterator<Item = Lease>) -> Option<Lease> {