Files
tchivert e88ddacdca
build / build (push) Successful in 6m11s
rtype
2024-08-09 21:05:49 +02:00

343 lines
7.4 KiB
Go

package main
import (
"fmt"
"net"
"log"
"net/http"
"flag"
"os"
"os/signal"
"strings"
"syscall"
"runtime"
"io/ioutil"
"github.com/miekg/dns"
"gopkg.in/yaml.v2"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
addr = flag.String("addr", "127.0.0.1", "address to use")
port = flag.String("port", "53", "port to run on")
config = flag.String("config", "./config.yml", "add a custom zone file")
cpu = flag.Int("cpu", 0, "number of cpu to use")
logs = flag.Bool("logs", false, "log queries")
metrics = flag.Bool("metrics", false, "enable prometheus metrics")
zones Zones
up = promauto.NewGauge(prometheus.GaugeOpts{
Name: "rdns_up",
Help: "Non-null value when the server is ready"})
queries = promauto.NewCounter(prometheus.CounterOpts{
Name: "rdns_queries_total",
Help: "Total number of queries"})
qtypes = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_queries",
Help: "Number of queries by type"}, []string{"type"})
responses = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rdns_responses",
Help: "Number of responses by type"}, []string{"type", "status"})
)
type Domain struct {
Name string `yaml:"domain"`
TTL int `yaml:"ttl"`
Records []Record `yaml:"records"`
}
type Record struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Data []string `yaml:"data"`
}
type Zones struct {
Domains []Domain `yaml:"domains"`
}
func serve(addr string, port string, net string) {
server := &dns.Server{Addr: addr + ":" + port, Net: net, ReusePort: true}
if err := server.ListenAndServe(); err != nil {
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
}
}
func atoi(s string) int {
var val int
for _, c := range s {
val = val*10 + int(c-'0')
}
return val
}
func getRcode(rcode int) string {
switch rcode {
case 0:
return "NOERROR"
case 1:
return "FORMERR"
case 2:
return "SERVFAIL"
case 3:
return "NXDOMAIN"
case 4:
return "NOTIMP"
case 5:
return "REFUSED"
case 6:
return "YXDOMAIN"
case 7:
return "XRRSET"
case 8:
return "NOTAUTH"
case 9:
return "NOTZONE"
}
return "SERVFAIL"
}
func getDomain(qn string) string {
if strings.Count(qn, ".") <= 1 {
return qn
}
qs := strings.Split(qn, ".")
return strings.Join(qs[len(qs)-2:], ".")
}
func strchr(s string, c rune) int {
var n int
for _, i := range s {
if i == c {
n++
}
}
return n
}
func getData(qn, qt string) []string {
domain := getDomain(qn)
for _, zone := range zones.Domains {
if zone.Name == domain {
for _, record := range zone.Records {
if record.Type == qt && (record.Name == qn ||
record.Name[0] == '*' && strchr(record.Name, '.') == strchr(qn, '.')) {
return(record.Data)
}
}
}
}
return([]string{""})
}
func getTTL(qn, qt string) int {
domain := getDomain(qn)
for _, zone := range zones.Domains {
if zone.Name == domain {
return zone.TTL
}
}
return 86400
}
func doDot(s string) string {
if s[len(s)-1] == '.' {
return s
}
return s + "."
}
func replySOA(qname string, ttl int, data []string) []dns.RR {
var (
a []dns.RR
domain = getDomain(qname)
soa = strings.Split(getData(domain, "SOA")[0], " ")
)
a = append(a, &dns.SOA{
Hdr: dns.RR_Header{ Name: doDot(domain), Rrtype: 6, Class: 1, Ttl: uint32(atoi(soa[6])) },
Ns: doDot(soa[0]),
Mbox: doDot(soa[1]),
Serial: uint32(atoi(soa[2])),
Refresh: uint32(atoi(soa[3])),
Retry: uint32(atoi(soa[4])),
Expire: uint32(atoi(soa[5])),
Minttl: uint32(atoi(soa[6])),
})
return a
}
func replyNS(qname string, ttl int, data []string) []dns.RR {
var a []dns.RR
for _, name := range data {
a = append(a, &dns.NS{
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 2, Class: 1, Ttl: uint32(ttl) },
Ns: doDot(name),
})
}
return a
}
func replyCNAME(qname string, ttl int, data []string) []dns.RR {
var a []dns.RR
for _, name := range data {
a = append(a, &dns.CNAME{
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 5, Class: 1, Ttl: uint32(ttl) },
Target: doDot(name),
})
}
return a
}
func replyA(qname string, ttl int, data []string) []dns.RR {
var a []dns.RR
for _, ip := range data {
a = append(a, &dns.A{
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 1, Class: 1, Ttl: uint32(ttl) },
A: net.ParseIP(ip),
})
}
return a
}
func replyAAAA(qname string, ttl int, data []string) []dns.RR {
var a []dns.RR
for _, ip := range data {
a = append(a, &dns.AAAA{
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 28, Class: 1, Ttl: uint32(ttl) },
AAAA: net.ParseIP(ip),
})
}
return a
}
func handleQuery(w dns.ResponseWriter, q *dns.Msg) {
var (
r *dns.Msg
rcode int
qname = q.Question[0].Name[:len(q.Question[0].Name)-1]
qclass = dns.Class(q.Question[0].Qclass).String()
qtype = dns.Type(q.Question[0].Qtype).String()
client = strings.Split(w.RemoteAddr().String(), ":")[0]
data = getData(qname, qtype)
ttl = getTTL(qname, qtype)
ch bool
)
queries.Inc()
qtypes.WithLabelValues(qtype).Inc()
if qclass == "CH" && qname == "version.bind" {
qname = "version.bind."
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.Answer = append(r.Answer, &dns.TXT{
Hdr: dns.RR_Header{ Name: qname, Rrtype: 16, Class: 3, Ttl: 86400 },
Txt: []string{ "rdns" },
})
}
if data[0] == "" && qtype == "A" {
qtype = "CNAME"
data = getData(qname, qtype)
ch = true
}
if data[0] == "" && qtype != "SOA" {
r = new(dns.Msg)
rcode = 0
qtype = "SOA"
r.SetReply(q)
r.SetRcode(q, rcode)
r.MsgHdr.Authoritative = true
r.Answer = replySOA(qname, ttl, data)
} else {
r = new(dns.Msg)
rcode = 0
r.SetReply(q)
r.MsgHdr.Authoritative = true
switch qtype {
case "A":
r.Answer = replyA(qname, ttl, data)
case "AAAA":
r.Answer = replyAAAA(qname, ttl, data)
case "CNAME":
r.Answer = replyCNAME(qname, ttl, data)
if ch == true {
qname = data[0]
data = getData(qname, "A")
cha := replyA(qname, ttl, data)
for _, ans := range cha {
r.Answer = append(r.Answer, ans)
}
}
case "NS":
r.Answer = replyNS(qname, ttl, data)
case "SOA":
r.MsgHdr.Authoritative = true
r.Ns = replySOA(qname, ttl, data)
}
}
log.Println(client, qname, qclass, qtype, getRcode(rcode))
if r.Len() > 512 {
r.Truncated = true
}
w.WriteMsg(r)
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
return
}
func getZones(config string) Zones {
data, err := ioutil.ReadFile(config)
if err != nil {
log.Fatalf("error: %v", err)
}
var tmp Zones
err = yaml.Unmarshal(data, &tmp)
if err != nil {
log.Fatalf("error: %v", err)
}
return (tmp)
}
func main() {
flag.Parse()
fmt.Println("Starting Proxy Resolver:", net.JoinHostPort(*addr, *port))
if *cpu != 0 {
runtime.GOMAXPROCS(*cpu)
}
if *logs == false {
log.SetOutput(ioutil.Discard)
}
if *config != "" {
zones = getZones(*config)
fmt.Println("Zones:")
for _, domain := range zones.Domains {
fmt.Println(" ", domain.Name)
}
} else {
log.Println("Zone missing")
}
if *metrics == true {
fmt.Println("Starting prometheus exporter on port 9153")
http.Handle("/metrics", promhttp.Handler())
go http.ListenAndServe(":9153", nil)
}
dns.HandleFunc(".", handleQuery)
go serve(*addr, *port, "tcp")
go serve(*addr, *port, "udp")
up.Set(float64(1))
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
_ = <-sig
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
}