diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 17f2257..5b09c91 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -184,25 +184,6 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } } - log.Printf("Loading stats from querylog") - err := fillStatsFromQueryLog() - if err != nil { - log.Printf("Failed to load stats from querylog: %s", err) - return nil, err - } - - if p.settings.QueryLogEnabled { - onceQueryLog.Do(func() { - go periodicQueryLogRotate() - go periodicHourlyTopRotate() - go statsRotator() - }) - } - - onceHook.Do(func() { - caddy.RegisterEventHook("dnsfilter-reload", hook) - }) - return p, nil } @@ -251,13 +232,6 @@ func (p *plug) onShutdown() error { } func (p *plug) onFinalShutdown() error { - logBufferLock.Lock() - err := flushToFile(logBuffer) - if err != nil { - log.Printf("failed to flush to file: %s", err) - return err - } - logBufferLock.Unlock() return nil } @@ -464,7 +438,6 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn } return rcode, result, err case dnsfilter.FilteredBlackList: - if result.Ip == nil { // return NXDomain rcode, err := p.writeNXdomain(ctx, w, r) @@ -511,7 +484,6 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( start := time.Now() requests.Inc() state := request.Request{W: w, Req: r} - ip := state.IP() // capture the written answer rrw := dnstest.NewRecorder(w) @@ -560,14 +532,8 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( // log elapsed := time.Since(start) elapsedTime.Observe(elapsed.Seconds()) - if p.settings.QueryLogEnabled { - logRequest(r, rrw.Msg, result, time.Since(start), ip) - } return rcode, err } // Name returns name of the plugin as seen in Corefile and plugin.cfg func (p *plug) Name() string { return "dnsfilter" } - -var onceHook sync.Once -var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go index b138911..2928513 100644 --- a/coredns_plugin/coredns_stats.go +++ b/coredns_plugin/coredns_stats.go @@ -1,10 +1,6 @@ package dnsfilter import ( - "encoding/json" - "fmt" - "log" - "net/http" "sync" "time" @@ -107,38 +103,6 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) { p.Unlock() } -func (p *periodicStats) statsRotate(now time.Time) { - p.Lock() - rotations := int64(now.Sub(p.LastRotate) / p.period) - if rotations > statsHistoryElements { - rotations = statsHistoryElements - } - // calculate how many times we should rotate - for r := int64(0); r < rotations; r++ { - for key, values := range p.Entries { - newValues := [statsHistoryElements]float64{} - for i := 1; i < len(values); i++ { - newValues[i] = values[i-1] - } - p.Entries[key] = newValues - } - } - if rotations > 0 { - p.LastRotate = now - } - p.Unlock() -} - -func statsRotator() { - for range time.Tick(time.Second) { - now := time.Now() - statistics.PerSecond.statsRotate(now) - statistics.PerMinute.statsRotate(now) - statistics.PerHour.statsRotate(now) - statistics.PerDay.statsRotate(now) - } -} - // counter that wraps around prometheus Counter but also adds to periodic stats type counter struct { name string // used as key in periodic stats @@ -223,188 +187,3 @@ func (h *histogram) Describe(ch chan<- *prometheus.Desc) { func (h *histogram) Collect(ch chan<- prometheus.Metric) { h.prom.Collect(ch) } - -// ----- -// stats -// ----- -func HandleStats(w http.ResponseWriter, r *http.Request) { - const numHours = 24 - histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) - // sum them up - summed := map[string]interface{}{} - for key, values := range histrical { - summedValue := 0.0 - floats, ok := values.([]float64) - if !ok { - continue - } - for _, v := range floats { - summedValue += v - } - summed[key] = summedValue - } - // don't forget to divide by number of elements in returned slice - if val, ok := summed["avg_processing_time"]; ok { - if flval, flok := val.(float64); flok { - flval /= numHours - summed["avg_processing_time"] = flval - } - } - - summed["stats_period"] = "24 hours" - - json, err := json.Marshal(summed) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { - // clamp - start = clamp(start, 0, statsHistoryElements) - end = clamp(end, 0, statsHistoryElements) - - avgProcessingTime := make([]float64, 0) - - count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) - sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) - for i := 0; i < len(count); i++ { - var avg float64 - if count[i] != 0 { - avg = sum[i] / count[i] - avg *= 1000 - } - avgProcessingTime = append(avgProcessingTime, avg) - } - - result := map[string]interface{}{ - "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), - "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), - "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), - "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), - "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), - "avg_processing_time": avgProcessingTime, - } - return result -} - -func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { - // handle time unit and prepare our time window size - now := time.Now() - timeUnitString := r.URL.Query().Get("time_unit") - var stats *periodicStats - var timeUnit time.Duration - switch timeUnitString { - case "seconds": - timeUnit = time.Second - stats = &statistics.PerSecond - case "minutes": - timeUnit = time.Minute - stats = &statistics.PerMinute - case "hours": - timeUnit = time.Hour - stats = &statistics.PerHour - case "days": - timeUnit = time.Hour * 24 - stats = &statistics.PerDay - default: - http.Error(w, "Must specify valid time_unit parameter", 400) - return - } - - // parse start and end time - startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - - // check if start and time times are within supported time range - timeRange := timeUnit * statsHistoryElements - if startTime.Add(timeRange).Before(now) { - http.Error(w, "start_time parameter is outside of supported range", 501) - return - } - if endTime.Add(timeRange).Before(now) { - http.Error(w, "end_time parameter is outside of supported range", 501) - return - } - - // calculate start and end of our array - // basically it's how many hours/minutes/etc have passed since now - start := int(now.Sub(endTime) / timeUnit) - end := int(now.Sub(startTime) / timeUnit) - - // swap them around if they're inverted - if start > end { - start, end = end, start - } - - data := generateMapFromStats(stats, start, end) - json, err := json.Marshal(data) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func HandleStatsReset(w http.ResponseWriter, r *http.Request) { - purgeStats() - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -func clamp(value, low, high int) int { - if value < low { - return low - } - if value > high { - return high - } - return value -} - -// -------------------------- -// helper functions for stats -// -------------------------- -func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { - output := make([]float64, 0) - for i := start; i <= end; i++ { - output = append([]float64{input[i]}, output...) - } - return output -} diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go deleted file mode 100644 index 92ba2d1..0000000 --- a/coredns_plugin/querylog.go +++ /dev/null @@ -1,239 +0,0 @@ -package dnsfilter - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/plugin/pkg/response" - "github.com/miekg/dns" -) - -const ( - logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk - queryLogTimeLimit = time.Hour * 24 // how far in the past we care about querylogs - queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours - queryLogFileName = "querylog.json" // .gz added during compression - queryLogSize = 5000 // maximum API response for /querylog - queryLogTopSize = 500 // Keep in memory only top N values -) - -var ( - logBufferLock sync.RWMutex - logBuffer []*logEntry - - queryLogCache []*logEntry - queryLogLock sync.RWMutex -) - -type logEntry struct { - Question []byte - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - Result dnsfilter.Result - Time time.Time - Elapsed time.Duration - IP string -} - -func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { - var q []byte - var a []byte - var err error - - if question != nil { - q, err = question.Pack() - if err != nil { - log.Printf("failed to pack question for querylog: %s", err) - return - } - } - if answer != nil { - a, err = answer.Pack() - if err != nil { - log.Printf("failed to pack answer for querylog: %s", err) - return - } - } - - now := time.Now() - entry := logEntry{ - Question: q, - Answer: a, - Result: result, - Time: now, - Elapsed: elapsed, - IP: ip, - } - var flushBuffer []*logEntry - - logBufferLock.Lock() - logBuffer = append(logBuffer, &entry) - if len(logBuffer) >= logBufferCap { - flushBuffer = logBuffer - logBuffer = nil - } - logBufferLock.Unlock() - queryLogLock.Lock() - queryLogCache = append(queryLogCache, &entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - // add it to running top - err = runningTop.addEntry(&entry, question, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - // don't do failure, just log - } - - // if buffer needs to be flushed to disk, do it now - if len(flushBuffer) > 0 { - // write to file - // do it in separate goroutine -- we are stalling DNS response this whole time - go flushToFile(flushBuffer) - } -} - -func HandleQueryLog(w http.ResponseWriter, r *http.Request) { - queryLogLock.RLock() - values := make([]*logEntry, len(queryLogCache)) - copy(values, queryLogCache) - queryLogLock.RUnlock() - - // reverse it so that newest is first - for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { - values[left], values[right] = values[right], values[left] - } - - var data = []map[string]interface{}{} - for _, entry := range values { - var q *dns.Msg - var a *dns.Msg - - if len(entry.Question) > 0 { - q = new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - q = nil - } - } - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339), - "client": entry.IP, - } - if q != nil { - jsonEntry["question"] = map[string]interface{}{ - "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), - "type": dns.Type(q.Question[0].Qtype).String(), - "class": dns.Class(q.Question[0].Qclass).String(), - } - } - - if a != nil { - status, _ := response.Typify(a, time.Now().UTC()) - jsonEntry["status"] = status.String() - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if a != nil && len(a.Answer) > 0 { - var answers = []map[string]interface{}{} - for _, k := range a.Answer { - header := k.Header() - answer := map[string]interface{}{ - "type": dns.TypeToString[header.Rrtype], - "ttl": header.Ttl, - } - // try most common record types - switch v := k.(type) { - case *dns.A: - answer["value"] = v.A - case *dns.AAAA: - answer["value"] = v.AAAA - case *dns.MX: - answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) - case *dns.CNAME: - answer["value"] = v.Target - case *dns.NS: - answer["value"] = v.Ns - case *dns.SPF: - answer["value"] = v.Txt - case *dns.TXT: - answer["value"] = v.Txt - case *dns.PTR: - answer["value"] = v.Ptr - case *dns.SOA: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) - case *dns.CAA: - answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) - case *dns.HINFO: - answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) - case *dns.RRSIG: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) - default: - // type unknown, marshall it as-is - answer["value"] = v - } - answers = append(answers, answer) - } - jsonEntry["answer"] = answers - } - - data = append(data, jsonEntry) - } - - jsonVal, err := json.Marshal(data) - if err != nil { - errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Fprint(os.Stderr, buf.String()) -} diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go deleted file mode 100644 index a36812c..0000000 --- a/coredns_plugin/querylog_file.go +++ /dev/null @@ -1,291 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "compress/gzip" - "encoding/json" - "fmt" - "log" - "os" - "sync" - "time" - - "github.com/go-test/deep" -) - -var ( - fileWriteLock sync.Mutex -) - -const enableGzip = false - -func flushToFile(buffer []*logEntry) error { - if len(buffer) == 0 { - return nil - } - start := time.Now() - - var b bytes.Buffer - e := json.NewEncoder(&b) - for _, entry := range buffer { - err := e.Encode(entry) - if err != nil { - log.Printf("Failed to marshal entry: %s", err) - return err - } - } - - elapsed := time.Since(start) - log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) - - err := checkBuffer(buffer, b) - if err != nil { - log.Printf("failed to check buffer: %s", err) - return err - } - - var zb bytes.Buffer - filename := queryLogFileName - - // gzip enabled? - if enableGzip { - filename += ".gz" - - zw := gzip.NewWriter(&zb) - zw.Name = queryLogFileName - zw.ModTime = time.Now() - - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Printf("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } - - if err = zw.Close(); err != nil { - log.Printf("Couldn't close gzip writer: %s", err) - return err - } - } else { - zb = b - } - - fileWriteLock.Lock() - defer fileWriteLock.Unlock() - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - log.Printf("failed to create file \"%s\": %s", filename, err) - return err - } - defer f.Close() - - n, err := f.Write(zb.Bytes()) - if err != nil { - log.Printf("Couldn't write to file: %s", err) - return err - } - - log.Printf("ok \"%s\": %v bytes written", filename, n) - - return nil -} - -func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { - l := len(buffer) - d := json.NewDecoder(&b) - - i := 0 - for d.More() { - entry := &logEntry{} - err := d.Decode(entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - return err - } - if diff := deep.Equal(entry, buffer[i]); diff != nil { - log.Printf("decoded buffer differs: %s", diff) - return fmt.Errorf("decoded buffer differs: %s", diff) - } - i++ - } - if i != l { - err := fmt.Errorf("check fail: %d vs %d entries", l, i) - log.Print(err) - return err - } - log.Printf("check ok: %d entries", i) - - return nil -} - -func rotateQueryLog() error { - from := queryLogFileName - to := queryLogFileName + ".1" - - if enableGzip { - from = queryLogFileName + ".gz" - to = queryLogFileName + ".gz.1" - } - - if _, err := os.Stat(from); os.IsNotExist(err) { - // do nothing, file doesn't exist - return nil - } - - err := os.Rename(from, to) - if err != nil { - log.Printf("Failed to rename querylog: %s", err) - return err - } - - log.Printf("Rotated from %s to %s successfully", from, to) - - return nil -} - -func periodicQueryLogRotate() { - for range time.Tick(queryLogRotationPeriod) { - err := rotateQueryLog() - if err != nil { - log.Printf("Failed to rotate querylog: %s", err) - // do nothing, continue rotating - } - } -} - -func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { - now := time.Now() - // read from querylog files, try newest file first - files := []string{} - - if enableGzip { - files = []string{ - queryLogFileName + ".gz", - queryLogFileName + ".gz.1", - } - } else { - files = []string{ - queryLogFileName, - queryLogFileName + ".1", - } - } - - // read from all files - for _, file := range files { - if !needMore() { - break - } - if _, err := os.Stat(file); os.IsNotExist(err) { - // do nothing, file doesn't exist - continue - } - - f, err := os.Open(file) - if err != nil { - log.Printf("Failed to open file \"%s\": %s", file, err) - // try next file - continue - } - defer f.Close() - - var d *json.Decoder - - if enableGzip { - trace("Creating gzip reader") - zr, err := gzip.NewReader(f) - if err != nil { - log.Printf("Failed to create gzip reader: %s", err) - continue - } - defer zr.Close() - - trace("Creating json decoder") - d = json.NewDecoder(zr) - } else { - d = json.NewDecoder(f) - } - - i := 0 - over := 0 - max := 10000 * time.Second - var sum time.Duration - // entries on file are in oldest->newest order - // we want maxLen newest - for d.More() { - if !needMore() { - break - } - var entry logEntry - err := d.Decode(&entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - // next entry can be fine, try more - continue - } - - if now.Sub(entry.Time) > timeWindow { - // trace("skipping entry") // debug logging - continue - } - - if entry.Elapsed > max { - over++ - } else { - sum += entry.Elapsed - } - - i++ - err = onEntry(&entry) - if err != nil { - return err - } - } - elapsed := time.Since(now) - var perunit time.Duration - var avg time.Duration - if i > 0 { - perunit = elapsed / time.Duration(i) - avg = sum / time.Duration(i) - } - log.Printf("file \"%s\": read %d entries in %v, %v/entry, %v over %v, %v avg", file, i, elapsed, perunit, over, max, avg) - } - return nil -} - -func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { - a := []*logEntry{} - - onEntry := func(entry *logEntry) error { - a = append(a, entry) - if len(a) > maxLen { - toskip := len(a) - maxLen - a = a[toskip:] - } - return nil - } - - needMore := func() bool { - return true - } - - err := genericLoader(onEntry, needMore, timeWindow) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return values - } - - // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest - for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { - a[left], a[right] = a[right], a[left] - } - - // append it to values - values = append(values, a...) - - // then cut off of it is bigger than maxLen - if len(values) > maxLen { - values = values[:maxLen] - } - - return values -} diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go deleted file mode 100644 index d4cc6e0..0000000 --- a/coredns_plugin/querylog_top.go +++ /dev/null @@ -1,386 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/bluele/gcache" - "github.com/miekg/dns" -) - -type hourTop struct { - domains gcache.Cache - blocked gcache.Cache - clients gcache.Cache - - mutex sync.RWMutex -} - -func (top *hourTop) init() { - top.domains = gcache.New(queryLogTopSize).LRU().Build() - top.blocked = gcache.New(queryLogTopSize).LRU().Build() - top.clients = gcache.New(queryLogTopSize).LRU().Build() -} - -type dayTop struct { - hours []*hourTop - hoursLock sync.RWMutex // writelock this lock ONLY WHEN rotating or intializing hours! - - loaded bool - loadedLock sync.Mutex -} - -var runningTop dayTop - -func init() { - runningTop.hoursWriteLock() - for i := 0; i < 24; i++ { - hour := hourTop{} - hour.init() - runningTop.hours = append(runningTop.hours, &hour) - } - runningTop.hoursWriteUnlock() -} - -func rotateHourlyTop() { - log.Printf("Rotating hourly top") - hour := &hourTop{} - hour.init() - runningTop.hoursWriteLock() - runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) - runningTop.hours = runningTop.hours[:24] - runningTop.hoursWriteUnlock() -} - -func periodicHourlyTopRotate() { - t := time.Hour - for range time.Tick(t) { - rotateHourlyTop() - } -} - -func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { - top.Lock() - defer top.Unlock() - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - // we just set it and we're done - err = cache.Set(key, 1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return err - } - - cachedValue, ok := ivalue.(int) - if !ok { - err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return err - } - - err = cache.Set(key, cachedValue+1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil -} - -func (top *hourTop) incrementDomains(key string) error { - return top.incrementValue(key, top.domains) -} - -func (top *hourTop) incrementBlocked(key string) error { - return top.incrementValue(key, top.blocked) -} - -func (top *hourTop) incrementClients(key string) error { - return top.incrementValue(key, top.clients) -} - -// if does not exist -- return 0 -func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - return 0, nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return 0, err - } - - value, ok := ivalue.(int) - if !ok { - err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return 0, err - } - - return value, nil -} - -func (top *hourTop) lockedGetDomains(key string) (int, error) { - return top.lockedGetValue(key, top.domains) -} - -func (top *hourTop) lockedGetBlocked(key string) (int, error) { - return top.lockedGetValue(key, top.blocked) -} - -func (top *hourTop) lockedGetClients(key string) (int, error) { - return top.lockedGetValue(key, top.clients) -} - -func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { - // figure out which hour bucket it belongs to - hour := int(now.Sub(entry.Time).Hours()) - if hour >= 24 { - log.Printf("t %v is >24 hours ago, ignoring", entry.Time) - return nil - } - - hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) - - // get value, if not set, crate one - runningTop.hoursReadLock() - defer runningTop.hoursReadUnlock() - err := runningTop.hours[hour].incrementDomains(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - - if entry.Result.IsFiltered { - err := runningTop.hours[hour].incrementBlocked(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - if len(entry.IP) > 0 { - err := runningTop.hours[hour].incrementClients(entry.IP) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - return nil -} - -func fillStatsFromQueryLog() error { - now := time.Now() - runningTop.loadedWriteLock() - defer runningTop.loadedWriteUnlock() - if runningTop.loaded { - return nil - } - onEntry := func(entry *logEntry) error { - if len(entry.Question) == 0 { - log.Printf("entry question is absent, skipping") - return nil - } - - if entry.Time.After(now) { - log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) - return nil - } - - q := new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - log.Printf("failed to unpack dns message question: %s", err) - return err - } - - if len(q.Question) != 1 { - log.Printf("malformed dns message, has no questions, skipping") - return nil - } - - err := runningTop.addEntry(entry, q, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - return err - } - - queryLogLock.Lock() - queryLogCache = append(queryLogCache, entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - requests.IncWithTime(entry.Time) - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) - } - elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) - - return nil - } - - needMore := func() bool { return true } - err := genericLoader(onEntry, needMore, queryLogTimeLimit) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return err - } - - runningTop.loaded = true - - return nil -} - -func HandleStatsTop(w http.ResponseWriter, r *http.Request) { - domains := map[string]int{} - blocked := map[string]int{} - clients := map[string]int{} - - do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { - for _, ikey := range keys { - key, ok := ikey.(string) - if !ok { - continue - } - value, err := getter(key) - if err != nil { - log.Printf("Failed to get top domains value for %v: %s", key, err) - return - } - result[key] += value - } - } - - runningTop.hoursReadLock() - for hour := 0; hour < 24; hour++ { - runningTop.hours[hour].RLock() - do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) - do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) - do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) - runningTop.hours[hour].RUnlock() - } - runningTop.hoursReadUnlock() - - // use manual json marshalling because we want maps to be sorted by value - json := bytes.Buffer{} - json.WriteString("{\n") - - gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", name)) - json.WriteString(": {\n") - sorted := sortByValue(top) - // no more than 50 entries - if len(sorted) > 50 { - sorted = sorted[:50] - } - for i, key := range sorted { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", key)) - json.WriteString(": ") - json.WriteString(strconv.Itoa(top[key])) - if i+1 != len(sorted) { - json.WriteByte(',') - } - json.WriteByte('\n') - } - json.WriteString(" }") - if addComma { - json.WriteByte(',') - } - json.WriteByte('\n') - } - gen(&json, "top_queried_domains", domains, true) - gen(&json, "top_blocked_domains", blocked, true) - gen(&json, "top_clients", clients, true) - json.WriteString(" \"stats_period\": \"24 hours\"\n") - json.WriteString("}\n") - - w.Header().Set("Content-Type", "application/json") - _, err := w.Write(json.Bytes()) - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -// helper function for querylog API -func sortByValue(m map[string]int) []string { - type kv struct { - k string - v int - } - var ss []kv - for k, v := range m { - ss = append(ss, kv{k, v}) - } - sort.Slice(ss, func(l, r int) bool { - return ss[l].v > ss[r].v - }) - - sorted := []string{} - for _, v := range ss { - sorted = append(sorted, v.k) - } - return sorted -} - -func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } -func (d *dayTop) hoursWriteUnlock() { tracelock(); d.hoursLock.Unlock() } -func (d *dayTop) hoursReadLock() { tracelock(); d.hoursLock.RLock() } -func (d *dayTop) hoursReadUnlock() { tracelock(); d.hoursLock.RUnlock() } -func (d *dayTop) loadedWriteLock() { tracelock(); d.loadedLock.Lock() } -func (d *dayTop) loadedWriteUnlock() { tracelock(); d.loadedLock.Unlock() } - -func (h *hourTop) Lock() { tracelock(); h.mutex.Lock() } -func (h *hourTop) RLock() { tracelock(); h.mutex.RLock() } -func (h *hourTop) RUnlock() { tracelock(); h.mutex.RUnlock() } -func (h *hourTop) Unlock() { tracelock(); h.mutex.Unlock() } - -func tracelock() { - if false { // not commented out to make code checked during compilation - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := path.Base(runtime.FuncForPC(pc[1]).Name()) - lockf := path.Base(runtime.FuncForPC(pc[0]).Name()) - fmt.Fprintf(os.Stderr, "%s(): %s\n", f, lockf) - } -} diff --git a/coredns_plugin/reload.go b/coredns_plugin/reload.go deleted file mode 100644 index 880a3ac..0000000 --- a/coredns_plugin/reload.go +++ /dev/null @@ -1,36 +0,0 @@ -package dnsfilter - -import ( - "log" - - "github.com/mholt/caddy" -) - -var Reload = make(chan bool) - -func hook(event caddy.EventName, info interface{}) error { - if event != caddy.InstanceStartupEvent { - return nil - } - - // this should be an instance. ok to panic if not - instance := info.(*caddy.Instance) - - go func() { - for range Reload { - corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) - if err != nil { - continue - } - _, err = instance.Restart(corefile) - if err != nil { - log.Printf("Corefile changed but reload failed: %s", err) - continue - } - // hook will be called again from new instance - return - } - }() - - return nil -}