all: client custom upstream config

This commit is contained in:
Stanislav Chzhen 2025-02-13 15:46:52 +03:00
parent 2fe2d254b5
commit d9dbaf0c9b
8 changed files with 133 additions and 188 deletions

View File

@ -3,7 +3,6 @@ package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
@ -12,9 +11,6 @@ import (
// type check
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
// type check
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
// type check
//
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.

View File

@ -9,6 +9,8 @@ import (
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@ -16,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/uuid"
)
@ -58,11 +61,16 @@ func (uid *UID) UnmarshalText(data []byte) error {
// Persistent contains information about persistent clients.
type Persistent struct {
// UpstreamConfig is the custom upstream configuration for this client. If
// mu protects upstreamConfig, customUpstreamConf.
mu *sync.Mutex
// upstreamConfig is the custom upstream configuration for this client. If
// it's nil, it has not been initialized yet. If it's non-nil and empty,
// there are no valid upstreams. If it's non-nil and non-empty, these
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig
upstreamConfig *proxy.CustomUpstreamConfig
customUpstreamConf *UpstreamConfig
// SafeSearch handles search engine hosts rewrites.
SafeSearch filtering.SafeSearch
@ -166,6 +174,8 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
c.mu = &sync.Mutex{}
return nil
}
@ -189,6 +199,106 @@ func (c *Persistent) SetIDs(ids []string) (err error) {
return nil
}
// UpstreamConfig is a custom upstream configuration for the persistent client.
type UpstreamConfig struct {
Bootstrap upstream.Resolver
BootstrapAddrs []string
UpstreamTimeout time.Duration
BootstrapPreferIPv6 bool
EDNSClientSubnetEnabled bool
UseHTTP3Upstreams bool
}
func (conf *UpstreamConfig) Clone() (clone *UpstreamConfig) {
c := *conf
c.BootstrapAddrs = slices.Clone(conf.BootstrapAddrs)
return &c
}
func equalUpstreamSettings(a, b *UpstreamConfig) (ok bool) {
if a == nil {
return b == nil
} else if b == nil {
return false
}
if !slices.Equal(a.BootstrapAddrs, b.BootstrapAddrs) {
return false
}
if a.UpstreamTimeout != b.UpstreamTimeout {
return false
}
return true
}
// isCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func isCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// upstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func upstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// UpstreamConfig returns the custom upstream configuration for the client.
func (c *Persistent) UpstreamConfig(upsConf *UpstreamConfig) (prxConf *proxy.CustomUpstreamConfig) {
c.mu.Lock()
defer c.mu.Unlock()
if !equalUpstreamSettings(c.customUpstreamConf, upsConf) {
c.setUpstreamConfig(upsConf)
}
return c.upstreamConfig
}
// setUpstreamConfig sets the custom upstream configuration for the client.
func (c *Persistent) setUpstreamConfig(conf *UpstreamConfig) {
upstreams := stringutil.FilterOut(c.Upstreams, isCommentOrEmpty)
if len(upstreams) == 0 {
return
}
upsConf, err := proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: conf.Bootstrap,
Timeout: time.Duration(conf.UpstreamTimeout),
HTTPVersions: upstreamHTTPVersions(conf.UseHTTP3Upstreams),
PreferIPv6: conf.BootstrapPreferIPv6,
},
)
if err != nil {
// Should not happen because upstreams are already validated. See
// [Persistent.validate].
panic(err)
}
c.upstreamConfig = proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
conf.EDNSClientSubnetEnabled,
)
c.customUpstreamConf = conf.Clone()
}
// subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
// position is the same.
@ -315,8 +425,8 @@ func (c *Persistent) ShallowClone() (clone *Persistent) {
// CloseUpstreams closes the client-specific upstream config of c if any.
func (c *Persistent) CloseUpstreams() (err error) {
if c.UpstreamConfig != nil {
if err = c.UpstreamConfig.Close(); err != nil {
if c.upstreamConfig != nil {
if err = c.upstreamConfig.Close(); err != nil {
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
}
}

View File

@ -29,19 +29,6 @@ import (
"github.com/ameshkov/dnscrypt/v2"
)
// ClientsContainer provides information about preconfigured DNS clients.
type ClientsContainer interface {
// UpstreamConfigByID returns the custom upstream configuration for the
// client having id, using boot to initialize the one if necessary. It
// returns nil if there is no custom upstream configuration for the client.
// The id is expected to be either a string representation of an IP address
// or the ClientID.
UpstreamConfigByID(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error)
}
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use.
type Config struct {
@ -50,9 +37,8 @@ type Config struct {
// FilterHandler is an optional additional filtering callback.
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
// ClientsContainer stores the information about special handling of some
// DNS clients.
ClientsContainer ClientsContainer `yaml:"-"`
// ClientStorage stores information about persistent clients.
ClientStorage *client.Storage `yaml:"-"`
// Anti-DNS amplification

View File

@ -682,75 +682,6 @@ func TestBlockedRequest(t *testing.T) {
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
}
func TestServerCustomClientUpstream(t *testing.T) {
const defaultCacheSize = 1024 * 1024
var upsCalledCounter uint32
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
CacheSize: defaultCacheSize,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
},
ServePlainDNS: true,
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf)
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1)
return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
})
customUpsConf := proxy.NewCustomUpstreamConfig(
&proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
true,
defaultCacheSize,
forwardConf.EDNSClientSubnet.Enabled,
)
s.conf.ClientsContainer = &aghtest.ClientsContainer{
OnUpstreamConfigByID: func(
_ string,
_ upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
return customUpsConf, nil
},
}
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
// Send test request.
req := createTestMessage("host.")
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)
require.NotEmpty(t, reply.Answer)
require.Len(t, reply.Answer, 1)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
_, err = dns.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
}
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string][]string{
"badhost.": {"NULL.example.org."},

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
@ -573,19 +574,29 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
if !pctx.Addr.IsValid() || s.conf.ClientStorage == nil {
return
}
// Use the ClientID first, since it has a higher priority.
id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
c, ok := s.conf.ClientStorage.Find(id)
if !ok {
return
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
upsConf := c.UpstreamConfig(&client.UpstreamConfig{
Bootstrap: s.bootstrap,
BootstrapAddrs: s.conf.BootstrapDNS,
UpstreamTimeout: s.conf.UpstreamTimeout,
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
})
if upsConf != nil {
log.Debug("dnsforward: using custom upstreams for client %s", id)

View File

@ -12,17 +12,13 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
)
// clientsContainer is the storage of all runtime and persistent clients.
@ -370,63 +366,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
return true
}
// type check
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
// *clientsContainer. upsConf is nil if the client isn't found or if the client
// has no custom upstreams.
func (clients *clientsContainer) UpstreamConfigByID(
id string,
bootstrap upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.storage.Find(id)
if !ok {
return nil, nil
} else if c.UpstreamConfig != nil {
return c.UpstreamConfig, nil
}
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil, nil
}
var upsConf *proxy.UpstreamConfig
upsConf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: bootstrap,
Timeout: time.Duration(config.DNS.UpstreamTimeout),
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
PreferIPv6: config.DNS.BootstrapPreferIPv6,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf = proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
config.DNS.EDNSClientSubnet.Enabled,
)
c.UpstreamConfig = conf
// TODO(s.chzhen): Pass context.
err = clients.storage.Update(context.TODO(), c.Name, c)
if err != nil {
return nil, fmt.Errorf("setting upstream config: %w", err)
}
return conf, nil
}
// type check
var _ client.AddressUpdater = (*clientsContainer)(nil)

View File

@ -1,15 +1,12 @@
package home
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -37,28 +34,3 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
return c
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add client with upstreams.
err := clients.storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
})
require.NoError(t, err)
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
assert.Nil(t, upsConf)
assert.NoError(t, err)
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
require.NotNil(t, upsConf)
assert.NoError(t, err)
}

View File

@ -235,7 +235,7 @@ func newServerConfig(
fwdConf := dnsConf.Config
fwdConf.FilterHandler = applyAdditionalFiltering
fwdConf.ClientsContainer = &Context.clients
fwdConf.ClientStorage = Context.clients.storage
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),