561 lines
15 KiB
Go
Raw Normal View History

2022-08-26 14:18:35 +03:00
// Package dnssvc contains AdGuard DNS's main DNS services.
//
// Prefer to keep all mentions of module dnsserver within this package and
// package agd.
package dnssvc
import (
"context"
"fmt"
"net/http"
2023-08-08 18:31:48 +03:00
"time"
2022-08-26 14:18:35 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
2023-06-11 12:58:40 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
2022-08-26 14:18:35 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
2023-03-18 17:11:10 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
2022-08-26 14:18:35 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
2023-06-11 12:58:40 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
2022-08-26 14:18:35 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
// DNS Service Definition
//
// Note that the definition of a “server” differs between AdGuard DNS and the
// dnsserver module. In the latter, a server is a listener bound to a single
// address, while in AGDNS, it's a collection of these listeners.
// Config is the configuration of the AdGuard DNS service.
type Config struct {
// Messages is the message constructor used to create blocked and other
// messages for this DNS service.
Messages *dnsmsg.Constructor
2023-06-11 12:58:40 +03:00
// ControlConf is the configuration of socket options.
ControlConf *netext.ControlConfig
// ConnLimiter, if not nil, is used to limit the number of simultaneously
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// SafeBrowsing is the safe browsing TXT hash matcher.
SafeBrowsing filter.HashMatcher
2022-08-26 14:18:35 +03:00
// BillStat is used to collect billing statistics.
BillStat billstat.Recorder
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on.
2023-06-11 12:58:40 +03:00
ProfileDB profiledb.Interface
2022-08-26 14:18:35 +03:00
// DNSCheck is used by clients to check if they use AdGuard DNS.
DNSCheck dnscheck.Interface
// NonDNS is the handler for non-DNS HTTP requests.
NonDNS http.Handler
// DNSDB is used to update anonymous statistics about DNS queries.
DNSDB dnsdb.Interface
// ErrColl is the error collector that is used to collect critical and
// non-critical errors.
ErrColl agd.ErrorCollector
// FilterStorage is the storage of all filters.
FilterStorage filter.Storage
// GeoIP is the GeoIP database used to detect geographic data about IP
// addresses in requests and responses.
GeoIP geoip.Interface
// QueryLog is used to write the logs into.
QueryLog querylog.Interface
// RuleStat is used to collect statistics about matched filtering rules and
// rule lists.
RuleStat rulestat.Interface
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
2023-03-18 17:11:10 +03:00
// TODO(a.garipov): The handler and service logic should really not be
2023-06-11 12:58:40 +03:00
// intertwined in this way. See AGDNS-1327.
2022-08-26 14:18:35 +03:00
NewListener NewListenerFunc
// Handler is used as the main DNS handler instead of a simple forwarder.
// It must not be nil.
//
// TODO(a.garipov): Think of a better way to make the DNS server logic more
// testable.
Handler dnsserver.Handler
// RateLimit is used for allow or decline requests.
RateLimit ratelimit.Interface
// FilteringGroups are the DNS filtering groups. Each element must be
// non-nil.
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
// ServerGroups are the DNS server groups. Each element must be non-nil.
ServerGroups []*agd.ServerGroup
// CacheSize is the size of the DNS cache for domain names that don't
// support ECS.
2023-08-08 18:31:48 +03:00
//
// TODO(a.garipov): Extract this and following fields to cache configuration
// struct.
2022-08-26 14:18:35 +03:00
CacheSize int
// ECSCacheSize is the size of the DNS cache for domain names that support
// ECS.
ECSCacheSize int
2023-08-08 18:31:48 +03:00
// CacheMinTTL is the minimum supported TTL for cache items. This setting
// is used when UseCacheTTLOverride set to true.
CacheMinTTL time.Duration
// UseCacheTTLOverride shows if the TTL overrides logic should be used.
UseCacheTTLOverride bool
2022-08-26 14:18:35 +03:00
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used.
UseECSCache bool
2023-02-03 15:27:58 +03:00
// ResearchMetrics controls whether research metrics are enabled or not.
// This is a set of metrics that we may need temporary, so its collection is
// controlled by a separate setting.
ResearchMetrics bool
2023-08-08 18:31:48 +03:00
// ResearchLogs controls whether logging of additional info for research
// purposes is enabled. These logs may be overly verbose and are only
// required temporary, that's why it's controlled by a separate setting.
// This setting will only be used when ResearchMetrics is also set to true.
ResearchLogs bool
2022-08-26 14:18:35 +03:00
}
// New returns a new DNS service.
func New(c *Config) (svc *Service, err error) {
// Use either the configured listener initializer or the default one.
newListener := c.NewListener
if newListener == nil {
newListener = NewListener
}
// Configure the end of the request handling pipeline.
handler := c.Handler
if handler == nil {
return nil, errors.Error("handler in config must not be nil")
}
// Configure the pre-upstream middleware common for all servers of all
// groups.
preUps := &preUpstreamMw{
2023-08-08 18:31:48 +03:00
db: c.DNSDB,
geoIP: c.GeoIP,
cacheSize: c.CacheSize,
ecsCacheSize: c.ECSCacheSize,
useECSCache: c.UseECSCache,
cacheMinTTL: c.CacheMinTTL,
useCacheTTLOverride: c.UseCacheTTLOverride,
2022-08-26 14:18:35 +03:00
}
handler = preUps.Wrap(handler)
// Configure the service itself.
groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{
messages: c.Messages,
billStat: c.BillStat,
errColl: c.ErrColl,
fltStrg: c.FilterStorage,
geoIP: c.GeoIP,
queryLog: c.QueryLog,
ruleStat: c.RuleStat,
groups: groups,
2023-02-03 15:27:58 +03:00
researchMetrics: c.ResearchMetrics,
2023-08-08 18:31:48 +03:00
researchLog: c.ResearchLogs,
2022-08-26 14:18:35 +03:00
}
for i, srvGrp := range c.ServerGroups {
// The Filtering Middlewares
//
// These are middlewares common to all filtering and server groups.
// They change the flow of request handling, so they are separated.
//
// TODO(a.garipov): Merge with some other middlewares.
dnsHdlr := dnsserver.WithMiddlewares(
handler,
2022-12-29 15:36:26 +03:00
&preServiceMw{
2023-06-11 12:58:40 +03:00
messages: c.Messages,
hashMatcher: c.SafeBrowsing,
checker: c.DNSCheck,
2022-08-26 14:18:35 +03:00
},
svc,
)
var servers []*server
servers, err = newServers(c, srvGrp, dnsHdlr, newListener)
if err != nil {
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
}
groups[i] = &serverGroup{
name: srvGrp.Name,
servers: servers,
}
}
return svc, nil
}
// server is a group of listeners.
type server struct {
name agd.ServerName
handler dnsserver.Handler
listeners []*listener
}
// serverGroup is a group of servers.
type serverGroup struct {
name agd.ServerGroupName
servers []*server
}
// type check
var _ agd.Service = (*Service)(nil)
// Service is the main DNS service of AdGuard DNS.
type Service struct {
2023-08-08 18:31:48 +03:00
messages *dnsmsg.Constructor
billStat billstat.Recorder
errColl agd.ErrorCollector
fltStrg filter.Storage
geoIP geoip.Interface
queryLog querylog.Interface
ruleStat rulestat.Interface
groups []*serverGroup
// researchMetrics enables reporting metrics that may be needed for research
// purposes.
2023-02-03 15:27:58 +03:00
researchMetrics bool
2023-08-08 18:31:48 +03:00
// researchLog enables logging of additional information that may be needed
// for research purposes. It will only be used when researchMetrics is set
// to true.
researchLog bool
2022-08-26 14:18:35 +03:00
}
// mustStartListener starts l and panics on any error.
func mustStartListener(
grp agd.ServerGroupName,
srv agd.ServerName,
l *listener,
) {
err := l.Start(context.Background())
if err != nil {
panic(fmt.Errorf("group %q: server %q: starting %q: %w", grp, srv, l.name, err))
}
}
// Start implements the agd.Service interface for *Service. It panics if one of
// the listeners could not start.
func (svc *Service) Start() (err error) {
for _, g := range svc.groups {
for _, s := range g.servers {
for _, l := range s.listeners {
// Consider inability to start any one DNS listener a fatal
// error.
mustStartListener(g.name, s.name, l)
}
}
}
return nil
}
// shutdownListeners is a helper function that shuts down all listeners of a
// server.
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
for _, l := range listeners {
err = l.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
}
}
return nil
}
// Shutdown implements the agd.Service interface for *Service.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
var errs []error
for _, g := range svc.groups {
for _, s := range g.servers {
err = shutdownListeners(ctx, s.listeners)
if err != nil {
2023-03-18 17:11:10 +03:00
errs = append(errs, fmt.Errorf("group %q: server %q: %w", g.name, s.name, err))
2022-08-26 14:18:35 +03:00
}
}
}
2023-03-18 17:11:10 +03:00
err = errors.Join(errs...)
if err != nil {
return fmt.Errorf("shutting down dns service: %w", err)
2022-08-26 14:18:35 +03:00
}
return nil
}
// Handle is a simple helper to test the handling of DNS requests.
func (svc *Service) Handle(
ctx context.Context,
grpName agd.ServerGroupName,
srvName agd.ServerName,
rw dnsserver.ResponseWriter,
r *dns.Msg,
) (err error) {
var grp *serverGroup
for _, g := range svc.groups {
if g.name == grpName {
grp = g
break
}
}
if grp == nil {
return errors.Error("no such server group")
}
var srv *server
for _, s := range grp.servers {
if s.name == srvName {
srv = s
break
}
}
if srv == nil {
return errors.Error("no such server")
}
return srv.handler.ServeDNS(ctx, rw, r)
}
// Listener is a type alias for dnsserver.Server to make internal naming more
// consistent.
type Listener = dnsserver.Server
// NewListenerFunc is the type for DNS listener constructors.
type NewListenerFunc func(
s *agd.Server,
name string,
2023-03-18 17:11:10 +03:00
addr string,
2022-08-26 14:18:35 +03:00
h dnsserver.Handler,
nonDNS http.Handler,
errColl agd.ErrorCollector,
2023-03-18 17:11:10 +03:00
lc netext.ListenConfig,
2022-08-26 14:18:35 +03:00
) (l Listener, err error)
// listener is a Listener along with some of its associated data.
type listener struct {
Listener
name string
}
// listenerName returns a standard name for a listener.
2023-03-18 17:11:10 +03:00
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
2022-08-26 14:18:35 +03:00
}
// NewListener returns a new Listener. It is the default DNS listener
// constructor.
func NewListener(
s *agd.Server,
name string,
2023-03-18 17:11:10 +03:00
addr string,
2022-08-26 14:18:35 +03:00
h dnsserver.Handler,
nonDNS http.Handler,
errColl agd.ErrorCollector,
2023-03-18 17:11:10 +03:00
lc netext.ListenConfig,
2022-08-26 14:18:35 +03:00
) (l Listener, err error) {
defer func() { err = errors.Annotate(err, "listener %q: %w", name) }()
dcConf := s.DNSCrypt
metricsListener := &errCollMetricsListener{
errColl: errColl,
baseListener: &prometheus.ServerMetricsListener{},
}
confBase := dnsserver.ConfigBase{
2023-03-18 17:11:10 +03:00
Name: name,
Addr: addr,
Network: dnsserver.NetworkAny,
Handler: h,
Metrics: metricsListener,
BaseContext: ctxWithReqID,
ListenConfig: lc,
2022-08-26 14:18:35 +03:00
}
switch p := s.Protocol; p {
2022-11-07 10:21:24 +03:00
case agd.ProtoDNS:
2022-08-26 14:18:35 +03:00
l = dnsserver.NewServerDNS(dnsserver.ConfigDNS{ConfigBase: confBase})
2022-11-07 10:21:24 +03:00
case agd.ProtoDNSCrypt:
2022-08-26 14:18:35 +03:00
l = dnsserver.NewServerDNSCrypt(dnsserver.ConfigDNSCrypt{
ConfigBase: confBase,
DNSCryptProviderName: dcConf.ProviderName,
DNSCryptResolverCert: dcConf.Cert,
})
case agd.ProtoDoH:
l = dnsserver.NewServerHTTPS(dnsserver.ConfigHTTPS{
ConfigBase: confBase,
TLSConfig: s.TLS,
NonDNSHandler: nonDNS,
})
case agd.ProtoDoQ:
l = dnsserver.NewServerQUIC(dnsserver.ConfigQUIC{
ConfigBase: confBase,
TLSConfig: s.TLS,
})
case agd.ProtoDoT:
l = dnsserver.NewServerTLS(dnsserver.ConfigTLS{
ConfigDNS: dnsserver.ConfigDNS{ConfigBase: confBase},
TLSConfig: s.TLS,
})
default:
return nil, fmt.Errorf("bad protocol %v", p)
}
return l, nil
}
// ctxWithReqID returns a context with a new request ID added to it.
func ctxWithReqID() (ctx context.Context) {
return agd.WithRequestID(context.Background(), agd.NewRequestID())
}
// newServers creates a slice of servers.
func newServers(
c *Config,
srvGrp *agd.ServerGroup,
handler dnsserver.Handler,
newListener NewListenerFunc,
) (servers []*server, err error) {
servers = make([]*server, len(srvGrp.Servers))
for i, s := range srvGrp.Servers {
// The Initial Middlewares
//
// These middlewares are either specific to the server or must be the
// furthest away from the handler and thus are the first to process
// a request.
// Assume that all the validations have been made during the
// configuration validation step back in package cmd. If we ever get
// new ways of receiving configuration, remove this assumption and
// validate fg.
fg := c.FilteringGroups[srvGrp.FilteringGroup]
2022-11-07 10:21:24 +03:00
// Only apply rate-limiting logic to plain DNS.
rlProtos := []agd.Protocol{agd.ProtoDNS}
2022-08-26 14:18:35 +03:00
var rlm *ratelimit.Middleware
rlm, err = ratelimit.NewMiddleware(c.RateLimit, rlProtos)
if err != nil {
return nil, fmt.Errorf("ratelimit: %w", err)
}
rlm.Metrics = &prometheus.RateLimitMetricsListener{}
imw := &initMw{
messages: c.Messages,
fltGrp: fg,
srvGrp: srvGrp,
srv: s,
db: c.ProfileDB,
geoIP: c.GeoIP,
errColl: c.ErrColl,
}
h := dnsserver.WithMiddlewares(
handler,
// Keep the rate limiting middleware as the outer one to make sure
// that the application logic isn't touched if the request is
// ratelimited.
rlm,
imw,
)
2023-03-18 17:11:10 +03:00
listeners := make([]*listener, 0, len(s.BindData))
for _, bindData := range s.BindData {
addr := bindData.Address
if addr == "" {
addr = bindData.AddrPort.String()
}
2022-08-26 14:18:35 +03:00
name := listenerName(s.Name, addr, s.Protocol)
2023-06-11 12:58:40 +03:00
lc := bindData.ListenConfig
if lc == nil {
lc = newListenConfig(c.ControlConf, c.ConnLimiter, s.Protocol)
}
2022-08-26 14:18:35 +03:00
var l Listener
2023-06-11 12:58:40 +03:00
l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, lc)
2022-08-26 14:18:35 +03:00
if err != nil {
return nil, fmt.Errorf("server %q: %w", s.Name, err)
}
listeners = append(listeners, &listener{
name: name,
Listener: l,
})
}
servers[i] = &server{
name: s.Name,
handler: h,
listeners: listeners,
}
}
return servers, nil
}
2023-06-11 12:58:40 +03:00
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {
lc = netext.DefaultListenConfig(ctrlConf)
}
if l != nil {
lc = connlimiter.NewListenConfig(lc, l)
}
return lc
}