AdGuardDNS/lrucache/cache.go
2020-08-17 21:16:03 +03:00

264 lines
5.2 KiB
Go

package lrucache
import (
"encoding/binary"
"math"
"strings"
"sync"
"time"
"github.com/bluele/gcache"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
const defaultCacheSize = 1000 // in number of elements
type item struct {
m *dns.Msg // dns message
when time.Time // time when m was cached
}
type cache struct {
items gcache.Cache // cache
cacheSize int // cache size
sync.RWMutex // lock
}
func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) {
if request == nil {
return nil, false
}
// create key for request
ok, key := key(request)
if !ok {
clog.Debug("key returned !ok")
return nil, false
}
c.Lock()
if c.items == nil {
c.Unlock()
return nil, false
}
c.Unlock()
rawValue, err := c.items.Get(key)
if err == gcache.KeyNotFoundError {
// not a real error, just no key found
return nil, false
}
if err != nil {
// real error
clog.Errorf("can't get response for %s from cache: %s", request.Question[0].Name, err)
return nil, false
}
cachedValue, ok := rawValue.(item)
if !ok {
clog.Errorf("entry with invalid type in cache for %s", request.Question[0].Name)
return nil, false
}
response := cachedValue.fromItem(request)
return response, true
}
func (c *cache) Set(m *dns.Msg) {
if m == nil {
return // no-op
}
if !isCacheable(m) {
return
}
ok, key := key(m)
if !ok {
return
}
i := toItem(m)
c.Lock()
// lazy initialization for cache
if c.items == nil {
size := defaultCacheSize
if c.cacheSize > 0 {
size = c.cacheSize
}
c.items = gcache.New(size).LRU().Build()
}
c.Unlock()
// set ttl as expiration time for item
ttl := time.Duration(findLowestTTL(m)) * time.Second
err := c.items.SetWithExpire(key, i, ttl)
if err != nil {
clog.Warning("Couldn't set cache item")
}
}
// check if message is cacheable
func isCacheable(m *dns.Msg) bool {
// truncated messages aren't valid
if m.Truncated {
clog.Debug("Refusing to cache truncated message")
return false
}
// if has wrong number of questions, also don't cache
if len(m.Question) != 1 {
clog.Debugf("Refusing to cache message with wrong number of questions")
return false
}
qName := m.Question[0].Name
qType := m.Question[0].Qtype
ttl := findLowestTTL(m)
if ttl == 0 {
return false
}
if m.Rcode != dns.RcodeSuccess && m.Rcode != dns.RcodeNameError {
clog.Debugf("%s: refusing to cache message with response type %s", qName, dns.RcodeToString[m.Rcode])
return false
}
if m.Rcode == dns.RcodeSuccess && (qType == dns.TypeA || qType == dns.TypeAAAA) {
// Now verify that it contains at least one A or AAAA record
if len(m.Answer) == 0 {
clog.Debugf("%s: refusing to cache a NOERROR response with no answers", qName)
return false
}
found := false
for _, rr := range m.Answer {
if rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA {
found = true
break
}
}
if !found {
clog.Debugf("%s: refusing to cache a response with no A and AAAA answers", qName)
return false
}
}
return true
}
func findLowestTTL(m *dns.Msg) uint32 {
var ttl uint32 = math.MaxUint32
if m.Answer != nil {
for _, r := range m.Answer {
ttl = getTTLIfLower(r.Header(), ttl)
}
}
if m.Ns != nil {
for _, r := range m.Ns {
ttl = getTTLIfLower(r.Header(), ttl)
}
}
if m.Extra != nil {
for _, r := range m.Extra {
ttl = getTTLIfLower(r.Header(), ttl)
}
}
if ttl == math.MaxUint32 {
return 0
}
return ttl
}
func getTTLIfLower(h *dns.RR_Header, ttl uint32) uint32 {
if h.Rrtype == dns.TypeOPT {
return ttl
}
if h.Ttl < ttl {
return h.Ttl
}
return ttl
}
// key is binary little endian in sequence:
// uint8(do)
// uint16(qtype)
// uint16(qclass)
// name
func key(m *dns.Msg) (bool, string) {
if len(m.Question) != 1 {
clog.Debugf("got msg with len(m.Question) != 1: %d", len(m.Question))
return false, ""
}
q := m.Question[0]
b := make([]byte, 1+2+2+len(q.Name))
// put do
opt := m.IsEdns0()
do := false
if opt != nil {
do = opt.Do()
}
if do {
b[0] = 1
} else {
b[0] = 0
}
// put qtype, qclass, name
binary.BigEndian.PutUint16(b[1:], q.Qtype)
binary.BigEndian.PutUint16(b[3:], q.Qclass)
name := strings.ToLower(q.Name)
copy(b[5:], name)
return true, string(b)
}
func toItem(m *dns.Msg) item {
return item{
m: m,
when: time.Now(),
}
}
func (i *item) fromItem(request *dns.Msg) *dns.Msg {
response := &dns.Msg{}
response.SetReply(request)
response.Authoritative = false
response.AuthenticatedData = i.m.AuthenticatedData
response.RecursionAvailable = i.m.RecursionAvailable
response.Rcode = i.m.Rcode
ttl := findLowestTTL(i.m)
timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds())
var newttl uint32
if timeleft > 0 {
newttl = uint32(timeleft)
}
for _, r := range i.m.Answer {
answer := dns.Copy(r)
answer.Header().Ttl = newttl
response.Answer = append(response.Answer, answer)
}
for _, r := range i.m.Ns {
ns := dns.Copy(r)
ns.Header().Ttl = newttl
response.Ns = append(response.Ns, ns)
}
for _, r := range i.m.Extra {
// don't return OPT records as these are hop-by-hop
if r.Header().Rrtype == dns.TypeOPT {
continue
}
extra := dns.Copy(r)
extra.Header().Ttl = newttl
response.Extra = append(response.Extra, extra)
}
return response
}