Files
rdns-go/rdns.go
tchivert c33fa0c089
build / build (push) Successful in 2m31s
fix blocklist
2025-07-12 01:05:34 +02:00

395 lines
9.7 KiB
Go

package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"time"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
addr = flag.String("addr", "0.0.0.0", "address to use")
port = flag.String("port", "53", "port to run on")
ns = flag.String("ns", "9.9.9.9", "nameservers to use")
blocklist = flag.String("blocklist", "", "blocklist path")
zones = flag.String("zones", "", "add a custom zone file")
hosts = flag.String("hosts", "", "add a custom hosts file")
cpu = flag.Int("cpu", 0, "number of cpu to use")
ttl = flag.Int("ttl", 10, "force queries TTL")
dot = flag.Bool("tls", false, "use tls to contact the resolver")
logs = flag.Bool("logs", false, "log queries")
metrics = flag.Bool("metrics", false, "enable prometheus metrics")
qcache *cache.Cache
up = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_up",
Help: "Non-null value when the server is ready",
})
blSize = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_blocklist_size",
Help: "Blocklist size in bytes",
})
blCount = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_blocklist_count",
Help: "Number of items in the blocklist",
})
cacheItems = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_cache_items",
Help: "Number of cached queries",
})
slowAnswers = promauto.NewSummary(prometheus.SummaryOpts{
Name: "rdns_answers_slow",
Help: "Number of queries answered within 100 millisecond",
})
cacheHits = promauto.NewCounter(prometheus.CounterOpts{
Name: "rdns_cache_hits",
Help: "Number of responses using cache",
})
queries = promauto.NewCounter(prometheus.CounterOpts{
Name: "rdns_queries_total",
Help: "Total number of queries",
})
qtypes = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_queries",
Help: "Number of queries by type",
}, []string{"type"})
responses = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_responses",
Help: "Number of responses by type",
}, []string{"type", "status"})
nameservers = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_nameservers",
Help: "Number of responses by nameserver",
}, []string{"ns"})
)
func getRcode(rcode int) string {
switch rcode {
case 0:
return "NOERROR"
case 1:
return "FORMERR"
case 2:
return "SERVFAIL"
case 3:
return "NXDOMAIN"
case 4:
return "NOTIMP"
case 5:
return "REFUSED"
case 6:
return "YXDOMAIN"
case 7:
return "XRRSET"
case 8:
return "NOTAUTH"
case 9:
return "NOTZONE"
default:
return "SERVFAIL"
}
}
func getDomain(qn string) string {
if strings.Count(qn, ".") <= 1 {
return qn
}
qs := strings.Split(qn, ".")
return strings.Join(qs[len(qs)-2:], ".")
}
func getZoneNS(zones string, qdomain string) string {
lines := strings.Split(zones, "\n")
for i := 0; i < len(lines); i++ {
line := strings.Split(lines[i], " ")
for j := 0; j < len(line); j++ {
if strings.Contains(line[j], qdomain) {
return line[j+1]
}
}
}
return "."
}
func getHost(hosts string, qn string) string {
lines := strings.Split(hosts, "\n")
for i := 0; i < len(lines); i++ {
line := strings.Split(lines[i], " ")
for j := 0; j < len(line); j++ {
if line[j] == qn {
return line[0]
}
}
}
return ""
}
func doDot(s string) string {
if s[len(s)-1] == '.' {
return s
}
return s + "."
}
func handleCache(q *dns.Msg, qname string, qtype string) (*dns.Msg, int) {
var qnt = qname + "." + qtype
if cached, found := qcache.Get(qnt); found {
cached.(*dns.Msg).MsgHdr.Id = q.MsgHdr.Id
rcode := cached.(*dns.Msg).MsgHdr.Rcode
return cached.(*dns.Msg), rcode
}
return q, 2
}
func lookup(q *dns.Msg, qname string) (*dns.Msg, string, error) {
var (
r *dns.Msg
rtt time.Duration
err error
n []string
notls bool = false
qdomain = getDomain(qname)
)
if strings.Contains(*zones, qdomain) {
n = []string{getZoneNS(*zones, qdomain)}
notls = true
} else {
n = strings.Split(*ns, ":")
if len(n) > 1 {
rand.Shuffle(len(n), func(i, j int) {
n[i], n[j] = n[j], n[i]
})
}
}
c := new(dns.Client)
for i := 0; i < len(n); i++ {
if *dot && !notls {
c.Net = "tcp-tls"
c.TLSConfig = &tls.Config{InsecureSkipVerify: true}
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "853"))
} else {
c.Net = ""
r, rtt, err = c.Exchange(q, net.JoinHostPort(n[i], "53"))
}
if err != nil {
log.Println("WARN:", err)
} else {
nameservers.WithLabelValues(n[i]).Inc()
break
}
}
if rtt/time.Millisecond > 100 {
slowAnswers.Observe(float64(rtt / time.Millisecond))
}
return r, c.Net, err
}
func handleQuery(w dns.ResponseWriter, q *dns.Msg) {
var (
r *dns.Msg
err error
rcode int
proto string
qname = strings.ToLower(q.Question[0].Name[:len(q.Question[0].Name)-1])
qclass = dns.Class(q.Question[0].Qclass).String()
qtype = dns.Type(q.Question[0].Qtype).String()
qnt = qname + "." + qtype
client = strings.Split(w.RemoteAddr().String(), ":")[0]
)
queries.Inc()
qtypes.WithLabelValues(qtype).Inc()
if qclass == "CH" && qname == "version.bind" {
qname = "version.bind."
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.SetRcode(q, rcode)
r.Answer = append(r.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: qname, Rrtype: 16, Class: 3, Ttl: 86400},
Txt: []string{"rdns"},
})
log.Println(client, qname, qclass, qtype, getRcode(rcode))
qcache.SetDefault(qnt, r)
w.WriteMsg(r)
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
if r, rcode = handleCache(q, qname, qtype); rcode != 2 {
log.Println(client, qname, qclass, qtype, getRcode(rcode), "cache")
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
cacheHits.Inc()
w.WriteMsg(r)
return
}
if strings.Contains(*hosts, qname) && ( qtype == "A" || qtype == "AAAA" ) {
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.SetRcode(q, rcode)
r.MsgHdr.RecursionAvailable = true
if qtype == "A" {
r.Answer = append(r.Answer, &dns.A{
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 1, Class: 1, Ttl: 300},
A: net.ParseIP(getHost(*hosts, qname)),
})
}
if qtype == "AAAA" {
r.Answer = append(r.Answer, &dns.AAAA{
Hdr: dns.RR_Header{Name: doDot(qname), Rrtype: 28, Class: 1, Ttl: 300},
AAAA: net.ParseIP(getHost(*hosts, qname)),
})
}
log.Println(client, qname, qclass, qtype, getRcode(rcode))
if r.Len() > 512 {
r.Truncated = true
}
qcache.SetDefault(qnt, r)
w.WriteMsg(r)
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
if strings.Contains(*blocklist, " "+qname+" ") {
r = new(dns.Msg)
rcode = 5
r.SetReply(q)
r.SetRcode(q, rcode)
qcache.SetDefault(qnt, r)
log.Println(client, qname, qclass, qtype, getRcode(rcode))
w.WriteMsg(r)
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
} else {
cacheItems.Inc()
}
r, proto, err = lookup(q, qname)
if err != nil {
log.Println("ERROR:", err)
r = new(dns.Msg)
rcode = 2
r.SetReply(q)
r.SetRcode(r, rcode)
log.Println(client, qname, qclass, qtype, getRcode(rcode))
} else {
rcode = r.MsgHdr.Rcode
r.MsgHdr.RecursionAvailable = true
log.Println(client, qname, qclass, qtype, getRcode(rcode), proto)
}
if r.Len() > 512 {
r.Truncated = true
}
qcache.SetDefault(qnt, r)
w.WriteMsg(r)
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
}
func serve(addr string, port string, net string) {
server := &dns.Server{Addr: addr + ":" + port, Net: net, ReusePort: true}
if err := server.ListenAndServe(); err != nil {
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
}
}
func getBlocklist(blocklist string) string {
bl, err := os.ReadFile(blocklist)
if err != nil {
log.Println("Error reading blocklist")
}
block := strings.Replace(string(bl), "\n", " ", -1)
blocklist = strings.Replace(string(block), "\t", " ", -1)
blCount.Set(float64(len(strings.Split(blocklist, " "))))
blSize.Set(float64(len(blocklist)))
return blocklist
}
func getZones(zones string) string {
z, err := os.ReadFile(zones)
if err != nil {
log.Println("Error reading zones file")
}
return strings.Replace(string(z), "\t", " ", -1)
}
func getHosts(hosts string) string {
h, err := os.ReadFile(hosts)
if err != nil {
log.Println("Error reading hosts file")
}
return strings.Replace(string(h), "\t", " ", -1)
}
func cacheEviction(key string, value interface{}) {
log.Println("Cache: Evicted", key)
cacheItems.Set(float64(qcache.ItemCount()))
}
func main() {
flag.Parse()
qcache = cache.New(time.Duration(*ttl)*time.Minute, 5*time.Minute)
qcache.OnEvicted(cacheEviction)
fmt.Println("Starting Proxy Resolver:", net.JoinHostPort(*addr, *port), "->", *ns, "[UDP/TCP]")
fmt.Println("Cache TTL:", *ttl, "minutes\nTLS enabled:", *dot)
if *cpu != 0 {
runtime.GOMAXPROCS(*cpu)
}
if !*logs {
log.SetOutput(io.Discard)
}
if *blocklist != "" {
*blocklist = getBlocklist(*blocklist)
fmt.Println("Blocklist loaded:", len(strings.Split(*blocklist, " ")), "entries")
}
if *zones != "" {
*zones = getZones(*zones)
fmt.Println("Zones loaded:", len(strings.Split(*zones, " "))-1, "entries")
}
if *hosts != "" {
*hosts = getHosts(*hosts)
fmt.Println("Hosts loaded:", len(strings.Split(*zones, "\n"))-1, "lines")
}
if *metrics {
fmt.Println("Starting prometheus exporter on port 9153")
http.Handle("/metrics", promhttp.Handler())
go http.ListenAndServe(":9153", nil)
}
dns.HandleFunc(".", handleQuery)
go serve(*addr, *port, "tcp")
go serve(*addr, *port, "udp")
up.Set(float64(1))
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
}