diff --git a/src/client.rs b/src/client.rs index 92ec845..5ba5253 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use crate::cache::DnsCache; +use crate::r#override::OverrideResolver; use domain::base::iana::{Opcode, Rcode}; use domain::base::message::Message; use domain::base::message_builder::MessageBuilder; @@ -18,13 +19,15 @@ pub struct ClientOptions { pub struct Client { options: ClientOptions, cache: DnsCache, + override_resolver: OverrideResolver, } impl Client { - pub fn new(options: ClientOptions) -> Client { + pub fn new(options: ClientOptions, override_resolver: OverrideResolver) -> Client { Client { options, cache: DnsCache::new(), + override_resolver, } } @@ -32,11 +35,11 @@ impl Client { &self, questions: Vec>>>, ) -> Result>, UnknownRecordData>>>, String> { - // Attempt to read from cache first - let (mut cached_answers, questions) = self.try_answer_from_cache(questions).await; + // Attempt to answer locally first + let (mut local_answers, questions) = self.try_answer_from_local(questions).await; if questions.len() == 0 { // No remaining questions to be handled. Return directly. - return Ok(cached_answers); + return Ok(local_answers); } let msg = Self::build_query(questions)?; @@ -48,7 +51,7 @@ impl Client { let mut ret = Self::extract_answers(resp)?; self.cache_answers(&ret).await; // Concatenate the cached answers we retrived previously with the newly-fetched answers - ret.append(&mut cached_answers); + ret.append(&mut local_answers); Ok(ret) } // NXDOMAIN is not an error we want to retry / panic upon @@ -181,10 +184,10 @@ impl Client { Ok(ret) } - // Try to answer the questions as much as we can from the cache + // Try to answer the questions as much as we can from the cache / override map // returns the available answers, and the remaining questions that cannot be - // answered from cache - async fn try_answer_from_cache( + // answered from cache or the override resolver + async fn try_answer_from_local( &self, questions: Vec>>>, ) -> ( @@ -194,9 +197,15 @@ impl Client { let mut answers = Vec::new(); let mut remaining = Vec::new(); for q in questions { - match self.cache.get_cache(&q).await { - Some(mut ans) => answers.append(&mut ans), - None => remaining.push(q), + if let Some(ans) = self.override_resolver.try_resolve(&q) { + // Try to resolve from override map first + answers.push(ans); + } else if let Some(mut ans) = self.cache.get_cache(&q).await { + // Then try cache + answers.append(&mut ans); + } else { + // If both failed, resolve via upstream + remaining.push(q); } } (answers, remaining) diff --git a/src/lib.rs b/src/lib.rs index f5f21f8..7603693 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod cache; mod client; mod kv; +mod r#override; mod server; mod util; diff --git a/src/override.rs b/src/override.rs new file mode 100644 index 0000000..a3e0b97 --- /dev/null +++ b/src/override.rs @@ -0,0 +1,82 @@ +use domain::base::rdata::UnknownRecordData; +use domain::base::{question::Question, Compose}; +use domain::base::{Dname, Record, Rtype}; +use domain::rdata::{Aaaa, AllRecordData, A}; +use std::collections::HashMap; +use std::net::IpAddr; + +pub struct OverrideResolver { + simple_matches: HashMap, + override_ttl: u32, +} + +impl OverrideResolver { + pub fn new(overrides: HashMap, override_ttl: u32) -> OverrideResolver { + OverrideResolver { + simple_matches: Self::build_simple_match_table(overrides), + override_ttl, + } + } + + fn build_simple_match_table(overrides: HashMap) -> HashMap { + let mut ret = HashMap::new(); + for (k, v) in overrides.into_iter() { + match v.parse::() { + Ok(addr) => { + ret.insert(k, addr); + } + // Ignore malformed IP addresses + Err(_) => continue, + } + } + return ret; + } + + pub fn try_resolve( + &self, + question: &Question>>, + ) -> Option>, UnknownRecordData>>> { + match question.qtype() { + // We only handle resolution of IP addresses + Rtype::A | Rtype::A6 | Rtype::Aaaa | Rtype::Cname | Rtype::Any => (), + // So if the question is anything else, just skip + _ => return None, + } + + let name = question.qname().to_string(); + if let Some(addr) = self.simple_matches.get(&name) { + self.respond_with_addr(question, addr) + } else { + None + } + } + + fn respond_with_addr( + &self, + question: &Question>>, + addr: &IpAddr, + ) -> Option>, UnknownRecordData>>> { + let (rtype, rdata): (_, AllRecordData, Dname>>) = match addr { + IpAddr::V4(addr) => (Rtype::A, AllRecordData::A(A::new(addr.clone()))), + IpAddr::V6(addr) => (Rtype::Aaaa, AllRecordData::Aaaa(Aaaa::new(addr.clone()))), + }; + + let qtype = question.qtype(); + if qtype == Rtype::Any || qtype == rtype { + // Convert AllRecordData to UnknownRecordData to match the type + // since our resolver client doesn't really care about the actual type + let mut rdata_buf: Vec = Vec::new(); + rdata.compose(&mut rdata_buf).ok()?; + let record = Record::new( + question.qname().clone(), + question.qclass(), + self.override_ttl, + UnknownRecordData::from_octets(rtype, rdata_buf), + ); + return Some(record); + } else { + // If the response and query types don't match, just return none + return None; + } + } +} diff --git a/src/server.rs b/src/server.rs index 647da57..5deaa12 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use crate::client::*; +use crate::r#override::OverrideResolver; use async_static::async_static; use domain::base::iana::{Opcode, Rcode}; use domain::base::message::Message; @@ -10,6 +11,7 @@ use domain::base::{Dname, ToDname}; use js_sys::{ArrayBuffer, Uint8Array}; use serde::Deserialize; use std::borrow::Borrow; +use std::collections::HashMap; use wasm_bindgen_futures::JsFuture; use web_sys::*; @@ -43,6 +45,10 @@ enum DnsResponseFormat { pub struct ServerOptions { upstream_urls: Vec, retries: usize, + #[serde(default)] + overrides: HashMap, + #[serde(default)] + override_ttl: u32, } pub struct Server { @@ -53,9 +59,12 @@ pub struct Server { impl Server { fn new(options: ServerOptions) -> Server { Server { - client: Client::new(ClientOptions { - upstream_urls: options.upstream_urls.clone(), - }), + client: Client::new( + ClientOptions { + upstream_urls: options.upstream_urls.clone(), + }, + OverrideResolver::new(options.overrides.clone(), options.override_ttl), + ), options, } }