You've already forked rdns-authoritative
343 lines
7.4 KiB
Go
343 lines
7.4 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"log"
|
|
"net/http"
|
|
"flag"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"syscall"
|
|
"runtime"
|
|
"io/ioutil"
|
|
|
|
"github.com/miekg/dns"
|
|
"gopkg.in/yaml.v2"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
)
|
|
|
|
var (
|
|
addr = flag.String("addr", "127.0.0.1", "address to use")
|
|
port = flag.String("port", "53", "port to run on")
|
|
config = flag.String("config", "./config.yml", "add a custom zone file")
|
|
cpu = flag.Int("cpu", 0, "number of cpu to use")
|
|
logs = flag.Bool("logs", false, "log queries")
|
|
metrics = flag.Bool("metrics", false, "enable prometheus metrics")
|
|
zones Zones
|
|
|
|
up = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "rdns_up",
|
|
Help: "Non-null value when the server is ready"})
|
|
queries = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "rdns_queries_total",
|
|
Help: "Total number of queries"})
|
|
qtypes = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "rdns_queries",
|
|
Help: "Number of queries by type"}, []string{"type"})
|
|
responses = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "rdns_responses",
|
|
Help: "Number of responses by type"}, []string{"type", "status"})
|
|
)
|
|
|
|
type Domain struct {
|
|
Name string `yaml:"domain"`
|
|
TTL int `yaml:"ttl"`
|
|
Records []Record `yaml:"records"`
|
|
}
|
|
|
|
type Record struct {
|
|
Name string `yaml:"name"`
|
|
Type string `yaml:"type"`
|
|
Data []string `yaml:"data"`
|
|
}
|
|
|
|
type Zones struct {
|
|
Domains []Domain `yaml:"domains"`
|
|
}
|
|
|
|
func serve(addr string, port string, net string) {
|
|
server := &dns.Server{Addr: addr + ":" + port, Net: net, ReusePort: true}
|
|
if err := server.ListenAndServe(); err != nil {
|
|
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
|
}
|
|
}
|
|
|
|
func atoi(s string) int {
|
|
var val int
|
|
for _, c := range s {
|
|
val = val*10 + int(c-'0')
|
|
}
|
|
return val
|
|
}
|
|
|
|
func getRcode(rcode int) string {
|
|
switch rcode {
|
|
case 0:
|
|
return "NOERROR"
|
|
case 1:
|
|
return "FORMERR"
|
|
case 2:
|
|
return "SERVFAIL"
|
|
case 3:
|
|
return "NXDOMAIN"
|
|
case 4:
|
|
return "NOTIMP"
|
|
case 5:
|
|
return "REFUSED"
|
|
case 6:
|
|
return "YXDOMAIN"
|
|
case 7:
|
|
return "XRRSET"
|
|
case 8:
|
|
return "NOTAUTH"
|
|
case 9:
|
|
return "NOTZONE"
|
|
}
|
|
return "SERVFAIL"
|
|
}
|
|
|
|
func getDomain(qn string) string {
|
|
if strings.Count(qn, ".") <= 1 {
|
|
return qn
|
|
}
|
|
qs := strings.Split(qn, ".")
|
|
return strings.Join(qs[len(qs)-2:], ".")
|
|
}
|
|
|
|
func strchr(s string, c rune) int {
|
|
var n int
|
|
for _, i := range s {
|
|
if i == c {
|
|
n++
|
|
}
|
|
}
|
|
return n
|
|
}
|
|
|
|
func getData(qn, qt string) []string {
|
|
domain := getDomain(qn)
|
|
for _, zone := range zones.Domains {
|
|
if zone.Name == domain {
|
|
for _, record := range zone.Records {
|
|
if record.Type == qt && (record.Name == qn ||
|
|
record.Name[0] == '*' && strchr(record.Name, '.') == strchr(qn, '.')) {
|
|
return(record.Data)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return([]string{""})
|
|
}
|
|
|
|
func getTTL(qn, qt string) int {
|
|
domain := getDomain(qn)
|
|
for _, zone := range zones.Domains {
|
|
if zone.Name == domain {
|
|
return zone.TTL
|
|
}
|
|
}
|
|
return 86400
|
|
}
|
|
|
|
func doDot(s string) string {
|
|
if s[len(s)-1] == '.' {
|
|
return s
|
|
}
|
|
return s + "."
|
|
}
|
|
|
|
func replySOA(qname string, ttl int, data []string) []dns.RR {
|
|
var (
|
|
a []dns.RR
|
|
domain = getDomain(qname)
|
|
soa = strings.Split(getData(domain, "SOA")[0], " ")
|
|
)
|
|
a = append(a, &dns.SOA{
|
|
Hdr: dns.RR_Header{ Name: doDot(domain), Rrtype: 6, Class: 1, Ttl: uint32(atoi(soa[6])) },
|
|
Ns: doDot(soa[0]),
|
|
Mbox: doDot(soa[1]),
|
|
Serial: uint32(atoi(soa[2])),
|
|
Refresh: uint32(atoi(soa[3])),
|
|
Retry: uint32(atoi(soa[4])),
|
|
Expire: uint32(atoi(soa[5])),
|
|
Minttl: uint32(atoi(soa[6])),
|
|
})
|
|
return a
|
|
}
|
|
|
|
func replyNS(qname string, ttl int, data []string) []dns.RR {
|
|
var a []dns.RR
|
|
for _, name := range data {
|
|
a = append(a, &dns.NS{
|
|
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 2, Class: 1, Ttl: uint32(ttl) },
|
|
Ns: doDot(name),
|
|
})
|
|
}
|
|
return a
|
|
}
|
|
|
|
func replyCNAME(qname string, ttl int, data []string) []dns.RR {
|
|
var a []dns.RR
|
|
for _, name := range data {
|
|
a = append(a, &dns.CNAME{
|
|
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 5, Class: 1, Ttl: uint32(ttl) },
|
|
Target: doDot(name),
|
|
})
|
|
}
|
|
return a
|
|
}
|
|
|
|
func replyA(qname string, ttl int, data []string) []dns.RR {
|
|
var a []dns.RR
|
|
for _, ip := range data {
|
|
a = append(a, &dns.A{
|
|
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 1, Class: 1, Ttl: uint32(ttl) },
|
|
A: net.ParseIP(ip),
|
|
})
|
|
}
|
|
return a
|
|
}
|
|
|
|
func replyAAAA(qname string, ttl int, data []string) []dns.RR {
|
|
var a []dns.RR
|
|
for _, ip := range data {
|
|
a = append(a, &dns.AAAA{
|
|
Hdr: dns.RR_Header{ Name: doDot(qname), Rrtype: 28, Class: 1, Ttl: uint32(ttl) },
|
|
AAAA: net.ParseIP(ip),
|
|
})
|
|
}
|
|
return a
|
|
}
|
|
|
|
func handleQuery(w dns.ResponseWriter, q *dns.Msg) {
|
|
var (
|
|
r *dns.Msg
|
|
rcode int
|
|
qname = q.Question[0].Name[:len(q.Question[0].Name)-1]
|
|
qclass = dns.Class(q.Question[0].Qclass).String()
|
|
qtype = dns.Type(q.Question[0].Qtype).String()
|
|
client = strings.Split(w.RemoteAddr().String(), ":")[0]
|
|
data = getData(qname, qtype)
|
|
ttl = getTTL(qname, qtype)
|
|
ch bool
|
|
)
|
|
|
|
queries.Inc()
|
|
qtypes.WithLabelValues(qtype).Inc()
|
|
|
|
if qclass == "CH" && qname == "version.bind" {
|
|
qname = "version.bind."
|
|
r = new(dns.Msg)
|
|
rcode = 0
|
|
r.SetReply(q)
|
|
r.Answer = append(r.Answer, &dns.TXT{
|
|
Hdr: dns.RR_Header{ Name: qname, Rrtype: 16, Class: 3, Ttl: 86400 },
|
|
Txt: []string{ "rdns" },
|
|
})
|
|
}
|
|
|
|
if data[0] == "" && qtype == "A" {
|
|
qtype = "CNAME"
|
|
data = getData(qname, qtype)
|
|
ch = true
|
|
}
|
|
|
|
if data[0] == "" && qtype != "SOA" {
|
|
r = new(dns.Msg)
|
|
rcode = 0
|
|
qtype = "SOA"
|
|
r.SetReply(q)
|
|
r.SetRcode(q, rcode)
|
|
r.MsgHdr.Authoritative = true
|
|
r.Answer = replySOA(qname, ttl, data)
|
|
} else {
|
|
r = new(dns.Msg)
|
|
rcode = 0
|
|
r.SetReply(q)
|
|
r.MsgHdr.Authoritative = true
|
|
switch qtype {
|
|
case "A":
|
|
r.Answer = replyA(qname, ttl, data)
|
|
case "AAAA":
|
|
r.Answer = replyAAAA(qname, ttl, data)
|
|
case "CNAME":
|
|
r.Answer = replyCNAME(qname, ttl, data)
|
|
if ch == true {
|
|
qname = data[0]
|
|
data = getData(qname, "A")
|
|
cha := replyA(qname, ttl, data)
|
|
for _, ans := range cha {
|
|
r.Answer = append(r.Answer, ans)
|
|
}
|
|
}
|
|
case "NS":
|
|
r.Answer = replyNS(qname, ttl, data)
|
|
case "SOA":
|
|
r.MsgHdr.Authoritative = true
|
|
r.Ns = replySOA(qname, ttl, data)
|
|
}
|
|
}
|
|
log.Println(client, qname, qclass, qtype, getRcode(rcode))
|
|
if r.Len() > 512 {
|
|
r.Truncated = true
|
|
}
|
|
w.WriteMsg(r)
|
|
responses.WithLabelValues(qtype, getRcode(rcode)).Inc()
|
|
return
|
|
}
|
|
|
|
func getZones(config string) Zones {
|
|
data, err := ioutil.ReadFile(config)
|
|
if err != nil {
|
|
log.Fatalf("error: %v", err)
|
|
}
|
|
|
|
var tmp Zones
|
|
err = yaml.Unmarshal(data, &tmp)
|
|
if err != nil {
|
|
log.Fatalf("error: %v", err)
|
|
}
|
|
return (tmp)
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
fmt.Println("Starting Proxy Resolver:", net.JoinHostPort(*addr, *port))
|
|
if *cpu != 0 {
|
|
runtime.GOMAXPROCS(*cpu)
|
|
}
|
|
if *logs == false {
|
|
log.SetOutput(ioutil.Discard)
|
|
}
|
|
if *config != "" {
|
|
zones = getZones(*config)
|
|
fmt.Println("Zones:")
|
|
for _, domain := range zones.Domains {
|
|
fmt.Println(" ", domain.Name)
|
|
}
|
|
} else {
|
|
log.Println("Zone missing")
|
|
}
|
|
if *metrics == true {
|
|
fmt.Println("Starting prometheus exporter on port 9153")
|
|
http.Handle("/metrics", promhttp.Handler())
|
|
go http.ListenAndServe(":9153", nil)
|
|
}
|
|
|
|
dns.HandleFunc(".", handleQuery)
|
|
go serve(*addr, *port, "tcp")
|
|
go serve(*addr, *port, "udp")
|
|
|
|
up.Set(float64(1))
|
|
|
|
sig := make(chan os.Signal)
|
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
|
_ = <-sig
|
|
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
|
|
}
|