aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs93
1 files changed, 76 insertions, 17 deletions
diff --git a/src/main.rs b/src/main.rs
index 25c419a..9b1c120 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,7 +1,7 @@
use std::cell::RefCell;
use std::fs::{self, File};
use std::io;
-use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
+use std::net::{IpAddr, SocketAddr, UdpSocket};
use std::str::FromStr;
use std::sync::{Arc, RwLock};
use std::thread;
@@ -19,6 +19,7 @@ use thiserror::Error;
const UPSTREAM_PRIMARY: &str = "[2620:fe::fe]:53";
const UPSTREAM_SECONDARY: &str = "9.9.9.9:53";
+const UPSTREAM_TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Debug, Error)]
pub enum Error {
@@ -115,6 +116,14 @@ fn main() -> Result<()> {
let sock = UdpSocket::bind("[::]:53")?;
+ let uplink_primary = UdpSocket::bind("[::]:0")?;
+ uplink_primary.set_read_timeout(Some(UPSTREAM_TIMEOUT))?;
+ uplink_primary.connect(UPSTREAM_PRIMARY)?;
+
+ let uplink_secondary = UdpSocket::bind("[::]:0")?;
+ uplink_secondary.set_read_timeout(Some(UPSTREAM_TIMEOUT))?;
+ uplink_secondary.connect(UPSTREAM_SECONDARY)?;
+
loop {
let mut buf = [0; 1024];
let (n, raddr) = sock.recv_from(&mut buf)?;
@@ -122,14 +131,31 @@ fn main() -> Result<()> {
let domain2 = domain.clone();
let sock2 = sock.try_clone()?;
+ let uplink_primary2 = uplink_primary.try_clone()?;
+ let uplink_secondary2 = uplink_secondary.try_clone()?;
let buf = buf.to_vec();
let leases3 = leases.clone();
- thread::spawn(
- move || match handle_query(&domain2, sock2, &buf, raddr, leases3) {
+ thread::spawn(move || {
+ match handle_query(
+ &domain2,
+ &sock2,
+ uplink_primary2,
+ uplink_secondary2,
+ &buf,
+ raddr,
+ leases3,
+ ) {
Ok(_) => {}
- Err(e) => print_query_error(&buf, raddr, e),
- },
- );
+ Err(e) => {
+ match respond_with_error(&sock2, &buf, raddr) {
+ Ok(_) => {}
+ Err(e) => println!("[warn] {} send error response: {}", raddr, e),
+ }
+
+ print_query_error(&buf, raddr, e);
+ }
+ }
+ });
}
}
@@ -144,6 +170,39 @@ fn print_query_error(buf: &[u8], raddr: SocketAddr, e: Error) {
}
}
+fn respond_with_error(sock: &UdpSocket, buf: &[u8], raddr: SocketAddr) -> Result<()> {
+ let bytes = Bytes::copy_from_slice(buf);
+ let msg = Dns::decode(bytes)?;
+
+ let resp = Dns {
+ id: msg.id,
+ flags: Flags {
+ qr: true,
+ opcode: msg.flags.opcode,
+ aa: false,
+ tc: false,
+ rd: msg.flags.rd,
+ ra: true,
+ ad: false,
+ cd: false,
+ rcode: RCode::ServFail,
+ },
+ questions: Vec::default(),
+ answers: Vec::default(),
+ authorities: Vec::default(),
+ additionals: Vec::default(),
+ };
+
+ let bytes = resp.encode()?;
+
+ let n = sock.send_to(&bytes, raddr)?;
+ if n != bytes.len() {
+ return Err(Error::PartialSend(bytes.len(), n));
+ }
+
+ Ok(())
+}
+
fn extract_questions(buf: &[u8]) -> Result<Vec<Question>> {
let bytes = Bytes::copy_from_slice(buf);
let msg = Dns::decode(bytes)?;
@@ -153,7 +212,9 @@ fn extract_questions(buf: &[u8]) -> Result<Vec<Question>> {
fn handle_query(
domain: &Option<Name>,
- sock: UdpSocket,
+ sock: &UdpSocket,
+ uplink_primary: UdpSocket,
+ uplink_secondary: UdpSocket,
buf: &[u8],
raddr: SocketAddr,
leases: Arc<RwLock<Vec<Lease>>>,
@@ -286,12 +347,15 @@ fn handle_query(
if !msg.questions.is_empty() {
let bytes = msg.encode()?;
- let resp = match upstream_query(UPSTREAM_PRIMARY, &bytes) {
+ let resp = match upstream_query(uplink_primary, &bytes) {
Ok(v) => v,
- Err(e) => match upstream_query(UPSTREAM_SECONDARY, &bytes) {
- Ok(v) => v,
+ Err(e) => match upstream_query(uplink_secondary, &bytes) {
+ Ok(v) => {
+ println!("[warn] {} primary unavailable: {}", raddr, e);
+ v
+ }
Err(e2) => {
- println!("[warn] primary unavailable: {}", e);
+ println!("[warn] {} secondary unavailable: {}", raddr, e2);
return Err(e2);
}
},
@@ -339,12 +403,7 @@ fn handle_query(
Ok(())
}
-fn upstream_query<A: ToSocketAddrs>(upstream: A, bytes: &[u8]) -> Result<Dns> {
- let uplink = UdpSocket::bind("[::]:0")?;
-
- uplink.set_read_timeout(Some(Duration::from_secs(1)))?;
- uplink.connect(upstream)?;
-
+fn upstream_query(uplink: UdpSocket, bytes: &[u8]) -> Result<Dns> {
let n = uplink.send(bytes)?;
if n != bytes.len() {
return Err(Error::PartialSend(bytes.len(), n));