diff options
author | Himbeer <himbeer@disroot.org> | 2024-08-19 23:14:03 +0200 |
---|---|---|
committer | Himbeer <himbeer@disroot.org> | 2024-08-19 23:14:03 +0200 |
commit | ca6753281558526fe17961a61094966e3f04eadb (patch) | |
tree | d001e7474fd22d5dcd28745db376d969ce2f0343 | |
parent | 7a0aad65f07deefbf17052682e6908694d79e30b (diff) |
Speed up reverse lookups using a bidirectional hashmap
Previously reverse lookups required iteration of all hashmap entries.
This commit makes it benefit from the same principle introduced in
aee8706bc84d98e96f7a30f1c3dbc1411f4d06a6 by adding a second hashmap
for reverse lookups so that a a simple indexing operation can be used
instead of iterating over the key-value pairs manually.
-rw-r--r-- | src/main.rs | 109 |
1 files changed, 61 insertions, 48 deletions
diff --git a/src/main.rs b/src/main.rs index c48f813..32fb013 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,6 +25,9 @@ const UPSTREAM_PRIMARY: &str = "[2620:fe::fe]:53"; const UPSTREAM_SECONDARY: &str = "9.9.9.9:53"; const UPSTREAM_TIMEOUT: Duration = Duration::from_secs(3); +type ForwardHosts = Arc<RwLock<HashMap<String, IpAddr>>>; +type ReverseHosts = Arc<RwLock<HashMap<IpAddr, String>>>; + #[derive(Debug, Error)] pub enum Error { #[error("hosts entry missing address column: {0}")] @@ -99,26 +102,27 @@ fn read_leases(cache: Arc<RwLock<Vec<Lease>>>) -> Result<()> { Ok(()) } -fn refresh_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> { +fn refresh_hosts(forward: ForwardHosts, reverse: ReverseHosts) -> Result<()> { let mut signals = Signals::new([SIGUSR2])?; for _ in signals.forever() { - read_hosts(cache.clone())?; + read_hosts(forward.clone(), reverse.clone())?; } Ok(()) // unreachable } -fn refresh_hosts_supervised(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> ! { +fn refresh_hosts_supervised(forward: ForwardHosts, reverse: ReverseHosts) -> ! { loop { - match refresh_hosts(cache.clone()) { + match refresh_hosts(forward.clone(), reverse.clone()) { Ok(_) => {} Err(e) => println!("[warn] hosts refresh: {}", e), } } } -fn read_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> { - let mut hosts = HashMap::new(); +fn read_hosts(forward: ForwardHosts, reverse: ReverseHosts) -> Result<()> { + let mut forward_hosts = HashMap::new(); + let mut reverse_hosts = HashMap::new(); let file = match File::open("/data/hosts.dnsd") { Ok(file) => file, @@ -137,12 +141,18 @@ fn read_hosts(cache: Arc<RwLock<HashMap<String, IpAddr>>>) -> Result<()> { let mut columns = split_input.split_whitespace(); let addr = columns.next().ok_or(Error::NoAddrColumn(line))?; + let parsed_addr = addr.parse()?; for host in columns { - hosts.insert(host.to_string() + ".", addr.parse()?); + let host = host.to_string() + "."; + + forward_hosts.insert(host.clone(), parsed_addr); + reverse_hosts.insert(parsed_addr, host); } } - *cache.write().unwrap() = hosts; + *forward.write().unwrap() = forward_hosts; + *reverse.write().unwrap() = reverse_hosts; + Ok(()) } @@ -155,11 +165,13 @@ fn main() -> Result<()> { let leases2 = leases.clone(); thread::spawn(move || refresh_leases_supervised(leases2)); - let hosts = Arc::new(RwLock::new(HashMap::new())); - read_hosts(hosts.clone())?; + let forward_hosts = Arc::new(RwLock::new(HashMap::new())); + let reverse_hosts = Arc::new(RwLock::new(HashMap::new())); + read_hosts(forward_hosts.clone(), reverse_hosts.clone())?; - let hosts2 = hosts.clone(); - thread::spawn(move || refresh_hosts_supervised(hosts2)); + let forward2 = forward_hosts.clone(); + let reverse2 = reverse_hosts.clone(); + thread::spawn(move || refresh_hosts_supervised(forward2, reverse2)); let domain = match fs::read_to_string("/data/dnsd.domain") { Ok(v) => match Name::from_utf8(v) { @@ -186,9 +198,10 @@ fn main() -> Result<()> { let sock2 = sock.try_clone()?; let buf = buf.to_vec(); let leases3 = leases.clone(); - let hosts3 = hosts.clone(); - thread::spawn( - move || match handle_query(&domain2, &sock2, &buf, raddr, leases3, hosts3) { + let forward3 = forward_hosts.clone(); + let reverse3 = reverse_hosts.clone(); + thread::spawn(move || { + match handle_query(&domain2, &sock2, &buf, raddr, leases3, forward3, reverse3) { Ok(_) => {} Err(e) => { match respond_with_error(&sock2, &buf, raddr) { @@ -198,8 +211,8 @@ fn main() -> Result<()> { print_query_error(&buf, raddr, e); } - }, - ); + } + }); } } @@ -260,7 +273,8 @@ fn handle_query( buf: &[u8], raddr: SocketAddr, leases: Arc<RwLock<Vec<Lease>>>, - hosts: Arc<RwLock<HashMap<String, IpAddr>>>, + forward_hosts: ForwardHosts, + reverse_hosts: ReverseHosts, ) -> Result<()> { let bytes = Bytes::copy_from_slice(buf); let mut msg = Dns::decode(bytes)?; @@ -278,7 +292,8 @@ fn handle_query( let (lan, fwd): (_, Vec<Question>) = msg.questions.into_iter().partition(|q| { let known = is_file_known( &usable_name(domain, &q.domain_name).expect("can't convert domain name"), - hosts.clone(), + forward_hosts.clone(), + reverse_hosts.clone(), ) || is_dhcp_known( &usable_name(domain, &q.domain_name).expect("can't convert domain name"), leases.clone(), @@ -314,7 +329,8 @@ fn handle_query( let hostname = usable_name(domain, &q.domain_name).expect("can't convert domain name"); if q.q_type == QType::A { - if let Some(entry) = file_entry(&hostname, hosts.clone()) { + if let Some(entry) = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone()) + { let IpAddr::V4(addr_as_v4) = entry.1 else { return None; }; @@ -345,7 +361,7 @@ fn handle_query( Some(answer) } } else if q.q_type == QType::AAAA { - let entry = file_entry(&hostname, hosts.clone())?; + let entry = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone())?; let IpAddr::V6(addr_as_v6) = entry.1 else { return None; }; @@ -358,7 +374,8 @@ fn handle_query( println!("[file] {} => {}", raddr, answer); Some(answer) } else if q.q_type == QType::PTR { - if let Some(entry) = file_entry(&hostname, hosts.clone()) { + if let Some(entry) = file_entry(&hostname, forward_hosts.clone(), reverse_hosts.clone()) + { let name = entry.0 + &domain .as_ref() @@ -506,41 +523,37 @@ fn upstream_query<A: ToSocketAddrs>(upstream: A, bytes: &[u8]) -> Result<Dns> { fn file_entry( hostname: &Name, - hosts: Arc<RwLock<HashMap<String, IpAddr>>>, + forward_hosts: ForwardHosts, + reverse_hosts: ReverseHosts, ) -> Option<(String, IpAddr)> { - let hosts = hosts.read().unwrap(); - let (host, addr) = if Name::from_str("in-addr.arpa.").unwrap().zone_of(hostname) - && hostname.iter().len() <= 6 - { - let (host, addr) = hosts - .iter() - .filter(|(_, addr)| addr.is_ipv4()) - .find(|(_, addr)| { - IpNet::new(**addr, 32).unwrap() - == hostname.parse_arpa_name().expect("can't parse arpa name") - })?; - (host.clone(), *addr) - } else if Name::from_str("ip6.arpa.").unwrap().zone_of(hostname) && hostname.iter().len() <= 34 + let forward_hosts = forward_hosts.read().unwrap(); + let reverse_hosts = reverse_hosts.read().unwrap(); + + let mapping = if (Name::from_str("in-addr.arpa.").unwrap().zone_of(hostname) + && hostname.iter().len() <= 6) + || (Name::from_str("ip6.arpa.").unwrap().zone_of(hostname) && hostname.iter().len() <= 34) { - let (host, addr) = hosts - .iter() - .filter(|(_, addr)| addr.is_ipv6()) - .find(|(_, addr)| { - IpNet::new(**addr, 128).unwrap() - == hostname.parse_arpa_name().expect("can't parse arpa name") - })?; - (host.clone(), *addr) + let addr = hostname + .parse_arpa_name() + .expect("can't parse arpa name") + .addr(); + let host = reverse_hosts.get(&addr)?; + (host.to_string(), addr) } else { let hostname_utf8 = hostname.to_utf8(); - let addr = hosts.get(&hostname_utf8)?; + let addr = forward_hosts.get(&hostname_utf8)?; (hostname_utf8, *addr) }; - Some((host, addr)) + Some(mapping) } -fn is_file_known(hostname: &Name, hosts: Arc<RwLock<HashMap<String, IpAddr>>>) -> bool { - file_entry(hostname, hosts).is_some() +fn is_file_known( + hostname: &Name, + forward_hosts: ForwardHosts, + reverse_hosts: ReverseHosts, +) -> bool { + file_entry(hostname, forward_hosts, reverse_hosts).is_some() } fn find_lease(hostname: &Name, mut leases: impl Iterator<Item = Lease>) -> Option<Lease> { |