Andrey Meshkov f1791135af Sync v2.11.0
2024-12-05 14:19:25 +03:00

354 lines
8.9 KiB
Go

package dnscheck
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/netip"
"slices"
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv/consulkv"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
cache "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
)
// RemoteKV is the RemoteKV KV based DNS checker.
type RemoteKV struct {
logger *slog.Logger
// mu protects cache. Don't use an RWMutex here, since it is expected that
// there are about as many reads as there are writes.
mu *sync.Mutex
cache *cache.Cache
kv remotekv.Interface
messages *dnsmsg.Constructor
errColl errcoll.Interface
domains []string
nodeLocation string
nodeName string
ipv4 []netip.Addr
ipv6 []netip.Addr
}
// RemoteKVConfig is the configuration structure for remote KV based DNS
// checker. All fields must be non-empty.
type RemoteKVConfig struct {
// Logger is used to log the operation of the DNS checker.
Logger *slog.Logger
// Messages is the message constructor used to create DNS responses with
// IPv4 and IPv6 IPs.
Messages *dnsmsg.Constructor
// RemoteKV for DNS server checking.
RemoteKV remotekv.Interface
// ErrColl is the error collector that is used to collect non-critical
// errors.
ErrColl errcoll.Interface
// Domains are the lower-cased domain names used to detect DNS check requests.
Domains []string
// NodeLocation is the location of this server node.
NodeLocation string
// NodeName is the name of this server node.
NodeName string
// IPv4 are the IPv4 addresses to respond with to A requests.
IPv4 []netip.Addr
// IPv6 are the IPv6 addresses to respond with to AAAA requests.
IPv6 []netip.Addr
}
// Default cache parameters.
//
// TODO(ameshkov): Consider making configurable.
const (
defaultCacheExp = 1 * time.Minute
defaultCacheGC = 1 * time.Minute
)
// NewRemoteKV creates a new remote KV based DNS checker. c must be non-nil.
func NewRemoteKV(c *RemoteKVConfig) (dc *RemoteKV) {
return &RemoteKV{
logger: c.Logger,
mu: &sync.Mutex{},
cache: cache.New(defaultCacheExp, defaultCacheGC),
kv: c.RemoteKV,
messages: c.Messages,
errColl: c.ErrColl,
domains: c.Domains,
nodeLocation: c.NodeLocation,
nodeName: c.NodeName,
ipv4: slices.Clone(c.IPv4),
ipv6: slices.Clone(c.IPv6),
}
}
// type check
var _ Interface = (*RemoteKV)(nil)
// Check implements the Interface interface for *RemoteKV.
func (dc *RemoteKV) Check(
ctx context.Context,
req *dns.Msg,
ri *agd.RequestInfo,
) (resp *dns.Msg, err error) {
var matched bool
defer func() {
incErrMetrics("dns", err)
if !matched {
return
}
metrics.DNSCheckRequestTotal.With(prometheus.Labels{
"type": "dns",
"valid": metrics.BoolString(err == nil),
}).Inc()
}()
var randomID string
randomID, matched, err = randomIDFromDomain(ri.Host, dc.domains)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
} else if !matched {
// Not a dnscheck domain, just ignore the request.
return nil, nil
} else if randomID == "" {
return dc.resp(ri, req)
}
inf := dc.newInfo(ri)
b, err := json.Marshal(inf)
if err != nil {
return nil, fmt.Errorf("encoding value for key %q for remote kv: %w", randomID, err)
}
dc.addToCache(randomID, b)
err = dc.kv.Set(ctx, randomID, b)
if err != nil {
errcoll.Collect(ctx, dc.errColl, dc.logger, "remote kv setting", err)
}
return dc.resp(ri, req)
}
// addToCache adds inf into cache using randomID as key. It's safe for
// concurrent use.
func (dc *RemoteKV) addToCache(randomID string, inf []byte) {
dc.mu.Lock()
defer dc.mu.Unlock()
dc.cache.SetDefault(randomID, inf)
}
// serverType is a type for the enum of server types in the DNS checker HTTP
// API.
type serverType string
// Valid serverType values.
const (
serverTypePrivate serverType = "private"
serverTypePublic serverType = "public"
)
// newInfo returns an information record with all available data about the
// server and the request. ri must not be nil.
func (dc *RemoteKV) newInfo(ri *agd.RequestInfo) (inf *info) {
g := ri.ServerGroup
srvType := serverTypePublic
if g.ProfilesEnabled {
srvType = serverTypePrivate
}
inf = &info{
ServerGroupName: g.Name,
ServerName: ri.Server,
ServerType: srvType,
Protocol: ri.Proto.String(),
NodeLocation: dc.nodeLocation,
NodeName: dc.nodeName,
ClientIP: ri.RemoteIP,
}
if p, d := ri.DeviceData(); p != nil {
inf.ProfileID = p.ID
inf.DeviceID = d.ID
}
return inf
}
// resp returns the corresponding response.
//
// TODO(e.burkov): Inspect the reason for using different message constructors
// for different DNS types, and consider using only one of them.
func (dc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) {
qt := ri.QType
if qt != dns.TypeA && qt != dns.TypeAAAA {
return ri.Messages.NewRespRCode(req, dns.RcodeSuccess), nil
}
if qt == dns.TypeA {
return dc.messages.NewRespIP(req, dc.ipv4...)
}
return dc.messages.NewRespIP(req, dc.ipv6...)
}
// type check
var _ http.Handler = (*RemoteKV)(nil)
// ServeHTTP implements the http.Handler interface for *RemoteKV.
//
// TODO(a.garipov): Consider using the websvc logger once it switches to
// log/slog.
func (dc *RemoteKV) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// TODO(a.garipov): Put this into constant here and in package dnssvc.
if r.URL.Path == "/dnscheck/test" {
dc.serveCheckTest(r.Context(), w, r)
return
}
http.NotFound(w, r)
}
// serveCheckTest serves the client DNS check API.
//
// TODO(a.garipov): Refactor this and other HTTP handlers to return wrapped
// errors and centralize the error handling.
func (dc *RemoteKV) serveCheckTest(ctx context.Context, w http.ResponseWriter, r *http.Request) {
l := dc.logger.With("raddr", r.RemoteAddr)
host, err := netutil.SplitHost(r.Host)
if err != nil {
l.DebugContext(ctx, "bad host", "hostport", r.Host, slogutil.KeyError, err)
http.NotFound(w, r)
return
}
randomID, matched, err := randomIDFromDomain(host, dc.domains)
if err != nil {
l.DebugContext(ctx, "bad request", "host", host, slogutil.KeyError, err)
http.NotFound(w, r)
return
} else if !matched || randomID == "" {
// We expect dnscheck requests to have a unique ID in the domain name.
l.DebugContext(ctx, "bad domain", "host", host, slogutil.KeyError, err)
http.NotFound(w, r)
return
}
inf, ok, err := dc.info(ctx, randomID)
// TODO(s.chzhen): Use error interface instead of error value.
if errors.Is(err, consulkv.ErrRateLimited) {
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
} else if err != nil {
l.DebugContext(ctx, "getting info", "random_id", randomID, slogutil.KeyError, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} else if !ok {
l.DebugContext(ctx, "no info", "random_id", randomID, slogutil.KeyError, err)
http.NotFound(w, r)
return
}
h := w.Header()
h.Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON)
h.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard)
_, err = w.Write(inf)
if err != nil {
errcoll.Collect(ctx, dc.errColl, dc.logger, "http resp write", err)
}
}
// info returns an information record by the random request ID.
func (dc *RemoteKV) info(ctx context.Context, randomID string) (inf []byte, ok bool, err error) {
defer func() {
metrics.DNSCheckRequestTotal.With(prometheus.Labels{
"type": "http",
"valid": metrics.BoolString(err == nil),
}).Inc()
incErrMetrics("http", err)
}()
defer func() { err = errors.Annotate(err, "getting from remote kv: %w") }()
dc.mu.Lock()
defer dc.mu.Unlock()
infoVal, ok := dc.cache.Get(randomID)
if ok {
return infoVal.([]byte), true, nil
}
inf, ok, err = dc.kv.Get(ctx, randomID)
if err != nil {
errcoll.Collect(ctx, dc.errColl, dc.logger, "remote kv getting", err)
// Don't wrap the error, as it will get annotated.
return nil, false, err
}
return inf, ok, nil
}
// info is a single DNS client and server information record.
type info struct {
ClientIP netip.Addr `json:"client_ip"`
DeviceID agd.DeviceID `json:"device_id"`
ProfileID agd.ProfileID `json:"profile_id"`
ServerGroupName agd.ServerGroupName `json:"server_group_name"`
ServerName agd.ServerName `json:"server_name"`
ServerType serverType `json:"server_type"`
Protocol string `json:"protocol"`
NodeLocation string `json:"node_location"`
NodeName string `json:"node_name"`
}