mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
354 lines
8.9 KiB
Go
354 lines
8.9 KiB
Go
package dnscheck
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/netip"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv/consulkv"
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
"github.com/AdguardTeam/golibs/httphdr"
|
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
|
"github.com/AdguardTeam/golibs/netutil"
|
|
"github.com/miekg/dns"
|
|
cache "github.com/patrickmn/go-cache"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
// RemoteKV is the RemoteKV KV based DNS checker.
|
|
type RemoteKV struct {
|
|
logger *slog.Logger
|
|
|
|
// mu protects cache. Don't use an RWMutex here, since it is expected that
|
|
// there are about as many reads as there are writes.
|
|
mu *sync.Mutex
|
|
cache *cache.Cache
|
|
|
|
kv remotekv.Interface
|
|
messages *dnsmsg.Constructor
|
|
|
|
errColl errcoll.Interface
|
|
|
|
domains []string
|
|
nodeLocation string
|
|
nodeName string
|
|
|
|
ipv4 []netip.Addr
|
|
ipv6 []netip.Addr
|
|
}
|
|
|
|
// RemoteKVConfig is the configuration structure for remote KV based DNS
|
|
// checker. All fields must be non-empty.
|
|
type RemoteKVConfig struct {
|
|
// Logger is used to log the operation of the DNS checker.
|
|
Logger *slog.Logger
|
|
|
|
// Messages is the message constructor used to create DNS responses with
|
|
// IPv4 and IPv6 IPs.
|
|
Messages *dnsmsg.Constructor
|
|
|
|
// RemoteKV for DNS server checking.
|
|
RemoteKV remotekv.Interface
|
|
|
|
// ErrColl is the error collector that is used to collect non-critical
|
|
// errors.
|
|
ErrColl errcoll.Interface
|
|
|
|
// Domains are the lower-cased domain names used to detect DNS check requests.
|
|
Domains []string
|
|
|
|
// NodeLocation is the location of this server node.
|
|
NodeLocation string
|
|
|
|
// NodeName is the name of this server node.
|
|
NodeName string
|
|
|
|
// IPv4 are the IPv4 addresses to respond with to A requests.
|
|
IPv4 []netip.Addr
|
|
|
|
// IPv6 are the IPv6 addresses to respond with to AAAA requests.
|
|
IPv6 []netip.Addr
|
|
}
|
|
|
|
// Default cache parameters.
|
|
//
|
|
// TODO(ameshkov): Consider making configurable.
|
|
const (
|
|
defaultCacheExp = 1 * time.Minute
|
|
defaultCacheGC = 1 * time.Minute
|
|
)
|
|
|
|
// NewRemoteKV creates a new remote KV based DNS checker. c must be non-nil.
|
|
func NewRemoteKV(c *RemoteKVConfig) (dc *RemoteKV) {
|
|
return &RemoteKV{
|
|
logger: c.Logger,
|
|
mu: &sync.Mutex{},
|
|
cache: cache.New(defaultCacheExp, defaultCacheGC),
|
|
kv: c.RemoteKV,
|
|
messages: c.Messages,
|
|
errColl: c.ErrColl,
|
|
domains: c.Domains,
|
|
nodeLocation: c.NodeLocation,
|
|
nodeName: c.NodeName,
|
|
ipv4: slices.Clone(c.IPv4),
|
|
ipv6: slices.Clone(c.IPv6),
|
|
}
|
|
}
|
|
|
|
// type check
|
|
var _ Interface = (*RemoteKV)(nil)
|
|
|
|
// Check implements the Interface interface for *RemoteKV.
|
|
func (dc *RemoteKV) Check(
|
|
ctx context.Context,
|
|
req *dns.Msg,
|
|
ri *agd.RequestInfo,
|
|
) (resp *dns.Msg, err error) {
|
|
var matched bool
|
|
defer func() {
|
|
incErrMetrics("dns", err)
|
|
|
|
if !matched {
|
|
return
|
|
}
|
|
|
|
metrics.DNSCheckRequestTotal.With(prometheus.Labels{
|
|
"type": "dns",
|
|
"valid": metrics.BoolString(err == nil),
|
|
}).Inc()
|
|
}()
|
|
|
|
var randomID string
|
|
randomID, matched, err = randomIDFromDomain(ri.Host, dc.domains)
|
|
if err != nil {
|
|
// Don't wrap the error, because it's informative enough as is.
|
|
return nil, err
|
|
} else if !matched {
|
|
// Not a dnscheck domain, just ignore the request.
|
|
return nil, nil
|
|
} else if randomID == "" {
|
|
return dc.resp(ri, req)
|
|
}
|
|
|
|
inf := dc.newInfo(ri)
|
|
b, err := json.Marshal(inf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("encoding value for key %q for remote kv: %w", randomID, err)
|
|
}
|
|
|
|
dc.addToCache(randomID, b)
|
|
|
|
err = dc.kv.Set(ctx, randomID, b)
|
|
if err != nil {
|
|
errcoll.Collect(ctx, dc.errColl, dc.logger, "remote kv setting", err)
|
|
}
|
|
|
|
return dc.resp(ri, req)
|
|
}
|
|
|
|
// addToCache adds inf into cache using randomID as key. It's safe for
|
|
// concurrent use.
|
|
func (dc *RemoteKV) addToCache(randomID string, inf []byte) {
|
|
dc.mu.Lock()
|
|
defer dc.mu.Unlock()
|
|
|
|
dc.cache.SetDefault(randomID, inf)
|
|
}
|
|
|
|
// serverType is a type for the enum of server types in the DNS checker HTTP
|
|
// API.
|
|
type serverType string
|
|
|
|
// Valid serverType values.
|
|
const (
|
|
serverTypePrivate serverType = "private"
|
|
serverTypePublic serverType = "public"
|
|
)
|
|
|
|
// newInfo returns an information record with all available data about the
|
|
// server and the request. ri must not be nil.
|
|
func (dc *RemoteKV) newInfo(ri *agd.RequestInfo) (inf *info) {
|
|
g := ri.ServerGroup
|
|
|
|
srvType := serverTypePublic
|
|
if g.ProfilesEnabled {
|
|
srvType = serverTypePrivate
|
|
}
|
|
|
|
inf = &info{
|
|
ServerGroupName: g.Name,
|
|
ServerName: ri.Server,
|
|
ServerType: srvType,
|
|
|
|
Protocol: ri.Proto.String(),
|
|
NodeLocation: dc.nodeLocation,
|
|
NodeName: dc.nodeName,
|
|
|
|
ClientIP: ri.RemoteIP,
|
|
}
|
|
|
|
if p, d := ri.DeviceData(); p != nil {
|
|
inf.ProfileID = p.ID
|
|
inf.DeviceID = d.ID
|
|
}
|
|
|
|
return inf
|
|
}
|
|
|
|
// resp returns the corresponding response.
|
|
//
|
|
// TODO(e.burkov): Inspect the reason for using different message constructors
|
|
// for different DNS types, and consider using only one of them.
|
|
func (dc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) {
|
|
qt := ri.QType
|
|
|
|
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
|
return ri.Messages.NewRespRCode(req, dns.RcodeSuccess), nil
|
|
}
|
|
|
|
if qt == dns.TypeA {
|
|
return dc.messages.NewRespIP(req, dc.ipv4...)
|
|
}
|
|
|
|
return dc.messages.NewRespIP(req, dc.ipv6...)
|
|
}
|
|
|
|
// type check
|
|
var _ http.Handler = (*RemoteKV)(nil)
|
|
|
|
// ServeHTTP implements the http.Handler interface for *RemoteKV.
|
|
//
|
|
// TODO(a.garipov): Consider using the websvc logger once it switches to
|
|
// log/slog.
|
|
func (dc *RemoteKV) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// TODO(a.garipov): Put this into constant here and in package dnssvc.
|
|
if r.URL.Path == "/dnscheck/test" {
|
|
dc.serveCheckTest(r.Context(), w, r)
|
|
|
|
return
|
|
}
|
|
|
|
http.NotFound(w, r)
|
|
}
|
|
|
|
// serveCheckTest serves the client DNS check API.
|
|
//
|
|
// TODO(a.garipov): Refactor this and other HTTP handlers to return wrapped
|
|
// errors and centralize the error handling.
|
|
func (dc *RemoteKV) serveCheckTest(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
l := dc.logger.With("raddr", r.RemoteAddr)
|
|
|
|
host, err := netutil.SplitHost(r.Host)
|
|
if err != nil {
|
|
l.DebugContext(ctx, "bad host", "hostport", r.Host, slogutil.KeyError, err)
|
|
|
|
http.NotFound(w, r)
|
|
|
|
return
|
|
}
|
|
|
|
randomID, matched, err := randomIDFromDomain(host, dc.domains)
|
|
if err != nil {
|
|
l.DebugContext(ctx, "bad request", "host", host, slogutil.KeyError, err)
|
|
|
|
http.NotFound(w, r)
|
|
|
|
return
|
|
} else if !matched || randomID == "" {
|
|
// We expect dnscheck requests to have a unique ID in the domain name.
|
|
l.DebugContext(ctx, "bad domain", "host", host, slogutil.KeyError, err)
|
|
|
|
http.NotFound(w, r)
|
|
|
|
return
|
|
}
|
|
|
|
inf, ok, err := dc.info(ctx, randomID)
|
|
// TODO(s.chzhen): Use error interface instead of error value.
|
|
if errors.Is(err, consulkv.ErrRateLimited) {
|
|
http.Error(w, err.Error(), http.StatusTooManyRequests)
|
|
|
|
return
|
|
} else if err != nil {
|
|
l.DebugContext(ctx, "getting info", "random_id", randomID, slogutil.KeyError, err)
|
|
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
} else if !ok {
|
|
l.DebugContext(ctx, "no info", "random_id", randomID, slogutil.KeyError, err)
|
|
|
|
http.NotFound(w, r)
|
|
|
|
return
|
|
}
|
|
|
|
h := w.Header()
|
|
h.Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON)
|
|
h.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard)
|
|
|
|
_, err = w.Write(inf)
|
|
if err != nil {
|
|
errcoll.Collect(ctx, dc.errColl, dc.logger, "http resp write", err)
|
|
}
|
|
}
|
|
|
|
// info returns an information record by the random request ID.
|
|
func (dc *RemoteKV) info(ctx context.Context, randomID string) (inf []byte, ok bool, err error) {
|
|
defer func() {
|
|
metrics.DNSCheckRequestTotal.With(prometheus.Labels{
|
|
"type": "http",
|
|
"valid": metrics.BoolString(err == nil),
|
|
}).Inc()
|
|
|
|
incErrMetrics("http", err)
|
|
}()
|
|
|
|
defer func() { err = errors.Annotate(err, "getting from remote kv: %w") }()
|
|
|
|
dc.mu.Lock()
|
|
defer dc.mu.Unlock()
|
|
|
|
infoVal, ok := dc.cache.Get(randomID)
|
|
if ok {
|
|
return infoVal.([]byte), true, nil
|
|
}
|
|
|
|
inf, ok, err = dc.kv.Get(ctx, randomID)
|
|
if err != nil {
|
|
errcoll.Collect(ctx, dc.errColl, dc.logger, "remote kv getting", err)
|
|
|
|
// Don't wrap the error, as it will get annotated.
|
|
return nil, false, err
|
|
}
|
|
|
|
return inf, ok, nil
|
|
}
|
|
|
|
// info is a single DNS client and server information record.
|
|
type info struct {
|
|
ClientIP netip.Addr `json:"client_ip"`
|
|
|
|
DeviceID agd.DeviceID `json:"device_id"`
|
|
ProfileID agd.ProfileID `json:"profile_id"`
|
|
ServerGroupName agd.ServerGroupName `json:"server_group_name"`
|
|
ServerName agd.ServerName `json:"server_name"`
|
|
ServerType serverType `json:"server_type"`
|
|
|
|
Protocol string `json:"protocol"`
|
|
NodeLocation string `json:"node_location"`
|
|
NodeName string `json:"node_name"`
|
|
}
|