diff options
author | Himbeer <himbeer@disroot.org> | 2024-07-22 14:28:53 +0200 |
---|---|---|
committer | Himbeer <himbeer@disroot.org> | 2024-07-22 14:28:53 +0200 |
commit | 7330e326b1f747646257f33bf713876a0d753ca5 (patch) | |
tree | 26bd4146b5cb5c6086cedca1ab26377dc42f961e | |
parent | eabf5703bf783be57e064e90f84207c3cf486218 (diff) |
Handle upstream read timeouts less disruptively
-rw-r--r-- | src/main.rs | 93 |
1 files changed, 76 insertions, 17 deletions
diff --git a/src/main.rs b/src/main.rs index 25c419a..9b1c120 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::cell::RefCell; use std::fs::{self, File}; use std::io; -use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket}; +use std::net::{IpAddr, SocketAddr, UdpSocket}; use std::str::FromStr; use std::sync::{Arc, RwLock}; use std::thread; @@ -19,6 +19,7 @@ use thiserror::Error; const UPSTREAM_PRIMARY: &str = "[2620:fe::fe]:53"; const UPSTREAM_SECONDARY: &str = "9.9.9.9:53"; +const UPSTREAM_TIMEOUT: Duration = Duration::from_secs(1); #[derive(Debug, Error)] pub enum Error { @@ -115,6 +116,14 @@ fn main() -> Result<()> { let sock = UdpSocket::bind("[::]:53")?; + let uplink_primary = UdpSocket::bind("[::]:0")?; + uplink_primary.set_read_timeout(Some(UPSTREAM_TIMEOUT))?; + uplink_primary.connect(UPSTREAM_PRIMARY)?; + + let uplink_secondary = UdpSocket::bind("[::]:0")?; + uplink_secondary.set_read_timeout(Some(UPSTREAM_TIMEOUT))?; + uplink_secondary.connect(UPSTREAM_SECONDARY)?; + loop { let mut buf = [0; 1024]; let (n, raddr) = sock.recv_from(&mut buf)?; @@ -122,14 +131,31 @@ fn main() -> Result<()> { let domain2 = domain.clone(); let sock2 = sock.try_clone()?; + let uplink_primary2 = uplink_primary.try_clone()?; + let uplink_secondary2 = uplink_secondary.try_clone()?; let buf = buf.to_vec(); let leases3 = leases.clone(); - thread::spawn( - move || match handle_query(&domain2, sock2, &buf, raddr, leases3) { + thread::spawn(move || { + match handle_query( + &domain2, + &sock2, + uplink_primary2, + uplink_secondary2, + &buf, + raddr, + leases3, + ) { Ok(_) => {} - Err(e) => print_query_error(&buf, raddr, e), - }, - ); + Err(e) => { + match respond_with_error(&sock2, &buf, raddr) { + Ok(_) => {} + Err(e) => println!("[warn] {} send error response: {}", raddr, e), + } + + print_query_error(&buf, raddr, e); + } + } + }); } } @@ -144,6 +170,39 @@ fn print_query_error(buf: &[u8], raddr: SocketAddr, e: Error) { } } +fn respond_with_error(sock: &UdpSocket, buf: &[u8], raddr: SocketAddr) -> Result<()> { + let bytes = Bytes::copy_from_slice(buf); + let msg = Dns::decode(bytes)?; + + let resp = Dns { + id: msg.id, + flags: Flags { + qr: true, + opcode: msg.flags.opcode, + aa: false, + tc: false, + rd: msg.flags.rd, + ra: true, + ad: false, + cd: false, + rcode: RCode::ServFail, + }, + questions: Vec::default(), + answers: Vec::default(), + authorities: Vec::default(), + additionals: Vec::default(), + }; + + let bytes = resp.encode()?; + + let n = sock.send_to(&bytes, raddr)?; + if n != bytes.len() { + return Err(Error::PartialSend(bytes.len(), n)); + } + + Ok(()) +} + fn extract_questions(buf: &[u8]) -> Result<Vec<Question>> { let bytes = Bytes::copy_from_slice(buf); let msg = Dns::decode(bytes)?; @@ -153,7 +212,9 @@ fn extract_questions(buf: &[u8]) -> Result<Vec<Question>> { fn handle_query( domain: &Option<Name>, - sock: UdpSocket, + sock: &UdpSocket, + uplink_primary: UdpSocket, + uplink_secondary: UdpSocket, buf: &[u8], raddr: SocketAddr, leases: Arc<RwLock<Vec<Lease>>>, @@ -286,12 +347,15 @@ fn handle_query( if !msg.questions.is_empty() { let bytes = msg.encode()?; - let resp = match upstream_query(UPSTREAM_PRIMARY, &bytes) { + let resp = match upstream_query(uplink_primary, &bytes) { Ok(v) => v, - Err(e) => match upstream_query(UPSTREAM_SECONDARY, &bytes) { - Ok(v) => v, + Err(e) => match upstream_query(uplink_secondary, &bytes) { + Ok(v) => { + println!("[warn] {} primary unavailable: {}", raddr, e); + v + } Err(e2) => { - println!("[warn] primary unavailable: {}", e); + println!("[warn] {} secondary unavailable: {}", raddr, e2); return Err(e2); } }, @@ -339,12 +403,7 @@ fn handle_query( Ok(()) } -fn upstream_query<A: ToSocketAddrs>(upstream: A, bytes: &[u8]) -> Result<Dns> { - let uplink = UdpSocket::bind("[::]:0")?; - - uplink.set_read_timeout(Some(Duration::from_secs(1)))?; - uplink.connect(upstream)?; - +fn upstream_query(uplink: UdpSocket, bytes: &[u8]) -> Result<Dns> { let n = uplink.send(bytes)?; if n != bytes.len() { return Err(Error::PartialSend(bytes.len(), n)); |