AdGuardDNS/internal/dnsserver/serverhttps.go
Andrey Meshkov f1791135af Sync v2.11.0
2024-12-05 14:19:25 +03:00

645 lines
17 KiB
Go

package dnsserver
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
)
const (
// MimeTypeDoH is a Content-Type that DoH wireformat requests and responses
// must use.
MimeTypeDoH = "application/dns-message"
// MimeTypeJSON is a Content-Type that DoH JSON requests and responses
// must use.
MimeTypeJSON = "application/x-javascript"
// PathDoH is a relative path we use to accept DoH wireformat requests.
PathDoH = "/dns-query"
// PathJSON is a relative path we use to accept DoH JSON requests.
PathJSON = "/resolve"
httpReadTimeout = 5 * time.Second
httpWriteTimeout = 5 * time.Second
httpIdleTimeout = 120 * time.Second
)
// NextProtoDoH is a list of ALPN protocols added by default to the server's
// *tls.Config if no NextProto is specified there. Note that with this order,
// HTTP/2 is prioritized over HTTP/1.1.
var NextProtoDoH = []string{http2.NextProtoTLS, "http/1.1"}
// NextProtoDoH3 is a list of ALPN protocols added by default to the server's
// *tls.Config if no NextProto is specified there and DoH3 is supposed to be
// used. Note that with this order, HTTP/2 is prioritized over HTTP/1.1.
var NextProtoDoH3 = []string{http3.NextProtoH3, http2.NextProtoTLS, "http/1.1"}
// ConfigHTTPS is a struct that needs to be passed to NewServerHTTPS to
// initialize a new ServerHTTPS instance. You can choose whether HTTP/3 is
// enabled or not by specifying [ConfigBase.Network]. By default, the server
// will listen to both HTTP/2 and HTTP/3, but if you set it to NetworkTCP, the
// server will only use HTTP/2 and NetworkUDP will mean HTTP/3 only.
type ConfigHTTPS struct {
// TLSConfDefault is the TLS configuration for HTTPS. If not set and
// [ConfigBase.Network] is set to NetworkTCP the server will listen to plain
// HTTP. If it is not nil, it must be set to [NextProtoDoH].
TLSConfDefault *tls.Config
// TLSConfH3 is the TLS configuration for DoH3. If it is not nil, it must
// be set to [NextProtoDoH3].
TLSConfH3 *tls.Config
// NonDNSHandler handles requests with the path not equal to /dns-query.
// If it is empty, the server will return 404 for requests like that.
NonDNSHandler http.Handler
ConfigBase
// MaxStreamsPerPeer is the maximum number of concurrent streams that a peer
// is allowed to open.
MaxStreamsPerPeer int
// QUICLimitsEnabled, if true, enables QUIC limiting.
QUICLimitsEnabled bool
}
// ServerHTTPS is a DoH server implementation. It supports both DNS Wireformat
// and DNS JSON format. Regular DoH (wireformat) will be available at the
// /dns-query location. JSON format will be available at the "/resolve"
// location.
type ServerHTTPS struct {
*ServerBase
// httpServer is an instance of an *http.Server that is responsible for
// handling HTTP/1.1 and HTTP/2 requests.
httpServer *http.Server
// h3Server is an instance of an *http.Server that is responsible for
// handling HTTP/3 requests.
h3Server *http3.Server
// quicListener is a listener that we use to serve DoH3 requests.
quicListener *quic.EarlyListener
// quicTransport is saved here to close it later.
quicTransport *quic.Transport
conf ConfigHTTPS
}
// type check
var _ Server = (*ServerHTTPS)(nil)
// NewServerHTTPS creates a new ServerHTTPS instance.
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
if conf.ListenConfig == nil {
// Do not enable OOB here, because ListenPacket is only used by HTTP/3,
// and quic-go sets the necessary flags.
conf.ListenConfig = netext.DefaultListenConfig(nil)
}
s = &ServerHTTPS{
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
conf: conf,
}
return s
}
// Start implements the dnsserver.Server interface for *ServerHTTPS.
func (s *ServerHTTPS) Start(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "starting doh server: %w") }()
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
return ErrServerAlreadyStarted
}
log.Info("[%s]: Starting the server", s.addr)
ctx = ContextWithServerInfo(ctx, &ServerInfo{
Name: s.name,
Addr: s.addr,
Proto: s.proto,
})
if s.proto != ProtoDoH {
return ErrInvalidArgument
}
if s.network.CanTCP() {
err = s.startHTTPSServer(ctx)
if err != nil {
return err
}
}
if s.network.CanUDP() {
err = s.startH3Server(ctx)
if err != nil {
return err
}
}
s.started = true
log.Info("[%s]: Server has been started", s.Name())
return nil
}
// Shutdown implements the dnsserver.Server interface for *ServerHTTPS.
func (s *ServerHTTPS) Shutdown(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "shutting down doh server: %w") }()
log.Info("[%s]: Stopping the server", s.Name())
err = s.shutdown(ctx)
if err != nil {
log.Info("[%s]: Failed to shutdown: %v", s.Name(), err)
return err
}
err = s.waitShutdown(ctx)
log.Info("[%s]: Finished stopping the server", s.Name())
return err
}
// startHTTPSServer starts the HTTPS server that will handle HTTP/1.1 and HTTP2.
func (s *ServerHTTPS) startHTTPSServer(ctx context.Context) (err error) {
// Start the TLS or TCP listener.
err = s.listenTLS(ctx)
if err != nil {
return err
}
// Prepare and run the HTTP server.
handler := &httpHandler{
srv: s,
localAddr: s.tcpListener.Addr(),
}
// Create an instance of the HTTP server.
s.httpServer = &http.Server{
Handler: handler,
ReadTimeout: httpReadTimeout,
ReadHeaderTimeout: httpReadTimeout,
WriteTimeout: httpWriteTimeout,
IdleTimeout: httpIdleTimeout,
ErrorLog: log.StdLog("dnsserver/serverhttps: "+s.name, log.DEBUG),
}
// Start the server worker goroutine.
s.wg.Add(1)
go s.serveHTTPS(ctx, s.httpServer, s.tcpListener)
return nil
}
// startH3Server starts the HTTP/3 server.
func (s *ServerHTTPS) startH3Server(ctx context.Context) (err error) {
// Start the QUIC listener.
err = s.listenQUIC(ctx)
if err != nil {
return err
}
// Prepare and run the HTTP/3 server.
handler := &httpHandler{
srv: s,
localAddr: s.quicListener.Addr(),
}
// Create an instance of the HTTP/3 server.
s.h3Server = &http3.Server{
Handler: handler,
}
// Start the server worker goroutine.
s.wg.Add(1)
go s.serveH3(ctx, s.h3Server, s.quicListener)
return nil
}
// shutdown marks the server as stopped and closes active listeners.
func (s *ServerHTTPS) shutdown(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return ErrServerNotStarted
}
s.started = false
// First step, close the active TCP listener right away. Don't close the
// UDP one, as if there is one, it is closed by closing the QUIC listener.
//
// TODO(a.garipov): Fix this mess.
if s.tcpListener != nil {
err = s.tcpListener.Close()
if err != nil {
log.Info("[%s]: Failed to close NetworkTCP listener: %v", s.Name(), err)
}
}
// Second, shutdown the HTTP server.
err = s.httpServer.Shutdown(ctx)
if err != nil {
log.Debug("[%s]: http server shutdown: %v", s.Name(), err)
}
// Finally, shutdown the HTTP/3 server.
s.shutdownH3()
return nil
}
// shutdownH3 shuts down the HTTP/3 server, if enabled, and logs all errors.
func (s *ServerHTTPS) shutdownH3() {
if s.h3Server == nil {
return
}
err := s.quicListener.Close()
if err != nil {
log.Debug("[%s]: quic listener shutdown: %s", s.Name(), err)
}
err = s.quicTransport.Close()
if err != nil {
log.Debug("[%s]: quic transport shutdown: %s", s.Name(), err)
}
err = s.h3Server.Close()
if err != nil {
log.Debug("[%s]: http/3 server shutdown: %s", s.Name(), err)
}
}
// serveHTTPS is launched in a worker goroutine and serves HTTP/1.1 and HTTP/2
// requests.
func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Listener) {
defer s.wg.Done()
// Do not recover from panics here since if this goroutine panics, the
// application won't be able to continue listening to DoH.
defer s.handlePanicAndExit(ctx)
scheme := urlutil.SchemeHTTPS
if s.conf.TLSConfDefault == nil {
scheme = urlutil.SchemeHTTP
}
u := &url.URL{
Scheme: scheme,
Host: s.addr,
}
log.Info("[%s]: Start listening to %s", s.name, u)
err := hs.Serve(l)
if err != nil {
log.Info("[%s]: Finished listening to %s due to %v", s.name, u, err)
}
}
// serveH3 is launched in a worker goroutine and serves HTTP/3 requests.
func (s *ServerHTTPS) serveH3(ctx context.Context, hs *http3.Server, ql *quic.EarlyListener) {
defer s.wg.Done()
// Do not recover from panics here since if this goroutine panics, the
// application won't be able to continue listening to DoH.
defer s.handlePanicAndExit(ctx)
u := &url.URL{
Scheme: http3.NextProtoH3,
Host: s.addr,
}
log.Info("[%s]: Start listening to %s", s.name, u)
err := hs.ServeListener(ql)
if err != nil {
log.Info("[%s]: Finished listening to %s due to %v", s.name, u, err)
}
}
// httpHandler is a helper structure that implements http.Handler
// and holds pointers to ServerHTTPS, net.Listener.
type httpHandler struct {
srv *ServerHTTPS
localAddr net.Addr
}
// type check
var _ http.Handler = (*httpHandler)(nil)
// remoteAddr gets HTTP request's remote address.
//
// TODO(a.garipov): Add trusted proxies and real IP extraction logic. Perhaps
// just copy from module dnsproxy if that one fits. Also, perhaps make that
// logic pluggable and put it into a new package in module golibs.
func (h *httpHandler) remoteAddr(r *http.Request) (addr net.Addr) {
// Consider that the http.Request.RemoteAddr documentation is correct and
// that it is always a valid ip:port value. Panic if it isn't so.
ipStr, port, err := netutil.SplitHostPort(r.RemoteAddr)
if err != nil {
panic(fmt.Sprintf("failed to split host:port %s: %v", r.RemoteAddr, err))
}
ip, err := netutil.ParseIP(ipStr)
if err != nil {
panic(fmt.Sprintf("failed to parse IP %s: %v", ipStr, err))
}
if NetworkFromAddr(h.localAddr) == NetworkUDP {
// This means that we're extracting remoteAddr from an HTTP/3 request.
return &net.UDPAddr{IP: ip, Port: int(port)}
}
return &net.TCPAddr{IP: ip, Port: int(port)}
}
// ServeHTTP implements the http.Handler interface for *httpHandler. It reads
// the DNS data from the request, resolves it, and sends a response.
//
// NOTE: r.Context() is only used to control cancelation. To add values to the
// context, use the BaseContext of this handler's ServerHTTPS.
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := h.srv.requestContext()
defer cancel()
if dl, ok := r.Context().Deadline(); ok {
ctx, cancel = context.WithDeadline(ctx, dl)
defer cancel()
}
defer h.srv.handlePanicAndRecover(ctx)
log.Debug("Received a request to %s", r.URL)
// TODO(ameshkov): Consider using ants.Pool here.
isDNS, _, _ := isDoH(r)
if isDNS {
h.serveDoH(ctx, w, r)
return
}
if h.srv.conf.NonDNSHandler != nil {
h.srv.conf.NonDNSHandler.ServeHTTP(w, r)
} else {
h.srv.metrics.OnInvalidMsg(ctx)
http.Error(w, "", http.StatusNotFound)
}
}
// serveDoH processes the incoming DNS message and writes the response back to
// the client.
func (h *httpHandler) serveDoH(ctx context.Context, w http.ResponseWriter, r *http.Request) {
m, err := httpRequestToMsg(r)
if err != nil {
log.Debug("Failed to convert request to a DNS message: %v", err)
h.srv.metrics.OnInvalidMsg(ctx)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
rAddr := h.remoteAddr(r)
lAddr := h.localAddr
rw := NewNonWriterResponseWriter(lAddr, rAddr)
ctx = addRequestInfo(ctx, r)
// Serve the query
written := h.srv.serveDNS(ctx, m, rw)
// If no response were written, indicate it via an internal server error.
if !written {
log.Debug("No response has been written by the handler")
http.Error(w, "No response", http.StatusInternalServerError)
return
}
// Get the response that has been written.
resp := rw.Msg()
req := rw.req
// Write the response to the client
err = h.writeResponse(req, resp, r, w)
if err != nil {
log.Debug("[%d] Failed to write HTTP response: %v", req.Id, err)
// Try writing an error response just in case.
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
h.srv.disposer.Dispose(resp)
}
// writeResponse writes the actual DNS response to the client and takes care of
// the response serialization, i.e. writes different content depending on the
// requested mime type (wireformat or JSON).
func (h *httpHandler) writeResponse(
req *dns.Msg,
resp *dns.Msg,
r *http.Request,
w http.ResponseWriter,
) (err error) {
// normalize the response
normalizeTCP(ProtoDoH, req, resp)
isDNS, _, ct := isDoH(r)
if !isDNS {
return fmt.Errorf("invalid request path: %s", r.URL.Path)
}
var buf []byte
switch ct {
case MimeTypeDoH:
buf, err = resp.Pack()
w.Header().Set(httphdr.ContentType, MimeTypeDoH)
case MimeTypeJSON:
buf, err = dnsMsgToJSON(resp)
w.Header().Set(httphdr.ContentType, MimeTypeJSON)
default:
err = fmt.Errorf("invalid content type: %q", ct)
}
if err != nil {
return err
}
// From RFC8484, Section 5.1:
// DoH servers SHOULD assign an explicit HTTP freshness
// lifetime (see Section 4.2 of [RFC7234]) so that the DoH client is
// more likely to use fresh DNS data.
maxAge := minimalTTL(resp)
w.Header().Set(httphdr.CacheControl, fmt.Sprintf("max-age=%f", maxAge.Seconds()))
w.Header().Set(httphdr.ContentLength, strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK)
// Write the actual response
log.Debug("[%d] Writing HTTP response", req.Id)
_, err = w.Write(buf)
return err
}
// listenTCP starts the TCP/TLS listener. Note that if there's no TLS config,
// a plain TCP listener will be started instead.
func (s *ServerHTTPS) listenTLS(ctx context.Context) (err error) {
err = s.listenTCP(ctx)
if err != nil {
return err
}
// Prepare the TLS configuration of the server.
tlsConf := s.conf.TLSConfDefault
if tlsConf == nil {
return nil
}
s.tcpListener = tls.NewListener(s.tcpListener, tlsConf)
return nil
}
// listenQUIC starts a QUIC listener that will be used to serve HTTP/3 requests.
func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
// Prepare the TLS configuration of the server.
tlsConf := s.conf.TLSConfH3
conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
if err != nil {
return fmt.Errorf("listening udp for quic: %w", err)
}
v := newQUICAddrValidator(quicAddrValidatorCacheSize, s.metrics, quicAddrValidatorCacheTTL)
transport := &quic.Transport{
Conn: conn,
VerifySourceAddress: v.requiresValidation,
}
qConf := newServerQUICConfig(s.conf.QUICLimitsEnabled, s.conf.MaxStreamsPerPeer)
ql, err := transport.ListenEarly(tlsConf, qConf)
if err != nil {
return fmt.Errorf("listening quic: %w", err)
}
s.udpListener = conn
s.quicTransport = transport
s.quicListener = ql
return nil
}
// addRequestInfo adds request info to the context.
func addRequestInfo(parent context.Context, r *http.Request) (ctx context.Context) {
ctx = parent
ri := &RequestInfo{
StartTime: time.Now(),
URL: netutil.CloneURL(r.URL),
}
if r.TLS != nil {
ri.TLSServerName = r.TLS.ServerName
}
if username, pass, ok := r.BasicAuth(); ok {
ri.Userinfo = url.UserPassword(username, pass)
}
return ContextWithRequestInfo(ctx, ri)
}
// httpRequestToMsg reads the DNS message from http.Request.
func httpRequestToMsg(req *http.Request) (b []byte, err error) {
_, isJSON, _ := isDoH(req)
if isJSON {
return httpRequestToMsgJSON(req)
}
switch req.Method {
case http.MethodGet:
return httpRequestToMsgGet(req)
case http.MethodPost:
return httpRequestToMsgPost(req)
default:
return nil, fmt.Errorf("method not allowed: %s", req.Method)
}
}
// httpRequestToMsgPost extracts the DNS message from a request body.
func httpRequestToMsgPost(req *http.Request) (b []byte, err error) {
buf, err := io.ReadAll(req.Body)
defer log.OnCloserError(req.Body, log.DEBUG)
return buf, err
}
// httpRequestToMsgGet extracts the DNS message from a GET request.
func httpRequestToMsgGet(req *http.Request) (b []byte, err error) {
values := req.URL.Query()
b64, ok := values["dns"]
if !ok {
return nil, fmt.Errorf("no 'dns' query parameter found")
}
if len(b64) != 1 {
return nil, fmt.Errorf("multiple 'dns' query values found")
}
return base64.RawURLEncoding.DecodeString(b64[0])
}
// isDoH returns true if r.URL.Path contains DNS-over-HTTP paths, and also what
// content type is desired by the user. isJSON is true if the user uses the
// JSON API. ct can be either MimeTypeDoH or MimeTypeJSON.
func isDoH(r *http.Request) (ok, isJSON bool, ct string) {
parts := strings.Split(path.Clean(r.URL.Path), "/")
if parts[0] == "" {
parts = parts[1:]
}
switch {
case parts[0] == "":
return false, false, ""
case strings.HasSuffix(PathDoH, parts[0]):
return true, false, MimeTypeDoH
case strings.HasSuffix(PathJSON, parts[0]):
desiredCt := r.URL.Query().Get("ct")
if desiredCt == MimeTypeDoH {
return true, true, MimeTypeDoH
}
return true, true, MimeTypeJSON
default:
return false, false, ""
}
}