mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
505 lines
14 KiB
Go
505 lines
14 KiB
Go
package dnsserver
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"os"
|
|
"runtime/debug"
|
|
"sync"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
|
"github.com/AdguardTeam/golibs/log"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// ConfigBase contains the necessary minimum that every Server needs to
|
|
// be initialized.
|
|
type ConfigBase struct {
|
|
// Handler is a handler that processes incoming DNS messages. If not set,
|
|
// the default handler, which returns error response to any query, is used.
|
|
Handler Handler
|
|
|
|
// Metrics is the object we use for collecting performance metrics. If not
|
|
// set, [EmptyMetricsListener] is used.
|
|
Metrics MetricsListener
|
|
|
|
// Disposer is used to help module users reuse parts of DNS responses. If
|
|
// not set, EmptyDisposer is used.
|
|
Disposer Disposer
|
|
|
|
// RequestContext is a ContextConstructor that returns contexts for
|
|
// requests. If not set, the server uses [DefaultContextConstructor].
|
|
RequestContext ContextConstructor
|
|
|
|
// ListenConfig, when set, is used to set options of connections used by the
|
|
// DNS server. If nil, an appropriate default ListenConfig is used.
|
|
ListenConfig netext.ListenConfig
|
|
|
|
// Network is the network this server listens to. If empty, the server will
|
|
// listen to all networks that are supposed to be used by the server's
|
|
// protocol. Note, that it only makes sense for [ServerDNS],
|
|
// [ServerDNSCrypt], and [ServerHTTPS].
|
|
Network Network
|
|
|
|
// Name is used for logging, and it may be used for perf counters reporting.
|
|
Name string
|
|
|
|
// Addr is the address the server listens to. See [net.Dial] for the
|
|
// documentation on the address format.
|
|
Addr string
|
|
}
|
|
|
|
// ServerBase implements base methods that every Server implementation uses.
|
|
type ServerBase struct {
|
|
// handler is a handler that processes incoming DNS messages.
|
|
handler Handler
|
|
|
|
// reqCtx is a function that should return the base context.
|
|
reqCtx ContextConstructor
|
|
|
|
// metrics is the object we use for collecting performance metrics.
|
|
metrics MetricsListener
|
|
|
|
// disposer is used to help module users reuse parts of DNS responses.
|
|
disposer Disposer
|
|
|
|
// listenConfig is used to set tcpListener and udpListener.
|
|
listenConfig netext.ListenConfig
|
|
|
|
// tcpListener is used to accept new TCP connections. It is nil for servers
|
|
// that don't use TCP.
|
|
tcpListener net.Listener
|
|
|
|
// udpListener is used to accept new UDP messages. It is nil for servers
|
|
// that don't use UDP.
|
|
udpListener net.PacketConn
|
|
|
|
// mu protects started, tcpListener, and udpListener.
|
|
mu *sync.RWMutex
|
|
|
|
// wg tracks active workers (listeners or query processing). Shutdown
|
|
// won't finish until there's at least one active worker.
|
|
wg *sync.WaitGroup
|
|
|
|
// name is used for logging and it may be used for perf counters reporting.
|
|
name string
|
|
|
|
// addr is the address the server listens to.
|
|
addr string
|
|
|
|
// network is the network to listen to. It only makes sense for the
|
|
// following protocols: [ProtoDNS], [ProtoDNSCrypt], [ProtoDoH].
|
|
network Network
|
|
|
|
// proto is the server protocol.
|
|
proto Protocol
|
|
|
|
started bool
|
|
}
|
|
|
|
// type check
|
|
var _ Server = (*ServerBase)(nil)
|
|
|
|
// newServerBase creates a new instance of ServerBase and initializes
|
|
// some of its internal properties.
|
|
func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
|
|
s = &ServerBase{
|
|
handler: conf.Handler,
|
|
reqCtx: conf.RequestContext,
|
|
metrics: conf.Metrics,
|
|
disposer: conf.Disposer,
|
|
listenConfig: conf.ListenConfig,
|
|
mu: &sync.RWMutex{},
|
|
wg: &sync.WaitGroup{},
|
|
name: conf.Name,
|
|
addr: conf.Addr,
|
|
network: conf.Network,
|
|
proto: proto,
|
|
}
|
|
|
|
if s.reqCtx == nil {
|
|
s.reqCtx = DefaultContextConstructor{}
|
|
}
|
|
|
|
if s.metrics == nil {
|
|
s.metrics = &EmptyMetricsListener{}
|
|
}
|
|
|
|
if s.disposer == nil {
|
|
s.disposer = EmptyDisposer{}
|
|
}
|
|
|
|
if s.handler == nil {
|
|
s.handler = notImplementedHandlerFunc
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// Name implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Name() (name string) {
|
|
return s.name
|
|
}
|
|
|
|
// Proto implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Proto() (proto Protocol) {
|
|
return s.proto
|
|
}
|
|
|
|
// Network implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Network() (network Network) {
|
|
return s.network
|
|
}
|
|
|
|
// Addr implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Addr() (addr string) {
|
|
return s.addr
|
|
}
|
|
|
|
// Start implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Start(_ context.Context) (err error) {
|
|
panic("*ServerBase must not be used directly")
|
|
}
|
|
|
|
// Shutdown implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) Shutdown(_ context.Context) (err error) {
|
|
panic("*ServerBase must not be used directly")
|
|
}
|
|
|
|
// LocalTCPAddr implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) LocalTCPAddr() (addr net.Addr) {
|
|
if s.tcpListener != nil {
|
|
return s.tcpListener.Addr()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LocalUDPAddr implements the [dnsserver.Server] interface for *ServerBase.
|
|
func (s *ServerBase) LocalUDPAddr() (addr net.Addr) {
|
|
if s.udpListener != nil {
|
|
return s.udpListener.LocalAddr()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// requestContext returns a context for one request and adds server information.
|
|
func (s *ServerBase) requestContext() (ctx context.Context, cancel context.CancelFunc) {
|
|
ctx, cancel = s.reqCtx.New()
|
|
ctx = ContextWithServerInfo(ctx, &ServerInfo{
|
|
Name: s.name,
|
|
Addr: s.addr,
|
|
Proto: s.proto,
|
|
})
|
|
|
|
return ctx, cancel
|
|
}
|
|
|
|
// serveDNS processes the incoming DNS query and writes the response to the
|
|
// specified ResponseWriter. written is false if no response was written.
|
|
func (s *ServerBase) serveDNS(ctx context.Context, buf []byte, rw ResponseWriter) (written bool) {
|
|
req := &dns.Msg{}
|
|
if err := req.Unpack(buf); err != nil {
|
|
// Ignore the incoming message and let the connection hang as it may be
|
|
// used to amplify.
|
|
s.metrics.OnInvalidMsg(ctx)
|
|
|
|
return false
|
|
}
|
|
|
|
return s.serveDNSMsg(ctx, req, rw)
|
|
}
|
|
|
|
// serveDNSMsg processes the incoming DNS query and writes the response to the
|
|
// specified ResponseWriter. written is false if no response was written.
|
|
func (s *ServerBase) serveDNSMsg(
|
|
ctx context.Context,
|
|
req *dns.Msg,
|
|
rw ResponseWriter,
|
|
) (written bool) {
|
|
hostname, qType := questionData(req)
|
|
log.Debug("[%d] processing \"%s %s\"", req.Id, qType, hostname)
|
|
|
|
recW := NewRecorderResponseWriter(rw)
|
|
s.serveDNSMsgInternal(ctx, req, recW)
|
|
|
|
resp := recW.Resp
|
|
written = resp != nil
|
|
|
|
var respLen int
|
|
if written {
|
|
// TODO(a.garipov): Use the real number of bytes written by
|
|
// [ResponseWriter] to the socket.
|
|
respLen = resp.Len()
|
|
}
|
|
|
|
s.metrics.OnRequest(ctx, &QueryInfo{
|
|
Request: req,
|
|
RequestSize: req.Len(),
|
|
Response: resp,
|
|
ResponseSize: respLen,
|
|
}, rw)
|
|
|
|
log.Debug("[%d]: finished processing \"%s %s\"", req.Id, qType, hostname)
|
|
|
|
s.dispose(rw, resp)
|
|
|
|
return written
|
|
}
|
|
|
|
// dispose is a helper for disposing a DNS response right after writing it to a
|
|
// connection. Disposal of a response is only safe assuming that there is no
|
|
// further processing up the stack. Currently, this is only true for plain DNS
|
|
// and DoT at this point in the code.
|
|
//
|
|
// TODO(a.garipov): Add DoQ as well once the legacy format is removed.
|
|
func (s *ServerBase) dispose(rw ResponseWriter, resp *dns.Msg) {
|
|
switch rw.(type) {
|
|
case
|
|
*tcpResponseWriter,
|
|
*udpResponseWriter:
|
|
s.disposer.Dispose(resp)
|
|
default:
|
|
// Go on.
|
|
}
|
|
}
|
|
|
|
// serveDNSMsgInternal serves the DNS request and uses recorder as a
|
|
// ResponseWriter. This method is supposed to be called from serveDNSMsg,
|
|
// the recorded response is used for counting metrics.
|
|
func (s *ServerBase) serveDNSMsgInternal(
|
|
ctx context.Context,
|
|
req *dns.Msg,
|
|
rw *RecorderResponseWriter,
|
|
) {
|
|
var resp *dns.Msg
|
|
|
|
// Check if we can accept this message
|
|
switch action := s.acceptMsg(req); action {
|
|
case dns.MsgReject:
|
|
log.Debug("[%d] Query format is invalid", req.Id)
|
|
resp = genErrorResponse(req, dns.RcodeFormatError)
|
|
case dns.MsgRejectNotImplemented:
|
|
log.Debug("[%d] Rejecting this query", req.Id)
|
|
resp = genErrorResponse(req, dns.RcodeNotImplemented)
|
|
case dns.MsgIgnore:
|
|
log.Debug("[%d] Ignoring this query", req.Id)
|
|
s.metrics.OnInvalidMsg(ctx)
|
|
|
|
return
|
|
}
|
|
|
|
// If resp is not empty at this stage, the request is invalid and we should
|
|
// simply exit here.
|
|
if resp != nil {
|
|
// Ignore errors and just write the message
|
|
log.Debug("[%d]: writing DNS response code %d", req.Id, resp.Rcode)
|
|
err := rw.WriteMsg(ctx, req, resp)
|
|
if err != nil {
|
|
log.Debug("[%d]: error writing a response: %v", req.Id, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
err := s.handler.ServeDNS(ctx, rw, req)
|
|
if err != nil {
|
|
log.Debug("[%d]: handler returned an error: %s", req.Id, err)
|
|
s.metrics.OnError(ctx, err)
|
|
|
|
resp = genErrorResponse(req, dns.RcodeServerFailure)
|
|
if isNonCriticalNetError(err) {
|
|
addEDE(req, resp, dns.ExtendedErrorCodeNetworkError, "")
|
|
}
|
|
|
|
err = rw.WriteMsg(ctx, req, resp)
|
|
if err != nil {
|
|
log.Debug("[%d]: error writing a response: %s", req.Id, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// addEDE adds an Extended DNS Error (EDE) option to the blocked response
|
|
// message, if the request indicates EDNS support.
|
|
func addEDE(req, resp *dns.Msg, code uint16, text string) {
|
|
reqOpt := req.IsEdns0()
|
|
if reqOpt == nil {
|
|
// Requestor doesn't implement EDNS, see
|
|
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
|
|
return
|
|
}
|
|
|
|
respOpt := resp.IsEdns0()
|
|
if respOpt == nil {
|
|
resp.SetEdns0(reqOpt.UDPSize(), reqOpt.Do())
|
|
respOpt = resp.Extra[len(resp.Extra)-1].(*dns.OPT)
|
|
}
|
|
|
|
respOpt.Option = append(respOpt.Option, &dns.EDNS0_EDE{
|
|
InfoCode: code,
|
|
ExtraText: text,
|
|
})
|
|
}
|
|
|
|
// acceptMsg checks if we should process the incoming DNS query.
|
|
func (s *ServerBase) acceptMsg(m *dns.Msg) (action dns.MsgAcceptAction) {
|
|
if m.Response {
|
|
log.Debug("[%d]: message rejected since this is a response", m.Id)
|
|
|
|
return dns.MsgIgnore
|
|
}
|
|
|
|
if m.Opcode != dns.OpcodeQuery && m.Opcode != dns.OpcodeNotify {
|
|
log.Debug("[%d]: rejected due to unsupported opcode", m.Opcode)
|
|
|
|
return dns.MsgRejectNotImplemented
|
|
}
|
|
|
|
// There can only be one question in request, unless DNS Cookies are
|
|
// involved. See AGDNS-738.
|
|
if len(m.Question) != 1 {
|
|
log.Debug("[%d]: message rejected due to wrong number of questions", m.Id)
|
|
|
|
return dns.MsgReject
|
|
}
|
|
|
|
// NOTIFY requests can have a SOA in the ANSWER section. See RFC 1996 Section 3.7 and 3.11.
|
|
if len(m.Answer) > 1 {
|
|
log.Debug("[%d]: message rejected due to wrong number of answers", m.Id)
|
|
|
|
return dns.MsgReject
|
|
}
|
|
|
|
// IXFR request could have one SOA RR in the NS section. See RFC 1995, section 3.
|
|
if len(m.Ns) > 1 {
|
|
log.Debug("[%d]: message rejected due to wrong number of NS records", m.Id)
|
|
|
|
return dns.MsgReject
|
|
}
|
|
|
|
return dns.MsgAccept
|
|
}
|
|
|
|
// handlePanicAndExit writes panic info to log, reports it to the registered
|
|
// MetricsListener and calls os.Exit with a positive exit code.
|
|
func (s *ServerBase) handlePanicAndExit(ctx context.Context) {
|
|
if v := recover(); v != nil {
|
|
log.Error(
|
|
"%q(%s://%s): panic encountered, exiting: %v\n%s",
|
|
s.name,
|
|
s.proto,
|
|
s.addr,
|
|
v,
|
|
string(debug.Stack()),
|
|
)
|
|
|
|
s.metrics.OnPanic(ctx, v)
|
|
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// handlePanicAndRecover writes panic info to log, reports it to the registered
|
|
// MetricsListener.
|
|
func (s *ServerBase) handlePanicAndRecover(ctx context.Context) {
|
|
if v := recover(); v != nil {
|
|
log.Error(
|
|
"%s %s://%s: panic encountered, recovered: %s\n%s",
|
|
s.name,
|
|
s.addr,
|
|
s.proto,
|
|
v,
|
|
string(debug.Stack()),
|
|
)
|
|
s.metrics.OnPanic(ctx, v)
|
|
}
|
|
}
|
|
|
|
// listenUDP initializes and starts s.udpListener using s.addr. If the TCP
|
|
// listener is already running, its address is used instead to properly handle
|
|
// the case when port 0 is used as both listeners should use the same port, and
|
|
// we only learn it after the first one was started.
|
|
func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
|
|
addr := s.addr
|
|
if s.tcpListener != nil {
|
|
addr = s.tcpListener.Addr().String()
|
|
}
|
|
|
|
conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.udpListener = conn
|
|
|
|
return nil
|
|
}
|
|
|
|
// listenTCP initializes and starts s.tcpListener using s.addr. If the UDP
|
|
// listener is already running, its address is used instead to properly handle
|
|
// the case when port 0 is used as both listeners should use the same port, and
|
|
// we only learn it after the first one was started.
|
|
func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
|
|
addr := s.addr
|
|
if s.udpListener != nil {
|
|
addr = s.udpListener.LocalAddr().String()
|
|
}
|
|
|
|
l, err := s.listenConfig.Listen(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.tcpListener = l
|
|
|
|
return nil
|
|
}
|
|
|
|
// closeListeners stops UDP and TCP listeners.
|
|
func (s *ServerBase) closeListeners() {
|
|
if s.udpListener != nil {
|
|
err := s.udpListener.Close()
|
|
if err != nil {
|
|
log.Info("[%s]: Failed to close NetworkUDP listener: %v", s.Name(), err)
|
|
}
|
|
}
|
|
if s.tcpListener != nil {
|
|
err := s.tcpListener.Close()
|
|
if err != nil {
|
|
log.Info("[%s]: Failed to close NetworkTCP listener: %v", s.Name(), err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// waitShutdown waits either until context deadline OR ServerBase.wg.
|
|
func (s *ServerBase) waitShutdown(ctx context.Context) (err error) {
|
|
// Using this channel to wait until all goroutines finish their work
|
|
closed := make(chan struct{})
|
|
go func() {
|
|
defer log.OnPanic("waitShutdown")
|
|
|
|
// wait until all queries are processed
|
|
s.wg.Wait()
|
|
close(closed)
|
|
}()
|
|
|
|
var ctxErr error
|
|
select {
|
|
case <-closed:
|
|
// Do nothing here
|
|
case <-ctx.Done():
|
|
ctxErr = ctx.Err()
|
|
}
|
|
|
|
return ctxErr
|
|
}
|
|
|
|
// isStarted returns true if the server is started.
|
|
func (s *ServerBase) isStarted() (started bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
return s.started
|
|
}
|