diff --git a/Makefile b/Makefile index a41dbbd..129c8e3 100644 --- a/Makefile +++ b/Makefile @@ -8,20 +8,22 @@ SHELL := /bin/bash .PHONY: default default: repo -GOPATH=$(shell pwd)/go_$(VERSION) +mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) +mkfile_dir := $(patsubst %/,%,$(dir $(mkfile_path))) +GOPATH := $(mkfile_dir)/go_$(VERSION) clean: rm -fv *.deb build: check-vars clean - mkdir -p $(GOPATH) - GOPATH=$(GOPATH) go get -v -d github.com/AdguardTeam/AdguardDNS + mkdir -p $(GOPATH)/src/bit.adguard.com/dns + if [ ! -h $(GOPATH)/src/bit.adguard.com/dns/adguard-internal-dns ]; then rm -rf $(GOPATH)/src/bit.adguard.com/dns/adguard-internal-dns && ln -fs $(mkfile_dir) $(GOPATH)/src/bit.adguard.com/dns/adguard-internal-dns; fi GOPATH=$(GOPATH) go get -v -d github.com/coredns/coredns cp plugin.cfg $(GOPATH)/src/github.com/coredns/coredns cd $(GOPATH)/src/github.com/coredns/coredns; GOPATH=$(GOPATH) go generate cd $(GOPATH)/src/github.com/coredns/coredns; GOPATH=$(GOPATH) go get -v -d -t . cd $(GOPATH)/src/github.com/coredns/coredns; GOPATH=$(GOPATH) PATH=$(GOPATH)/bin:$(PATH) make - cd $(GOPATH)/src/github.com/coredns/coredns; GOPATH=$(GOPATH) go build -x -v -ldflags="-X github.com/coredns/coredns/coremain.GitCommit=$(VERSION)" -o $(GOPATH)/bin/coredns + cd $(GOPATH)/src/github.com/coredns/coredns; GOPATH=$(GOPATH) go build -x -v -ldflags="-X github.com/coredns/coredns/coremain.GitCommit=$(VERSION)" -asmflags="-trimpath=$(GOPATH)" -gcflags="-trimpath=$(GOPATH)" -o $(GOPATH)/bin/coredns package: build fpm --prefix /usr/local/bin \ diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go new file mode 100644 index 0000000..f3a946d --- /dev/null +++ b/coredns_plugin/coredns_plugin.go @@ -0,0 +1,557 @@ +package dnsfilter + +import ( + "bufio" + "errors" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + "github.com/mholt/caddy" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/context" +) + +var defaultSOA = &dns.SOA{ + // values copied from verisign's nonexistent .com domain + // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, +} + +func init() { + caddy.RegisterPlugin("dnsfilter", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +type plugFilter struct { + ID int64 + Path string +} + +type plugSettings struct { + SafeBrowsingBlockHost string + ParentalBlockHost string + QueryLogEnabled bool + BlockedTTL uint32 // in seconds, default 3600 + Filters []plugFilter +} + +type plug struct { + d *dnsfilter.Dnsfilter + Next plugin.Handler + upstream upstream.Upstream + settings plugSettings + + sync.RWMutex +} + +var defaultPluginSettings = plugSettings{ + SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", + ParentalBlockHost: "family.block.dns.adguard.com", + BlockedTTL: 3600, // in seconds + Filters: make([]plugFilter, 0), +} + +// +// coredns handling functions +// +func setupPlugin(c *caddy.Controller) (*plug, error) { + // create new Plugin and copy default values + p := &plug{ + settings: defaultPluginSettings, + d: dnsfilter.New(), + } + + log.Println("Initializing the CoreDNS plugin") + + for c.Next() { + for c.NextBlock() { + blockValue := c.Val() + switch blockValue { + case "safebrowsing": + log.Println("Browsing security service is enabled") + p.d.EnableSafeBrowsing() + if c.NextArg() { + if len(c.Val()) == 0 { + return nil, c.ArgErr() + } + p.d.SetSafeBrowsingServer(c.Val()) + } + case "safesearch": + log.Println("Safe search is enabled") + p.d.EnableSafeSearch() + case "parental": + if !c.NextArg() { + return nil, c.ArgErr() + } + sensitivity, err := strconv.Atoi(c.Val()) + if err != nil { + return nil, c.ArgErr() + } + + log.Println("Parental control is enabled") + err = p.d.EnableParental(sensitivity) + if err != nil { + return nil, c.ArgErr() + } + if c.NextArg() { + if len(c.Val()) == 0 { + return nil, c.ArgErr() + } + p.settings.ParentalBlockHost = c.Val() + } + case "blocked_ttl": + if !c.NextArg() { + return nil, c.ArgErr() + } + blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32) + if err != nil { + return nil, c.ArgErr() + } + log.Printf("Blocked request TTL is %d", blockedTtl) + p.settings.BlockedTTL = uint32(blockedTtl) + case "querylog": + log.Println("Query log is enabled") + p.settings.QueryLogEnabled = true + case "filter": + if !c.NextArg() { + return nil, c.ArgErr() + } + + filterId, err := strconv.ParseInt(c.Val(), 10, 64) + if err != nil { + return nil, c.ArgErr() + } + if !c.NextArg() { + return nil, c.ArgErr() + } + filterPath := c.Val() + + // Initialize filter and add it to the list + p.settings.Filters = append(p.settings.Filters, plugFilter{ + ID: filterId, + Path: filterPath, + }) + } + } + } + + for _, filter := range p.settings.Filters { + log.Printf("Loading rules from %s", filter.Path) + + file, err := os.Open(filter.Path) + if err != nil { + return nil, err + } + defer file.Close() + + count := 0 + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + + err = p.d.AddRule(text, filter.ID) + if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { + continue + } + if err != nil { + log.Printf("Cannot add rule %s: %s", text, err) + // Just ignore invalid rules + continue + } + count++ + } + log.Printf("Added %d rules from filter ID=%d", count, filter.ID) + + if err = scanner.Err(); err != nil { + return nil, err + } + } + + log.Printf("Loading stats from querylog") + err := fillStatsFromQueryLog() + if err != nil { + log.Printf("Failed to load stats from querylog: %s", err) + return nil, err + } + + if p.settings.QueryLogEnabled { + onceQueryLog.Do(func() { + go periodicQueryLogRotate() + go periodicHourlyTopRotate() + go statsRotator() + }) + } + + onceHook.Do(func() { + caddy.RegisterEventHook("dnsfilter-reload", hook) + }) + + p.upstream, err = upstream.New(nil) + if err != nil { + return nil, err + } + + return p, nil +} + +func setup(c *caddy.Controller) error { + p, err := setupPlugin(c) + if err != nil { + return err + } + config := dnsserver.GetConfig(c) + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + p.Next = next + return p + }) + + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(requests) + x.MustRegister(filtered) + x.MustRegister(filteredLists) + x.MustRegister(filteredSafebrowsing) + x.MustRegister(filteredParental) + x.MustRegister(whitelisted) + x.MustRegister(safesearch) + x.MustRegister(errorsTotal) + x.MustRegister(elapsedTime) + x.MustRegister(p) + } + return nil + }) + c.OnShutdown(p.onShutdown) + c.OnFinalShutdown(p.onFinalShutdown) + + return nil +} + +func (p *plug) onShutdown() error { + p.Lock() + p.d.Destroy() + p.d = nil + p.Unlock() + return nil +} + +func (p *plug) onFinalShutdown() error { + logBufferLock.Lock() + err := flushToFile(logBuffer) + if err != nil { + log.Printf("failed to flush to file: %s", err) + return err + } + logBufferLock.Unlock() + return nil +} + +type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) + +func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { + realch, ok := ch.(chan<- *prometheus.Desc) + if !ok { + log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n") + return + } + realch <- prometheus.NewDesc(name, text, nil, nil) +} + +func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { + realch, ok := ch.(chan<- prometheus.Metric) + if !ok { + log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n") + return + } + desc := prometheus.NewDesc(name, text, nil, nil) + realch <- prometheus.MustNewConstMetric(desc, valueType, value) +} + +func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) { + doFunc(ch, name, text, value, valueType) +} + +func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) { + gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue) + gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue) + gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue) + gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue) +} + +func (p *plug) doStats(ch interface{}, doFunc statsFunc) { + p.RLock() + stats := p.d.GetStats() + doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) + doStatsLookup(ch, doFunc, "parental", &stats.Parental) + p.RUnlock() +} + +// Describe is called by prometheus handler to know stat types +func (p *plug) Describe(ch chan<- *prometheus.Desc) { + p.doStats(ch, doDesc) +} + +// Collect is called by prometheus handler to collect stats +func (p *plug) Collect(ch chan<- prometheus.Metric) { + p.doStats(ch, doMetric) +} + +func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { + // check if it's a domain name or IP address + addr := net.ParseIP(val) + var records []dns.RR + // log.Println("Will give", val, "instead of", host) // debug logging + if addr != nil { + // this is an IP address, return it + result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val)) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) + } + records = append(records, result) + } else { + // this is a domain name, need to look it up + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(val), question.Qtype) + req.RecursionDesired = true + reqstate := request.Request{W: w, Req: req, Context: ctx} + result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) + } + if result != nil { + for _, answer := range result.Answer { + answer.Header().Name = question.Name + } + records = result.Answer + } + } + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true + m.Answer = append(m.Answer, records...) + state := request.Request{W: w, Req: r, Context: ctx} + state.SizeAndDo(m) + err := state.W.WriteMsg(m) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) + } + return dns.RcodeSuccess, nil +} + +// generate SOA record that makes DNS clients cache NXdomain results +// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant +func (p *plug) genSOA(r *dns.Msg) []dns.RR { + zone := r.Question[0].Name + header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET} + + Mbox := "hostmaster." + if zone[0] != '.' { + Mbox += zone + } + Ns := "fake-for-negative-caching.adguard.com." + + soa := *defaultSOA + soa.Hdr = header + soa.Mbox = Mbox + soa.Ns = Ns + soa.Serial = 100500 // faster than uint32(time.Now().Unix()) + return []dns.RR{&soa} +} + +func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r, Context: ctx} + m := new(dns.Msg) + m.SetRcode(state.Req, dns.RcodeNameError) + m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true + m.Ns = p.genSOA(r) + + state.SizeAndDo(m) + err := state.W.WriteMsg(m) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, err + } + return dns.RcodeNameError, nil +} + +func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { + if len(r.Question) != 1 { + // google DNS, bind and others do the same + return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") + } + for _, question := range r.Question { + host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) + // is it a safesearch domain? + p.RLock() + if val, ok := p.d.SafeSearchDomain(host); ok { + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if err != nil { + p.RUnlock() + return rcode, dnsfilter.Result{}, err + } + p.RUnlock() + return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err + } + p.RUnlock() + + // needs to be filtered instead + p.RLock() + result, err := p.d.CheckHost(host) + if err != nil { + log.Printf("plugin/dnsfilter: %s\n", err) + p.RUnlock() + return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) + } + p.RUnlock() + + if result.IsFiltered { + switch result.Reason { + case dnsfilter.FilteredSafeBrowsing: + // return cname safebrowsing.block.dns.adguard.com + val := p.settings.SafeBrowsingBlockHost + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if err != nil { + return rcode, dnsfilter.Result{}, err + } + return rcode, result, err + case dnsfilter.FilteredParental: + // return cname family.block.dns.adguard.com + val := p.settings.ParentalBlockHost + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if err != nil { + return rcode, dnsfilter.Result{}, err + } + return rcode, result, err + case dnsfilter.FilteredBlackList: + + if result.Ip == nil { + // return NXDomain + rcode, err := p.writeNXdomain(ctx, w, r) + if err != nil { + return rcode, dnsfilter.Result{}, err + } + return rcode, result, err + } else { + // This is a hosts-syntax rule + rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question) + if err != nil { + return rcode, dnsfilter.Result{}, err + } + return rcode, result, err + } + case dnsfilter.FilteredInvalid: + // return NXdomain + rcode, err := p.writeNXdomain(ctx, w, r) + if err != nil { + return rcode, dnsfilter.Result{}, err + } + return rcode, result, err + default: + log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result) + } + } else { + switch result.Reason { + case dnsfilter.NotFilteredWhiteList: + rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) + return rcode, result, err + case dnsfilter.NotFilteredNotFound: + // do nothing, pass through to lower code + default: + log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result) + } + } + } + rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) + return rcode, dnsfilter.Result{}, err +} + +// ServeDNS handles the DNS request and refuses if it's in filterlists +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + start := time.Now() + requests.Inc() + state := request.Request{W: w, Req: r} + ip := state.IP() + + // capture the written answer + rrw := dnstest.NewRecorder(w) + rcode, result, err := p.serveDNSInternal(ctx, rrw, r) + if rcode > 0 { + // actually send the answer if we have one + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + state.SizeAndDo(answer) + err = w.WriteMsg(answer) + if err != nil { + return dns.RcodeServerFailure, err + } + } + + // increment counters + switch { + case err != nil: + errorsTotal.Inc() + case result.Reason == dnsfilter.FilteredBlackList: + filtered.Inc() + filteredLists.Inc() + case result.Reason == dnsfilter.FilteredSafeBrowsing: + filtered.Inc() + filteredSafebrowsing.Inc() + case result.Reason == dnsfilter.FilteredParental: + filtered.Inc() + filteredParental.Inc() + case result.Reason == dnsfilter.FilteredInvalid: + filtered.Inc() + filteredInvalid.Inc() + case result.Reason == dnsfilter.FilteredSafeSearch: + // the request was passsed through but not filtered, don't increment filtered + safesearch.Inc() + case result.Reason == dnsfilter.NotFilteredWhiteList: + whitelisted.Inc() + case result.Reason == dnsfilter.NotFilteredNotFound: + // do nothing + case result.Reason == dnsfilter.NotFilteredError: + text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!" + log.Println(text) + err = errors.New(text) + rcode = dns.RcodeServerFailure + } + + // log + elapsed := time.Since(start) + elapsedTime.Observe(elapsed.Seconds()) + if p.settings.QueryLogEnabled { + logRequest(r, rrw.Msg, result, time.Since(start), ip) + } + return rcode, err +} + +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "dnsfilter" } + +var onceHook sync.Once +var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go new file mode 100644 index 0000000..1733fd6 --- /dev/null +++ b/coredns_plugin/coredns_plugin_test.go @@ -0,0 +1,131 @@ +package dnsfilter + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "os" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/mholt/caddy" + "github.com/miekg/dns" +) + +func TestSetup(t *testing.T) { + for i, testcase := range []struct { + config string + failing bool + }{ + {`dnsfilter`, false}, + {`dnsfilter { + filter 0 /dev/nonexistent/abcdef + }`, true}, + {`dnsfilter { + filter 0 ../tests/dns.txt + }`, false}, + {`dnsfilter { + safebrowsing + filter 0 ../tests/dns.txt + }`, false}, + {`dnsfilter { + parental + filter 0 ../tests/dns.txt + }`, true}, + } { + c := caddy.NewTestController("dns", testcase.config) + err := setup(c) + if err != nil { + if !testcase.failing { + t.Fatalf("Test #%d expected no errors, but got: %v", i, err) + } + continue + } + if testcase.failing { + t.Fatalf("Test #%d expected to fail but it didn't", i) + } + } +} + +func TestEtcHostsFilter(t *testing.T) { + text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net") + tmpfile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + if _, err = tmpfile.Write(text); err != nil { + t.Fatal(err) + } + if err = tmpfile.Close(); err != nil { + t.Fatal(err) + } + + defer os.Remove(tmpfile.Name()) + + configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name()) + c := caddy.NewTestController("dns", configText) + p, err := setupPlugin(c) + if err != nil { + t.Fatal(err) + } + + p.Next = zeroTTLBackend() + + ctx := context.TODO() + + for _, testcase := range []struct { + host string + filtered bool + }{ + {"www.doubleclick.net", false}, + {"doubleclick.net", true}, + {"www2.example.org", false}, + {"www2.example.net", false}, + {"test.www.example.org", false}, + {"test.www.example.net", false}, + {"example.org", true}, + {"example.net", true}, + {"www.example.org", true}, + {"www.example.net", true}, + } { + req := new(dns.Msg) + req.SetQuestion(testcase.host+".", dns.TypeA) + + resp := test.ResponseWriter{} + rrw := dnstest.NewRecorder(&resp) + rcode, err := p.ServeDNS(ctx, rrw, req) + if err != nil { + t.Fatalf("ServeDNS returned error: %s", err) + } + if rcode != rrw.Rcode { + t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode) + } + A, ok := rrw.Msg.Answer[0].(*dns.A) + if !ok { + t.Fatalf("Host %s expected to have result A", testcase.host) + } + ip := net.IPv4(127, 0, 0, 1) + filtered := ip.Equal(A.A) + if testcase.filtered && testcase.filtered != filtered { + t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host) + } + if !testcase.filtered && testcase.filtered != filtered { + t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host) + } + } +} + +func zeroTTLBackend() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go new file mode 100644 index 0000000..b138911 --- /dev/null +++ b/coredns_plugin/coredns_stats.go @@ -0,0 +1,410 @@ +package dnsfilter + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") + filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") + filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") + filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") + filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") + filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") + whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") + safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") + errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") + elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.") +) + +// entries for single time period (for example all per-second entries) +type statsEntries map[string][statsHistoryElements]float64 + +// how far back to keep the stats +const statsHistoryElements = 60 + 1 // +1 for calculating delta + +// each periodic stat is a map of arrays +type periodicStats struct { + Entries statsEntries + period time.Duration // how long one entry lasts + LastRotate time.Time // last time this data was rotated + + sync.RWMutex +} + +type stats struct { + PerSecond periodicStats + PerMinute periodicStats + PerHour periodicStats + PerDay periodicStats +} + +// per-second/per-minute/per-hour/per-day stats +var statistics stats + +func initPeriodicStats(periodic *periodicStats, period time.Duration) { + periodic.Entries = statsEntries{} + periodic.LastRotate = time.Now() + periodic.period = period +} + +func init() { + purgeStats() +} + +func purgeStats() { + initPeriodicStats(&statistics.PerSecond, time.Second) + initPeriodicStats(&statistics.PerMinute, time.Minute) + initPeriodicStats(&statistics.PerHour, time.Hour) + initPeriodicStats(&statistics.PerDay, time.Hour*24) +} + +func (p *periodicStats) Inc(name string, when time.Time) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + currentValues := p.Entries[name] + currentValues[elapsed]++ + p.Entries[name] = currentValues + p.Unlock() +} + +func (p *periodicStats) Observe(name string, when time.Time, value float64) { + // calculate how many periods ago this happened + elapsed := int64(time.Since(when) / p.period) + // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) + if elapsed >= statsHistoryElements { + return // outside of our timeframe + } + p.Lock() + { + countname := name + "_count" + currentValues := p.Entries[countname] + value := currentValues[elapsed] + // trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) + value += 1 + currentValues[elapsed] = value + p.Entries[countname] = currentValues + } + { + totalname := name + "_sum" + currentValues := p.Entries[totalname] + currentValues[elapsed] += value + p.Entries[totalname] = currentValues + } + p.Unlock() +} + +func (p *periodicStats) statsRotate(now time.Time) { + p.Lock() + rotations := int64(now.Sub(p.LastRotate) / p.period) + if rotations > statsHistoryElements { + rotations = statsHistoryElements + } + // calculate how many times we should rotate + for r := int64(0); r < rotations; r++ { + for key, values := range p.Entries { + newValues := [statsHistoryElements]float64{} + for i := 1; i < len(values); i++ { + newValues[i] = values[i-1] + } + p.Entries[key] = newValues + } + } + if rotations > 0 { + p.LastRotate = now + } + p.Unlock() +} + +func statsRotator() { + for range time.Tick(time.Second) { + now := time.Now() + statistics.PerSecond.statsRotate(now) + statistics.PerMinute.statsRotate(now) + statistics.PerHour.statsRotate(now) + statistics.PerDay.statsRotate(now) + } +} + +// counter that wraps around prometheus Counter but also adds to periodic stats +type counter struct { + name string // used as key in periodic stats + value int64 + prom prometheus.Counter +} + +func newDNSCounter(name string, help string) *counter { + // trace("called") + c := &counter{} + c.prom = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnsfilter", + Name: name, + Help: help, + }) + c.name = name + + return c +} + +func (c *counter) IncWithTime(when time.Time) { + statistics.PerSecond.Inc(c.name, when) + statistics.PerMinute.Inc(c.name, when) + statistics.PerHour.Inc(c.name, when) + statistics.PerDay.Inc(c.name, when) + c.value++ + c.prom.Inc() +} + +func (c *counter) Inc() { + c.IncWithTime(time.Now()) +} + +func (c *counter) Describe(ch chan<- *prometheus.Desc) { + c.prom.Describe(ch) +} + +func (c *counter) Collect(ch chan<- prometheus.Metric) { + c.prom.Collect(ch) +} + +type histogram struct { + name string // used as key in periodic stats + count int64 + total float64 + prom prometheus.Histogram +} + +func newDNSHistogram(name string, help string) *histogram { + // trace("called") + h := &histogram{} + h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnsfilter", + Name: name, + Help: help, + }) + h.name = name + + return h +} + +func (h *histogram) ObserveWithTime(value float64, when time.Time) { + statistics.PerSecond.Observe(h.name, when, value) + statistics.PerMinute.Observe(h.name, when, value) + statistics.PerHour.Observe(h.name, when, value) + statistics.PerDay.Observe(h.name, when, value) + h.count++ + h.total += value + h.prom.Observe(value) +} + +func (h *histogram) Observe(value float64) { + h.ObserveWithTime(value, time.Now()) +} + +func (h *histogram) Describe(ch chan<- *prometheus.Desc) { + h.prom.Describe(ch) +} + +func (h *histogram) Collect(ch chan<- prometheus.Metric) { + h.prom.Collect(ch) +} + +// ----- +// stats +// ----- +func HandleStats(w http.ResponseWriter, r *http.Request) { + const numHours = 24 + histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) + // sum them up + summed := map[string]interface{}{} + for key, values := range histrical { + summedValue := 0.0 + floats, ok := values.([]float64) + if !ok { + continue + } + for _, v := range floats { + summedValue += v + } + summed[key] = summedValue + } + // don't forget to divide by number of elements in returned slice + if val, ok := summed["avg_processing_time"]; ok { + if flval, flok := val.(float64); flok { + flval /= numHours + summed["avg_processing_time"] = flval + } + } + + summed["stats_period"] = "24 hours" + + json, err := json.Marshal(summed) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { + // clamp + start = clamp(start, 0, statsHistoryElements) + end = clamp(end, 0, statsHistoryElements) + + avgProcessingTime := make([]float64, 0) + + count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) + sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) + for i := 0; i < len(count); i++ { + var avg float64 + if count[i] != 0 { + avg = sum[i] / count[i] + avg *= 1000 + } + avgProcessingTime = append(avgProcessingTime, avg) + } + + result := map[string]interface{}{ + "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), + "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), + "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), + "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), + "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), + "avg_processing_time": avgProcessingTime, + } + return result +} + +func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { + // handle time unit and prepare our time window size + now := time.Now() + timeUnitString := r.URL.Query().Get("time_unit") + var stats *periodicStats + var timeUnit time.Duration + switch timeUnitString { + case "seconds": + timeUnit = time.Second + stats = &statistics.PerSecond + case "minutes": + timeUnit = time.Minute + stats = &statistics.PerMinute + case "hours": + timeUnit = time.Hour + stats = &statistics.PerHour + case "days": + timeUnit = time.Hour * 24 + stats = &statistics.PerDay + default: + http.Error(w, "Must specify valid time_unit parameter", 400) + return + } + + // parse start and end time + startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) + if err != nil { + errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) + log.Println(errortext) + http.Error(w, errortext, 400) + return + } + + // check if start and time times are within supported time range + timeRange := timeUnit * statsHistoryElements + if startTime.Add(timeRange).Before(now) { + http.Error(w, "start_time parameter is outside of supported range", 501) + return + } + if endTime.Add(timeRange).Before(now) { + http.Error(w, "end_time parameter is outside of supported range", 501) + return + } + + // calculate start and end of our array + // basically it's how many hours/minutes/etc have passed since now + start := int(now.Sub(endTime) / timeUnit) + end := int(now.Sub(startTime) / timeUnit) + + // swap them around if they're inverted + if start > end { + start, end = end, start + } + + data := generateMapFromStats(stats, start, end) + json, err := json.Marshal(data) + if err != nil { + errortext := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(json) + if err != nil { + errortext := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errortext) + http.Error(w, errortext, 500) + return + } +} + +func HandleStatsReset(w http.ResponseWriter, r *http.Request) { + purgeStats() + _, err := fmt.Fprintf(w, "OK\n") + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } +} + +func clamp(value, low, high int) int { + if value < low { + return low + } + if value > high { + return high + } + return value +} + +// -------------------------- +// helper functions for stats +// -------------------------- +func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { + output := make([]float64, 0) + for i := start; i <= end; i++ { + output = append([]float64{input[i]}, output...) + } + return output +} diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go new file mode 100644 index 0000000..92ba2d1 --- /dev/null +++ b/coredns_plugin/querylog.go @@ -0,0 +1,239 @@ +package dnsfilter + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/miekg/dns" +) + +const ( + logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk + queryLogTimeLimit = time.Hour * 24 // how far in the past we care about querylogs + queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours + queryLogFileName = "querylog.json" // .gz added during compression + queryLogSize = 5000 // maximum API response for /querylog + queryLogTopSize = 500 // Keep in memory only top N values +) + +var ( + logBufferLock sync.RWMutex + logBuffer []*logEntry + + queryLogCache []*logEntry + queryLogLock sync.RWMutex +) + +type logEntry struct { + Question []byte + Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + Result dnsfilter.Result + Time time.Time + Elapsed time.Duration + IP string +} + +func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { + var q []byte + var a []byte + var err error + + if question != nil { + q, err = question.Pack() + if err != nil { + log.Printf("failed to pack question for querylog: %s", err) + return + } + } + if answer != nil { + a, err = answer.Pack() + if err != nil { + log.Printf("failed to pack answer for querylog: %s", err) + return + } + } + + now := time.Now() + entry := logEntry{ + Question: q, + Answer: a, + Result: result, + Time: now, + Elapsed: elapsed, + IP: ip, + } + var flushBuffer []*logEntry + + logBufferLock.Lock() + logBuffer = append(logBuffer, &entry) + if len(logBuffer) >= logBufferCap { + flushBuffer = logBuffer + logBuffer = nil + } + logBufferLock.Unlock() + queryLogLock.Lock() + queryLogCache = append(queryLogCache, &entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() + + // add it to running top + err = runningTop.addEntry(&entry, question, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + // don't do failure, just log + } + + // if buffer needs to be flushed to disk, do it now + if len(flushBuffer) > 0 { + // write to file + // do it in separate goroutine -- we are stalling DNS response this whole time + go flushToFile(flushBuffer) + } +} + +func HandleQueryLog(w http.ResponseWriter, r *http.Request) { + queryLogLock.RLock() + values := make([]*logEntry, len(queryLogCache)) + copy(values, queryLogCache) + queryLogLock.RUnlock() + + // reverse it so that newest is first + for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { + values[left], values[right] = values[right], values[left] + } + + var data = []map[string]interface{}{} + for _, entry := range values { + var q *dns.Msg + var a *dns.Msg + + if len(entry.Question) > 0 { + q = new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + // ignore, log and move on + log.Printf("Failed to unpack dns message question: %s", err) + q = nil + } + } + if len(entry.Answer) > 0 { + a = new(dns.Msg) + if err := a.Unpack(entry.Answer); err != nil { + // ignore, log and move on + log.Printf("Failed to unpack dns message question: %s", err) + a = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339), + "client": entry.IP, + } + if q != nil { + jsonEntry["question"] = map[string]interface{}{ + "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), + "type": dns.Type(q.Question[0].Qtype).String(), + "class": dns.Class(q.Question[0].Qclass).String(), + } + } + + if a != nil { + status, _ := response.Typify(a, time.Now().UTC()) + jsonEntry["status"] = status.String() + } + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if a != nil && len(a.Answer) > 0 { + var answers = []map[string]interface{}{} + for _, k := range a.Answer { + header := k.Header() + answer := map[string]interface{}{ + "type": dns.TypeToString[header.Rrtype], + "ttl": header.Ttl, + } + // try most common record types + switch v := k.(type) { + case *dns.A: + answer["value"] = v.A + case *dns.AAAA: + answer["value"] = v.AAAA + case *dns.MX: + answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) + case *dns.CNAME: + answer["value"] = v.Target + case *dns.NS: + answer["value"] = v.Ns + case *dns.SPF: + answer["value"] = v.Txt + case *dns.TXT: + answer["value"] = v.Txt + case *dns.PTR: + answer["value"] = v.Ptr + case *dns.SOA: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) + case *dns.CAA: + answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) + case *dns.HINFO: + answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) + case *dns.RRSIG: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) + default: + // type unknown, marshall it as-is + answer["value"] = v + } + answers = append(answers, answer) + } + jsonEntry["answer"] = answers + } + + data = append(data, jsonEntry) + } + + jsonVal, err := json.Marshal(data) + if err != nil { + errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } +} + +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go new file mode 100644 index 0000000..a36812c --- /dev/null +++ b/coredns_plugin/querylog_file.go @@ -0,0 +1,291 @@ +package dnsfilter + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/go-test/deep" +) + +var ( + fileWriteLock sync.Mutex +) + +const enableGzip = false + +func flushToFile(buffer []*logEntry) error { + if len(buffer) == 0 { + return nil + } + start := time.Now() + + var b bytes.Buffer + e := json.NewEncoder(&b) + for _, entry := range buffer { + err := e.Encode(entry) + if err != nil { + log.Printf("Failed to marshal entry: %s", err) + return err + } + } + + elapsed := time.Since(start) + log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) + + err := checkBuffer(buffer, b) + if err != nil { + log.Printf("failed to check buffer: %s", err) + return err + } + + var zb bytes.Buffer + filename := queryLogFileName + + // gzip enabled? + if enableGzip { + filename += ".gz" + + zw := gzip.NewWriter(&zb) + zw.Name = queryLogFileName + zw.ModTime = time.Now() + + _, err = zw.Write(b.Bytes()) + if err != nil { + log.Printf("Couldn't compress to gzip: %s", err) + zw.Close() + return err + } + + if err = zw.Close(); err != nil { + log.Printf("Couldn't close gzip writer: %s", err) + return err + } + } else { + zb = b + } + + fileWriteLock.Lock() + defer fileWriteLock.Unlock() + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + log.Printf("failed to create file \"%s\": %s", filename, err) + return err + } + defer f.Close() + + n, err := f.Write(zb.Bytes()) + if err != nil { + log.Printf("Couldn't write to file: %s", err) + return err + } + + log.Printf("ok \"%s\": %v bytes written", filename, n) + + return nil +} + +func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { + l := len(buffer) + d := json.NewDecoder(&b) + + i := 0 + for d.More() { + entry := &logEntry{} + err := d.Decode(entry) + if err != nil { + log.Printf("Failed to decode: %s", err) + return err + } + if diff := deep.Equal(entry, buffer[i]); diff != nil { + log.Printf("decoded buffer differs: %s", diff) + return fmt.Errorf("decoded buffer differs: %s", diff) + } + i++ + } + if i != l { + err := fmt.Errorf("check fail: %d vs %d entries", l, i) + log.Print(err) + return err + } + log.Printf("check ok: %d entries", i) + + return nil +} + +func rotateQueryLog() error { + from := queryLogFileName + to := queryLogFileName + ".1" + + if enableGzip { + from = queryLogFileName + ".gz" + to = queryLogFileName + ".gz.1" + } + + if _, err := os.Stat(from); os.IsNotExist(err) { + // do nothing, file doesn't exist + return nil + } + + err := os.Rename(from, to) + if err != nil { + log.Printf("Failed to rename querylog: %s", err) + return err + } + + log.Printf("Rotated from %s to %s successfully", from, to) + + return nil +} + +func periodicQueryLogRotate() { + for range time.Tick(queryLogRotationPeriod) { + err := rotateQueryLog() + if err != nil { + log.Printf("Failed to rotate querylog: %s", err) + // do nothing, continue rotating + } + } +} + +func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { + now := time.Now() + // read from querylog files, try newest file first + files := []string{} + + if enableGzip { + files = []string{ + queryLogFileName + ".gz", + queryLogFileName + ".gz.1", + } + } else { + files = []string{ + queryLogFileName, + queryLogFileName + ".1", + } + } + + // read from all files + for _, file := range files { + if !needMore() { + break + } + if _, err := os.Stat(file); os.IsNotExist(err) { + // do nothing, file doesn't exist + continue + } + + f, err := os.Open(file) + if err != nil { + log.Printf("Failed to open file \"%s\": %s", file, err) + // try next file + continue + } + defer f.Close() + + var d *json.Decoder + + if enableGzip { + trace("Creating gzip reader") + zr, err := gzip.NewReader(f) + if err != nil { + log.Printf("Failed to create gzip reader: %s", err) + continue + } + defer zr.Close() + + trace("Creating json decoder") + d = json.NewDecoder(zr) + } else { + d = json.NewDecoder(f) + } + + i := 0 + over := 0 + max := 10000 * time.Second + var sum time.Duration + // entries on file are in oldest->newest order + // we want maxLen newest + for d.More() { + if !needMore() { + break + } + var entry logEntry + err := d.Decode(&entry) + if err != nil { + log.Printf("Failed to decode: %s", err) + // next entry can be fine, try more + continue + } + + if now.Sub(entry.Time) > timeWindow { + // trace("skipping entry") // debug logging + continue + } + + if entry.Elapsed > max { + over++ + } else { + sum += entry.Elapsed + } + + i++ + err = onEntry(&entry) + if err != nil { + return err + } + } + elapsed := time.Since(now) + var perunit time.Duration + var avg time.Duration + if i > 0 { + perunit = elapsed / time.Duration(i) + avg = sum / time.Duration(i) + } + log.Printf("file \"%s\": read %d entries in %v, %v/entry, %v over %v, %v avg", file, i, elapsed, perunit, over, max, avg) + } + return nil +} + +func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { + a := []*logEntry{} + + onEntry := func(entry *logEntry) error { + a = append(a, entry) + if len(a) > maxLen { + toskip := len(a) - maxLen + a = a[toskip:] + } + return nil + } + + needMore := func() bool { + return true + } + + err := genericLoader(onEntry, needMore, timeWindow) + if err != nil { + log.Printf("Failed to load entries from querylog: %s", err) + return values + } + + // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest + for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { + a[left], a[right] = a[right], a[left] + } + + // append it to values + values = append(values, a...) + + // then cut off of it is bigger than maxLen + if len(values) > maxLen { + values = values[:maxLen] + } + + return values +} diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go new file mode 100644 index 0000000..d4cc6e0 --- /dev/null +++ b/coredns_plugin/querylog_top.go @@ -0,0 +1,386 @@ +package dnsfilter + +import ( + "bytes" + "fmt" + "log" + "net/http" + "os" + "path" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/bluele/gcache" + "github.com/miekg/dns" +) + +type hourTop struct { + domains gcache.Cache + blocked gcache.Cache + clients gcache.Cache + + mutex sync.RWMutex +} + +func (top *hourTop) init() { + top.domains = gcache.New(queryLogTopSize).LRU().Build() + top.blocked = gcache.New(queryLogTopSize).LRU().Build() + top.clients = gcache.New(queryLogTopSize).LRU().Build() +} + +type dayTop struct { + hours []*hourTop + hoursLock sync.RWMutex // writelock this lock ONLY WHEN rotating or intializing hours! + + loaded bool + loadedLock sync.Mutex +} + +var runningTop dayTop + +func init() { + runningTop.hoursWriteLock() + for i := 0; i < 24; i++ { + hour := hourTop{} + hour.init() + runningTop.hours = append(runningTop.hours, &hour) + } + runningTop.hoursWriteUnlock() +} + +func rotateHourlyTop() { + log.Printf("Rotating hourly top") + hour := &hourTop{} + hour.init() + runningTop.hoursWriteLock() + runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) + runningTop.hours = runningTop.hours[:24] + runningTop.hoursWriteUnlock() +} + +func periodicHourlyTopRotate() { + t := time.Hour + for range time.Tick(t) { + rotateHourlyTop() + } +} + +func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { + top.Lock() + defer top.Unlock() + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + // we just set it and we're done + err = cache.Set(key, 1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return err + } + + cachedValue, ok := ivalue.(int) + if !ok { + err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return err + } + + err = cache.Set(key, cachedValue+1) + if err != nil { + log.Printf("Failed to set hourly top value: %s", err) + return err + } + return nil +} + +func (top *hourTop) incrementDomains(key string) error { + return top.incrementValue(key, top.domains) +} + +func (top *hourTop) incrementBlocked(key string) error { + return top.incrementValue(key, top.blocked) +} + +func (top *hourTop) incrementClients(key string) error { + return top.incrementValue(key, top.clients) +} + +// if does not exist -- return 0 +func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { + ivalue, err := cache.Get(key) + if err == gcache.KeyNotFoundError { + return 0, nil + } + + if err != nil { + log.Printf("gcache encountered an error during get: %s", err) + return 0, err + } + + value, ok := ivalue.(int) + if !ok { + err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) + log.Println(err) + return 0, err + } + + return value, nil +} + +func (top *hourTop) lockedGetDomains(key string) (int, error) { + return top.lockedGetValue(key, top.domains) +} + +func (top *hourTop) lockedGetBlocked(key string) (int, error) { + return top.lockedGetValue(key, top.blocked) +} + +func (top *hourTop) lockedGetClients(key string) (int, error) { + return top.lockedGetValue(key, top.clients) +} + +func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { + // figure out which hour bucket it belongs to + hour := int(now.Sub(entry.Time).Hours()) + if hour >= 24 { + log.Printf("t %v is >24 hours ago, ignoring", entry.Time) + return nil + } + + hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) + + // get value, if not set, crate one + runningTop.hoursReadLock() + defer runningTop.hoursReadUnlock() + err := runningTop.hours[hour].incrementDomains(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + + if entry.Result.IsFiltered { + err := runningTop.hours[hour].incrementBlocked(hostname) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + if len(entry.IP) > 0 { + err := runningTop.hours[hour].incrementClients(entry.IP) + if err != nil { + log.Printf("Failed to increment value: %s", err) + return err + } + } + + return nil +} + +func fillStatsFromQueryLog() error { + now := time.Now() + runningTop.loadedWriteLock() + defer runningTop.loadedWriteUnlock() + if runningTop.loaded { + return nil + } + onEntry := func(entry *logEntry) error { + if len(entry.Question) == 0 { + log.Printf("entry question is absent, skipping") + return nil + } + + if entry.Time.After(now) { + log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) + return nil + } + + q := new(dns.Msg) + if err := q.Unpack(entry.Question); err != nil { + log.Printf("failed to unpack dns message question: %s", err) + return err + } + + if len(q.Question) != 1 { + log.Printf("malformed dns message, has no questions, skipping") + return nil + } + + err := runningTop.addEntry(entry, q, now) + if err != nil { + log.Printf("Failed to add entry to running top: %s", err) + return err + } + + queryLogLock.Lock() + queryLogCache = append(queryLogCache, entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() + + requests.IncWithTime(entry.Time) + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) + } + elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) + + return nil + } + + needMore := func() bool { return true } + err := genericLoader(onEntry, needMore, queryLogTimeLimit) + if err != nil { + log.Printf("Failed to load entries from querylog: %s", err) + return err + } + + runningTop.loaded = true + + return nil +} + +func HandleStatsTop(w http.ResponseWriter, r *http.Request) { + domains := map[string]int{} + blocked := map[string]int{} + clients := map[string]int{} + + do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { + for _, ikey := range keys { + key, ok := ikey.(string) + if !ok { + continue + } + value, err := getter(key) + if err != nil { + log.Printf("Failed to get top domains value for %v: %s", key, err) + return + } + result[key] += value + } + } + + runningTop.hoursReadLock() + for hour := 0; hour < 24; hour++ { + runningTop.hours[hour].RLock() + do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) + do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) + do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) + runningTop.hours[hour].RUnlock() + } + runningTop.hoursReadUnlock() + + // use manual json marshalling because we want maps to be sorted by value + json := bytes.Buffer{} + json.WriteString("{\n") + + gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", name)) + json.WriteString(": {\n") + sorted := sortByValue(top) + // no more than 50 entries + if len(sorted) > 50 { + sorted = sorted[:50] + } + for i, key := range sorted { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", key)) + json.WriteString(": ") + json.WriteString(strconv.Itoa(top[key])) + if i+1 != len(sorted) { + json.WriteByte(',') + } + json.WriteByte('\n') + } + json.WriteString(" }") + if addComma { + json.WriteByte(',') + } + json.WriteByte('\n') + } + gen(&json, "top_queried_domains", domains, true) + gen(&json, "top_blocked_domains", blocked, true) + gen(&json, "top_clients", clients, true) + json.WriteString(" \"stats_period\": \"24 hours\"\n") + json.WriteString("}\n") + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(json.Bytes()) + if err != nil { + errortext := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errortext) + http.Error(w, errortext, http.StatusInternalServerError) + } +} + +// helper function for querylog API +func sortByValue(m map[string]int) []string { + type kv struct { + k string + v int + } + var ss []kv + for k, v := range m { + ss = append(ss, kv{k, v}) + } + sort.Slice(ss, func(l, r int) bool { + return ss[l].v > ss[r].v + }) + + sorted := []string{} + for _, v := range ss { + sorted = append(sorted, v.k) + } + return sorted +} + +func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } +func (d *dayTop) hoursWriteUnlock() { tracelock(); d.hoursLock.Unlock() } +func (d *dayTop) hoursReadLock() { tracelock(); d.hoursLock.RLock() } +func (d *dayTop) hoursReadUnlock() { tracelock(); d.hoursLock.RUnlock() } +func (d *dayTop) loadedWriteLock() { tracelock(); d.loadedLock.Lock() } +func (d *dayTop) loadedWriteUnlock() { tracelock(); d.loadedLock.Unlock() } + +func (h *hourTop) Lock() { tracelock(); h.mutex.Lock() } +func (h *hourTop) RLock() { tracelock(); h.mutex.RLock() } +func (h *hourTop) RUnlock() { tracelock(); h.mutex.RUnlock() } +func (h *hourTop) Unlock() { tracelock(); h.mutex.Unlock() } + +func tracelock() { + if false { // not commented out to make code checked during compilation + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := path.Base(runtime.FuncForPC(pc[1]).Name()) + lockf := path.Base(runtime.FuncForPC(pc[0]).Name()) + fmt.Fprintf(os.Stderr, "%s(): %s\n", f, lockf) + } +} diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go new file mode 100644 index 0000000..8d3eeec --- /dev/null +++ b/coredns_plugin/ratelimit/ratelimit.go @@ -0,0 +1,182 @@ +package ratelimit + +import ( + "errors" + "log" + "sort" + "strconv" + "time" + + // ratelimiting and per-ip buckets + "github.com/beefsack/go-rate" + "github.com/patrickmn/go-cache" + + // coredns plugin + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/request" + "github.com/mholt/caddy" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/context" +) + +const defaultRatelimit = 30 +const defaultResponseSize = 1000 + +var ( + tokenBuckets = cache.New(time.Hour, time.Hour) +) + +// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + ip := state.IP() + allow, err := p.allowRequest(ip) + if err != nil { + return 0, err + } + if !allow { + ratelimited.Inc() + return 0, nil + } + + // Record response to get status code and size of the reply. + rw := dnstest.NewRecorder(w) + status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r) + + size := rw.Len + + if size > defaultResponseSize && state.Proto() == "udp" { + // For large UDP responses we call allowRequest more times + // The exact number of times depends on the response size + for i := 0; i < size/defaultResponseSize; i++ { + p.allowRequest(ip) + } + } + + return status, err +} + +func (p *plug) allowRequest(ip string) (bool, error) { + if len(p.whitelist) > 0 { + i := sort.SearchStrings(p.whitelist, ip) + + if i < len(p.whitelist) && p.whitelist[i] == ip { + return true, nil + } + } + + if _, found := tokenBuckets.Get(ip); !found { + tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) + } + + value, found := tokenBuckets.Get(ip) + if !found { + // should not happen since we've just inserted it + text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared" + log.Println(text) + err := errors.New(text) + return true, err + } + + rl, ok := value.(*rate.RateLimiter) + if !ok { + text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache" + log.Println(text) + err := errors.New(text) + return true, err + } + + allow, _ := rl.Try() + return allow, nil +} + +// +// helper functions +// +func init() { + caddy.RegisterPlugin("ratelimit", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +type plug struct { + Next plugin.Handler + + // configuration for creating above + ratelimit int // in requests per second per IP + whitelist []string // a list of whitelisted IP addresses +} + +func setupPlugin(c *caddy.Controller) (*plug, error) { + p := &plug{ratelimit: defaultRatelimit} + + for c.Next() { + args := c.RemainingArgs() + if len(args) > 0 { + ratelimit, err := strconv.Atoi(args[0]) + if err != nil { + return nil, c.ArgErr() + } + p.ratelimit = ratelimit + } + for c.NextBlock() { + switch c.Val() { + case "whitelist": + p.whitelist = c.RemainingArgs() + + if len(p.whitelist) > 0 { + sort.Strings(p.whitelist) + } + } + } + } + + return p, nil +} + +func setup(c *caddy.Controller) error { + p, err := setupPlugin(c) + if err != nil { + return err + } + + config := dnsserver.GetConfig(c) + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + p.Next = next + return p + }) + + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(ratelimited) + } + return nil + }) + + return nil +} + +func newDNSCounter(name string, help string) prometheus.Counter { + return prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "ratelimit", + Name: name, + Help: help, + }) +} + +var ( + ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit") +) + +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "ratelimit" } diff --git a/coredns_plugin/ratelimit/ratelimit_test.go b/coredns_plugin/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..b426f2e --- /dev/null +++ b/coredns_plugin/ratelimit/ratelimit_test.go @@ -0,0 +1,80 @@ +package ratelimit + +import ( + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + for i, testcase := range []struct { + config string + failing bool + }{ + {`ratelimit`, false}, + {`ratelimit 100`, false}, + {`ratelimit { + whitelist 127.0.0.1 + }`, false}, + {`ratelimit 50 { + whitelist 127.0.0.1 176.103.130.130 + }`, false}, + {`ratelimit test`, true}, + } { + c := caddy.NewTestController("dns", testcase.config) + err := setup(c) + if err != nil { + if !testcase.failing { + t.Fatalf("Test #%d expected no errors, but got: %v", i, err) + } + continue + } + if testcase.failing { + t.Fatalf("Test #%d expected to fail but it didn't", i) + } + } +} + +func TestRatelimiting(t *testing.T) { + // rate limit is 1 per sec + c := caddy.NewTestController("dns", `ratelimit 1`) + p, err := setupPlugin(c) + + if err != nil { + t.Fatal("Failed to initialize the plugin") + } + + allowed, err := p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("First request must have been allowed") + } + + allowed, err = p.allowRequest("127.0.0.1") + + if err != nil || allowed { + t.Fatal("Second request must have been ratelimited") + } +} + +func TestWhitelist(t *testing.T) { + // rate limit is 1 per sec + c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`) + p, err := setupPlugin(c) + + if err != nil { + t.Fatal("Failed to initialize the plugin") + } + + allowed, err := p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("First request must have been allowed") + } + + allowed, err = p.allowRequest("127.0.0.1") + + if err != nil || !allowed { + t.Fatal("Second request must have been allowed due to whitelist") + } +} diff --git a/coredns_plugin/refuseany/refuseany.go b/coredns_plugin/refuseany/refuseany.go new file mode 100644 index 0000000..92d5d50 --- /dev/null +++ b/coredns_plugin/refuseany/refuseany.go @@ -0,0 +1,91 @@ +package refuseany + +import ( + "fmt" + "log" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/request" + "github.com/mholt/caddy" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/context" +) + +type plug struct { + Next plugin.Handler +} + +// ServeDNS handles the DNS request and refuses if it's an ANY request +func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if len(r.Question) != 1 { + // google DNS, bind and others do the same + return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions") + } + + q := r.Question[0] + if q.Qtype == dns.TypeANY { + state := request.Request{W: w, Req: r, Context: ctx} + rcode := dns.RcodeNotImplemented + + m := new(dns.Msg) + m.SetRcode(r, rcode) + state.SizeAndDo(m) + err := state.W.WriteMsg(m) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, err + } + return rcode, nil + } + + return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) +} + +func init() { + caddy.RegisterPlugin("refuseany", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + p := &plug{} + config := dnsserver.GetConfig(c) + + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + p.Next = next + return p + }) + + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(ratelimited) + } + return nil + }) + + return nil +} + +func newDNSCounter(name string, help string) prometheus.Counter { + return prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "refuseany", + Name: name, + Help: help, + }) +} + +var ( + ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped") +) + +// Name returns name of the plugin as seen in Corefile and plugin.cfg +func (p *plug) Name() string { return "refuseany" } diff --git a/coredns_plugin/reload.go b/coredns_plugin/reload.go new file mode 100644 index 0000000..880a3ac --- /dev/null +++ b/coredns_plugin/reload.go @@ -0,0 +1,36 @@ +package dnsfilter + +import ( + "log" + + "github.com/mholt/caddy" +) + +var Reload = make(chan bool) + +func hook(event caddy.EventName, info interface{}) error { + if event != caddy.InstanceStartupEvent { + return nil + } + + // this should be an instance. ok to panic if not + instance := info.(*caddy.Instance) + + go func() { + for range Reload { + corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) + if err != nil { + continue + } + _, err = instance.Restart(corefile) + if err != nil { + log.Printf("Corefile changed but reload failed: %s", err) + continue + } + // hook will be called again from new instance + return + } + }() + + return nil +} diff --git a/plugin.cfg b/plugin.cfg index b5a1784..259f5c4 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -26,9 +26,9 @@ pprof:pprof prometheus:metrics errors:errors log:log -ratelimit:github.com/AdguardTeam/AdguardDNS/coredns_plugin/ratelimit -refuseany:github.com/AdguardTeam/AdguardDNS/coredns_plugin/refuseany -dnsfilter:github.com/AdguardTeam/AdguardDNS/coredns_plugin +ratelimit:bit.adguard.com/dns/adguard-internal-dns/coredns_plugin/ratelimit +refuseany:bit.adguard.com/dns/adguard-internal-dns/coredns_plugin/refuseany +dnsfilter:bit.adguard.com/dns/adguard-internal-dns/coredns_plugin cache:cache template:template file:file