389 lines
9.7 KiB
Go
389 lines
9.7 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"runtime"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/patrickmn/go-cache"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
)
|
|
|
|
var (
|
|
addr = flag.String("addr", "0.0.0.0", "address to use")
|
|
port = flag.String("port", "53", "port to run on")
|
|
ns = flag.String("ns", "9.9.9.9", "nameservers to use")
|
|
blocklist = flag.String("blocklist", "", "blocklist path")
|
|
zones = flag.String("zones", "", "add a custom zone file")
|
|
hosts = flag.String("hosts", "", "add a custom hosts file")
|
|
cpu = flag.Int("cpu", 0, "number of cpu to use")
|
|
ttl = flag.Int("ttl", 10, "force queries TTL")
|
|
dot = flag.Bool("tls", false, "use tls to contact the resolver")
|
|
logs = flag.Bool("logs", false, "log queries")
|
|
metrics = flag.Bool("metrics", false, "enable prometheus metrics")
|
|
qcache *cache.Cache
|
|
|
|
up = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "rdns_up",
|
|
Help: "Non-null value when the server is ready",
|
|
})
|
|
blSize = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "rdns_blocklist_size",
|
|
Help: "Blocklist size in bytes",
|
|
})
|
|
blCount = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "rdns_blocklist_count",
|
|
Help: "Number of items in the blocklist",
|
|
})
|
|
cacheItems = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "rdns_cache_items",
|
|
Help: "Number of cached queries",
|
|
})
|
|
slowAnswers = promauto.NewSummary(prometheus.SummaryOpts{
|
|
Name: "rdns_answers_slow",
|
|
Help: "Number of queries answered within 100 millisecond",
|
|
})
|
|
cacheHits = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "rdns_cache_hits",
|
|
Help: "Number of responses using cache",
|
|
})
|
|
queries = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "rdns_queries_total",
|
|
Help: "Total number of queries",
|
|
})
|
|
qtypes = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "rdns_queries",
|
|
Help: "Number of queries by type",
|
|
}, []string{"type"})
|
|
responses = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "rdns_responses",
|
|
Help: "Number of responses by type",
|
|
}, []string{"type", "status"})
|
|
nameservers = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "rdns_nameservers",
|
|
Help: "Number of responses by nameserver",
|
|
}, []string{"ns"})
|
|
)
|
|
|
|
func getRcode(rcode int) string {
|
|
switch rcode {
|
|
case 0:
|
|
return "NOERROR"
|
|
case 1:
|
|
return "FORMERR"
|
|
case 2:
|
|
return "SERVFAIL"
|
|
case 3:
|
|
return "NXDOMAIN"
|
|
case 4:
|
|
return "NOTIMP"
|
|
case 5:
|
|
return "REFUSED"
|
|
case 6:
|
|
return "YXDOMAIN"
|
|
case 7:
|
|
return "XRRSET"
|
|
case 8:
|
|
return "NOTAUTH"
|
|
case 9:
|
|
return "NOTZONE"
|
|
default:
|
|
return "SERVFAIL"
|
|
}
|
|
}
|
|
|
|
func getDomain(qn string) string {
|
|
if strings.Count(qn, ".") <= 1 {
|
|
return qn
|
|
}
|
|
qs := strings.Split(qn, ".")
|
|
return strings.Join(qs[len(qs)-2:], ".")
|
|
}
|
|
|
|
func getZoneNS(zones string, qdomain string) string {
|
|
lines := strings.Split(zones, "\n")
|
|
for i := 0; i < len(lines); i++ {
|
|
line := strings.Split(lines[i], " ")
|
|
for j := 0; j < len(line); j++ {
|
|
if strings.Contains(line[j], qdomain) {
|
|
return line[j+1]
|
|
}
|
|
}
|
|
}
|
|
return "."
|
|
}
|
|
|
|
func getHost(hosts string, qn string) string {
|
|
lines := strings.Split(hosts, "\n")
|
|
for i := 0; i < len(lines); i++ {
|
|
line := strings.Split(lines[i], " ")
|
|
for j := 0; j < len(line); j++ {
|
|
if line[j] == qn {
|
|
return line[0]
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func doDot(s string) string {
|
|
if s[len(s)-1] == '.' {
|
|
return s
|
|
}
|
|
return s + "."
|
|
}
|
|
|
|
func handleCache(q *dns.Msg, qname string, qtype string) (*dns.Msg, int) {
|
|
var qnt = qname + "." + qtype
|
|
if cached, found := qcache.Get(qnt); found {
|
|
cached.(*dns.Msg).MsgHdr.Id = q.MsgHdr.Id
|
|
rcode := cached.(*dns.Msg).MsgHdr.Rcode
|
|
return cached.(*dns.Msg), rcode
|
|
}
|
|
return q, 2
|
|
}
|
|
|
|
func lookup(q *dns.Msg, qname string) (*dns.Msg, string, error) {
|
|
var (
|
|
r *dns.Msg
|
|
rtt time.Duration
|
|
err error
|
|
n []string
|
|
notls bool = false
|
|
qdomain = getDomain(qname)
|
|
)
|
|
|
|
if strings.Contains(*zones, qdomain) {
|
|
n = []string{getZoneNS(*zones, qdomain)}
|
|
notls = true
|
|
} else {
|
|
n = strings.Split(*ns, ":")
|
|
if len(n) > 1 {
|
|
rand.Shuffle(len(n), func(i, j int) {
|
|
n[i], n[j] = n[j], n[i]
|
|
})
|
|
}
|
|
}
|
|
|
|
c := new(dns.Client)
|
|
for i := 0; i < len(n); i++ {
|
|
if *dot && !notls {
|
|
c.Net = "tcp-tls"
|
|
c.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
|
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "853"))
|
|
} else {
|
|
c.Net = ""
|
|
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "53"))
|
|
}
|
|
if err != nil {
|
|
log.Println("WARN:", err)
|
|
} else {
|
|
nameservers.WithLabelValues(n[i]).Inc()
|
|
break
|
|
}
|
|
}
|
|
|
|
if rtt/time.Millisecond > 100 {
|
|
slowAnswers.Observe(float64(rtt / time.Millisecond))
|
|
}
|
|
|
|
return r, c.Net, err
|
|
}
|
|
|
|
func handleQuery(w dns.ResponseWriter, q *dns.Msg) {
|
|
var (
|
|
r *dns.Msg
|
|
err error
|
|
rcode int
|
|
proto string
|
|
qname = strings.ToLower(q.Question[0].Name[:len(q.Question[0].Name)-1])
|
|
qclass = dns.Class(q.Question[0].Qclass).String()
|
|
qtype = dns.Type(q.Question[0].Qtype).String()
|
|
qnt = qname + "." + qtype
|
|
client = strings.Split(w.RemoteAddr().String(), ":")[0]
|
|
)
|
|
|
|
queries.Inc()
|
|
qtypes.WithLabelValues(qtype).Inc()
|
|
|
|
if qclass == "CH" && qname == "version.bind" {
|
|
qname = "version.bind."
|
|
r = new(dns.Msg)
|
|
rcode = 0
|
|
r.SetReply(q)
|
|
r.SetRcode(q, rcode)
|
|
r.Answer = append(r.Answer, &dns.TXT{
|
|
Hdr: dns.RR_Header{Name: qname, Rrtype: 16, Class: 3, Ttl: 86400},
|
|
Txt: []string{"rdns"},
|
|
})
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode))
|
|
qcache.SetDefault(qnt, r)
|
|
w.WriteMsg(r)
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
return
|
|
}
|
|
|
|
if r, rcode = handleCache(q, qname, qtype); rcode != 2 {
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode), "cache")
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
cacheHits.Inc()
|
|
w.WriteMsg(r)
|
|
return
|
|
}
|
|
|
|
if strings.Contains(*hosts, qname) && ( qtype == "A" || qtype == "AAAA" ) {
|
|
r = new(dns.Msg)
|
|
rcode = 0
|
|
r.SetReply(q)
|
|
r.SetRcode(q, rcode)
|
|
r.MsgHdr.RecursionAvailable = true
|
|
if qtype == "A" {
|
|
r.Answer = append(r.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 1, Class: 1, Ttl: 300},
|
|
A: net.ParseIP(getHost(*hosts, qname)),
|
|
})
|
|
}
|
|
if qtype == "AAAA" {
|
|
r.Answer = append(r.Answer, &dns.AAAA{
|
|
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 28, Class: 1, Ttl: 300},
|
|
AAAA: net.ParseIP(getHost(*hosts, qname)),
|
|
})
|
|
}
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode))
|
|
qcache.SetDefault(qnt, r)
|
|
w.WriteMsg(r)
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
return
|
|
}
|
|
|
|
if strings.Contains(*blocklist, " "+qname+" ") {
|
|
r = new(dns.Msg)
|
|
rcode = 5
|
|
r.SetReply(q)
|
|
r.SetRcode(q, rcode)
|
|
qcache.SetDefault(qnt, r)
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode))
|
|
w.WriteMsg(r)
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
return
|
|
} else {
|
|
cacheItems.Inc()
|
|
}
|
|
|
|
r, proto, err = lookup(q, qname)
|
|
if err != nil {
|
|
log.Println("ERROR:", err)
|
|
r = new(dns.Msg)
|
|
rcode = 2
|
|
r.SetReply(q)
|
|
r.SetRcode(r, rcode)
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode))
|
|
} else {
|
|
rcode = r.MsgHdr.Rcode
|
|
r.MsgHdr.RecursionAvailable = true
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode), proto)
|
|
}
|
|
qcache.SetDefault(qnt, r)
|
|
w.WriteMsg(r)
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
}
|
|
|
|
func serve(addr string, port string, net string) {
|
|
server := &dns.Server{Addr: addr + ":" + port, Net: net, ReusePort: true}
|
|
if err := server.ListenAndServe(); err != nil {
|
|
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
|
}
|
|
}
|
|
|
|
func getBlocklist(blocklist string) string {
|
|
bl, err := os.ReadFile(blocklist)
|
|
if err != nil {
|
|
log.Println("Error reading blocklist")
|
|
}
|
|
block := strings.Replace(string(bl), "\n", " ", -1)
|
|
blocklist = strings.Replace(string(block), "\t", " ", -1)
|
|
blCount.Set(float64(len(strings.Split(blocklist, " "))))
|
|
blSize.Set(float64(len(blocklist)))
|
|
return blocklist
|
|
}
|
|
|
|
func getZones(zones string) string {
|
|
z, err := os.ReadFile(zones)
|
|
if err != nil {
|
|
log.Println("Error reading zones file")
|
|
}
|
|
return strings.Replace(string(z), "\t", " ", -1)
|
|
}
|
|
|
|
func getHosts(hosts string) string {
|
|
h, err := os.ReadFile(hosts)
|
|
if err != nil {
|
|
log.Println("Error reading hosts file")
|
|
}
|
|
return strings.Replace(string(h), "\t", " ", -1)
|
|
}
|
|
|
|
func cacheEviction(key string, value interface{}) {
|
|
log.Println("Cache: Evicted", key)
|
|
cacheItems.Set(float64(qcache.ItemCount()))
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
qcache = cache.New(time.Duration(*ttl)*time.Minute, 5*time.Minute)
|
|
qcache.OnEvicted(cacheEviction)
|
|
|
|
fmt.Println("Starting Proxy Resolver:", net.JoinHostPort(*addr, *port), "->", *ns, "[UDP/TCP]")
|
|
fmt.Println("Cache TTL:", *ttl, "minutes\nTLS enabled:", *dot)
|
|
|
|
if *cpu != 0 {
|
|
runtime.GOMAXPROCS(*cpu)
|
|
}
|
|
if !*logs {
|
|
log.SetOutput(io.Discard)
|
|
}
|
|
if *blocklist != "" {
|
|
*blocklist = getBlocklist(*blocklist)
|
|
fmt.Println("Blocklist loaded:", len(strings.Split(*blocklist, " ")), "entries")
|
|
}
|
|
if *zones != "" {
|
|
*zones = getZones(*zones)
|
|
fmt.Println("Zones loaded:", len(strings.Split(*zones, " "))-1, "entries")
|
|
}
|
|
if *hosts != "" {
|
|
*hosts = getHosts(*hosts)
|
|
fmt.Println("Hosts loaded:", len(strings.Split(*zones, "\n"))-1, "lines")
|
|
}
|
|
if *metrics {
|
|
fmt.Println("Starting prometheus exporter on port 9153")
|
|
http.Handle("/metrics", promhttp.Handler())
|
|
go http.ListenAndServe(":9153", nil)
|
|
}
|
|
|
|
dns.HandleFunc(".", handleQuery)
|
|
go serve(*addr, *port, "tcp")
|
|
go serve(*addr, *port, "udp")
|
|
|
|
up.Set(float64(1))
|
|
|
|
sig := make(chan os.Signal, 1)
|
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sig
|
|
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
|
|
}
|