Files
rdns-go/rdns.go
2025-11-16 22:35:37 +01:00

476 lines
11 KiB
Go

package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"time"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
addr = flag.String("a", "0.0.0.0", "address to use")
port = flag.String("p", "53", "port to run on")
ns = flag.String("n", "9.9.9.9", "nameservers to use")
blocklistPath = flag.String("b", "", "blocklist path")
zonesPath = flag.String("z", "", "add a custom zone file")
hostsPath = flag.String("h", "", "add a custom hosts file")
cpu = flag.Int("c", 0, "number of cpu to use")
ttl = flag.Int("t", 10, "force queries TTL")
dot = flag.Bool("s", false, "use tls to contact the resolver")
logs = flag.Bool("l", false, "log queries")
metrics = flag.Bool("m", false, "enable prometheus metrics")
qcache *cache.Cache
// Optimized data structures for O(1) lookups
blocklistMap = make(map[string]bool)
zonesMap = make(map[string]string)
hostsMap = make(map[string]string)
up = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_up",
Help: "Non-null value when the server is ready",
})
blSize = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_blocklist_size",
Help: "Blocklist size in bytes",
})
blCount = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_blocklist_count",
Help: "Number of items in the blocklist",
})
cacheItems = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_cache_items",
Help: "Number of cached queries",
})
slowAnswers = promauto.NewSummary(prometheus.SummaryOpts{
Name: "rdns_answers_slow",
Help: "Number of queries answered within 100 millisecond",
})
cacheHits = promauto.NewCounter(prometheus.CounterOpts{
Name: "rdns_cache_hits",
Help: "Number of responses using cache",
})
queries = promauto.NewCounter(prometheus.CounterOpts{
Name: "rdns_queries_total",
Help: "Total number of queries",
})
qtypes = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_queries",
Help: "Number of queries by type",
}, []string{"type"})
responses = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_responses",
Help: "Number of responses by type",
}, []string{"type", "status"})
nameservers = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_nameservers",
Help: "Number of responses by nameserver",
}, []string{"ns"})
)
func getRcode(rcode int) string {
switch rcode {
case 0:
return "NOERROR"
case 1:
return "FORMERR"
case 2:
return "SERVFAIL"
case 3:
return "NXDOMAIN"
case 4:
return "NOTIMP"
case 5:
return "REFUSED"
case 6:
return "YXDOMAIN"
case 7:
return "XRRSET"
case 8:
return "NOTAUTH"
case 9:
return "NOTZONE"
default:
return "SERVFAIL"
}
}
func getDomain(qn string) string {
if strings.Count(qn, ".") <= 1 {
return qn
}
qs := strings.Split(qn, ".")
return strings.Join(qs[len(qs)-2:], ".")
}
func getZoneNS(qdomain string) string {
if ns, exists := zonesMap[qdomain]; exists {
return ns
}
return "."
}
func getHost(qn string) string {
if ip, exists := hostsMap[qn]; exists {
return ip
}
return ""
}
func doDot(s string) string {
if len(s) == 0 {
return "."
}
if s[len(s)-1] == '.' {
return s
}
return s + "."
}
func handleCache(q *dns.Msg, qname string, qtype string) (*dns.Msg, int) {
var qnt = qname + "." + qtype
if cached, found := qcache.Get(qnt); found {
cached.(*dns.Msg).MsgHdr.Id = q.MsgHdr.Id
rcode := cached.(*dns.Msg).MsgHdr.Rcode
return cached.(*dns.Msg), rcode
}
return q, 2
}
func lookup(q *dns.Msg, qname string) (*dns.Msg, string, error) {
var (
r *dns.Msg
rtt time.Duration
err error
n []string
notls bool = false
qdomain = getDomain(qname)
)
if zoneNS := getZoneNS(qdomain); zoneNS != "." {
n = []string{zoneNS}
notls = true
} else {
n = strings.Split(*ns, ":")
if len(n) > 1 {
rand.Shuffle(len(n), func(i, j int) {
n[i], n[j] = n[j], n[i]
})
}
}
c := new(dns.Client)
for i := 0; i < len(n); i++ {
if *dot && !notls {
c.Net = "tcp-tls"
// Use proper TLS configuration with certificate verification
c.TLSConfig = &tls.Config{
ServerName: n[i],
}
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "853"))
if err != nil {
// Always log TLS connection failures to stderr (security-critical)
fmt.Fprintf(os.Stderr, "TLS connection failed to %s: %v\n", n[i], err)
}
} else {
c.Net = "udp"
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "53"))
if err != nil && *logs {
log.Println("WARN:", err)
}
}
if err == nil {
nameservers.WithLabelValues(n[i]).Inc()
break
}
}
if rtt/time.Millisecond > 100 {
slowAnswers.Observe(float64(rtt / time.Millisecond))
}
return r, c.Net, err
}
func handleQuery(w dns.ResponseWriter, q *dns.Msg) {
var (
r *dns.Msg
err error
rcode int
proto string
qname = strings.ToLower(q.Question[0].Name[:len(q.Question[0].Name)-1])
qclass = dns.Class(q.Question[0].Qclass).String()
qtype = dns.Type(q.Question[0].Qtype).String()
qnt = qname + "." + qtype
client = strings.Split(w.RemoteAddr().String(), ":")[0]
)
queries.Inc()
qtypes.WithLabelValues(qtype).Inc()
if qclass == "CH" && qname == "version.bind" {
qname = "version.bind."
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.SetRcode(q, rcode)
r.Answer = append(r.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: qname, Rrtype: 16, Class: 3, Ttl: 86400},
Txt: []string{"rdns"},
})
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode))
}
qcache.SetDefault(qnt, r)
if err := w.WriteMsg(r); err != nil && *logs {
log.Println("ERROR: Failed to write response:", err)
}
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
if r, rcode = handleCache(q, qname, qtype); rcode != 2 {
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode), "cache")
}
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
cacheHits.Inc()
if err := w.WriteMsg(r); err != nil && *logs {
log.Println("ERROR: Failed to write response:", err)
}
return
}
if hostIP := getHost(qname); hostIP != "" && (qtype == "A" || qtype == "AAAA") {
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.SetRcode(q, rcode)
r.MsgHdr.RecursionAvailable = true
if qtype == "A" {
r.Answer = append(r.Answer, &dns.A{
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 1, Class: 1, Ttl: 300},
A: net.ParseIP(hostIP),
})
}
if qtype == "AAAA" {
r.Answer = append(r.Answer, &dns.AAAA{
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 28, Class: 1, Ttl: 300},
AAAA: net.ParseIP(hostIP),
})
}
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode))
}
qcache.SetDefault(qnt, r)
if err := w.WriteMsg(r); err != nil && *logs {
log.Println("ERROR: Failed to write response:", err)
}
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
if blocklistMap[qname] {
r = new(dns.Msg)
rcode = 5
r.SetReply(q)
r.SetRcode(q, rcode)
qcache.SetDefault(qnt, r)
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode))
}
if err := w.WriteMsg(r); err != nil && *logs {
log.Println("ERROR: Failed to write response:", err)
}
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
r, proto, err = lookup(q, qname)
if err != nil {
if *logs {
log.Println("ERROR:", err)
}
r = new(dns.Msg)
rcode = 2
r.SetReply(q)
r.SetRcode(r, rcode)
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode))
}
} else {
rcode = r.MsgHdr.Rcode
r.MsgHdr.RecursionAvailable = true
if *logs {
log.Println(client, qname, qclass, qtype, getRcode(rcode), proto)
}
}
qcache.SetDefault(qnt, r)
if err := w.WriteMsg(r); err != nil && *logs {
log.Println("ERROR: Failed to write response:", err)
}
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
}
func serve(addr string, port string, net string) {
server := &dns.Server{Addr: addr + ":" + port, Net: net, ReusePort: true}
if err := server.ListenAndServe(); err != nil {
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
}
}
func loadBlocklist(path string) {
data, err := os.ReadFile(path)
if err != nil {
log.Println("Error reading blocklist:", err)
return
}
lines := strings.Split(string(data), "\n")
count := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" && !strings.HasPrefix(line, "#") {
blocklistMap[line] = true
count++
}
}
blCount.Set(float64(count))
blSize.Set(float64(len(data)))
}
func loadZones(path string) int {
data, err := os.ReadFile(path)
if err != nil {
log.Println("Error reading zones file:", err)
return 0
}
lines := strings.Split(string(data), "\n")
count := 0
for _, line := range lines {
if index := strings.Index(line, "#"); index != -1 {
line = line[:index]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
zonesMap[parts[0]] = parts[1]
count++
}
}
return count
}
func loadHosts(path string) int {
data, err := os.ReadFile(path)
if err != nil {
log.Println("Error reading hosts file:", err)
return 0
}
lines := strings.Split(string(data), "\n")
count := 0
for _, line := range lines {
if index := strings.Index(line, "#"); index != -1 {
line = line[:index]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
// Map all hostnames to the IP (first field)
for i := 1; i < len(parts); i++ {
hostsMap[parts[i]] = parts[0]
count++
}
}
}
return count
}
func cacheEviction(key string, value any) {
if *logs {
log.Println("Cache: Evicted", key)
}
cacheItems.Set(float64(qcache.ItemCount()))
}
func main() {
flag.Parse()
qcache = cache.New(time.Duration(*ttl)*time.Minute, 5*time.Minute)
qcache.OnEvicted(cacheEviction)
fmt.Println("Starting Proxy Resolver:", net.JoinHostPort(*addr, *port), "->", *ns, "[UDP/TCP]")
fmt.Println("Cache TTL:", *ttl, "minutes\nTLS enabled:", *dot)
if *cpu != 0 {
runtime.GOMAXPROCS(*cpu)
}
if !*logs {
log.SetOutput(io.Discard)
}
if *blocklistPath != "" {
loadBlocklist(*blocklistPath)
fmt.Println("Blocklist loaded:", len(blocklistMap), "entries")
}
if *zonesPath != "" {
count := loadZones(*zonesPath)
if count == 0 {
fmt.Println("No zones found")
} else {
fmt.Println("Zones loaded:", count, "entries")
}
}
if *hostsPath != "" {
count := loadHosts(*hostsPath)
if count == 0 {
fmt.Println("No hosts found")
} else {
fmt.Println("Hosts loaded:", count, "entries")
}
}
if *metrics {
fmt.Println("Starting prometheus exporter on port 9153")
http.Handle("/metrics", promhttp.Handler())
go http.ListenAndServe(":9153", nil)
}
dns.HandleFunc(".", handleQuery)
go serve(*addr, *port, "tcp")
go serve(*addr, *port, "udp")
up.Set(float64(1))
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
}