mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-02-20 11:44:09 +08:00
all: client custom upstream config
This commit is contained in:
parent
2fe2d254b5
commit
d9dbaf0c9b
@ -3,7 +3,6 @@ package aghtest_test
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
)
|
||||
|
||||
@ -12,9 +11,6 @@ import (
|
||||
// type check
|
||||
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||
|
||||
// type check
|
||||
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
|
||||
|
||||
// type check
|
||||
//
|
||||
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@ -16,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@ -58,11 +61,16 @@ func (uid *UID) UnmarshalText(data []byte) error {
|
||||
|
||||
// Persistent contains information about persistent clients.
|
||||
type Persistent struct {
|
||||
// UpstreamConfig is the custom upstream configuration for this client. If
|
||||
// mu protects upstreamConfig, customUpstreamConf.
|
||||
mu *sync.Mutex
|
||||
|
||||
// upstreamConfig is the custom upstream configuration for this client. If
|
||||
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
||||
// there are no valid upstreams. If it's non-nil and non-empty, these
|
||||
// upstream must be used.
|
||||
UpstreamConfig *proxy.CustomUpstreamConfig
|
||||
upstreamConfig *proxy.CustomUpstreamConfig
|
||||
|
||||
customUpstreamConf *UpstreamConfig
|
||||
|
||||
// SafeSearch handles search engine hosts rewrites.
|
||||
SafeSearch filtering.SafeSearch
|
||||
@ -166,6 +174,8 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
|
||||
// TODO(s.chzhen): Move to the constructor.
|
||||
slices.Sort(c.Tags)
|
||||
|
||||
c.mu = &sync.Mutex{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -189,6 +199,106 @@ func (c *Persistent) SetIDs(ids []string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpstreamConfig is a custom upstream configuration for the persistent client.
|
||||
type UpstreamConfig struct {
|
||||
Bootstrap upstream.Resolver
|
||||
BootstrapAddrs []string
|
||||
UpstreamTimeout time.Duration
|
||||
BootstrapPreferIPv6 bool
|
||||
EDNSClientSubnetEnabled bool
|
||||
UseHTTP3Upstreams bool
|
||||
}
|
||||
|
||||
func (conf *UpstreamConfig) Clone() (clone *UpstreamConfig) {
|
||||
c := *conf
|
||||
c.BootstrapAddrs = slices.Clone(conf.BootstrapAddrs)
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func equalUpstreamSettings(a, b *UpstreamConfig) (ok bool) {
|
||||
if a == nil {
|
||||
return b == nil
|
||||
} else if b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !slices.Equal(a.BootstrapAddrs, b.BootstrapAddrs) {
|
||||
return false
|
||||
}
|
||||
|
||||
if a.UpstreamTimeout != b.UpstreamTimeout {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func isCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// upstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func upstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// UpstreamConfig returns the custom upstream configuration for the client.
|
||||
func (c *Persistent) UpstreamConfig(upsConf *UpstreamConfig) (prxConf *proxy.CustomUpstreamConfig) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !equalUpstreamSettings(c.customUpstreamConf, upsConf) {
|
||||
c.setUpstreamConfig(upsConf)
|
||||
}
|
||||
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// setUpstreamConfig sets the custom upstream configuration for the client.
|
||||
func (c *Persistent) setUpstreamConfig(conf *UpstreamConfig) {
|
||||
upstreams := stringutil.FilterOut(c.Upstreams, isCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
upsConf, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: conf.Bootstrap,
|
||||
Timeout: time.Duration(conf.UpstreamTimeout),
|
||||
HTTPVersions: upstreamHTTPVersions(conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Should not happen because upstreams are already validated. See
|
||||
// [Persistent.validate].
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c.upstreamConfig = proxy.NewCustomUpstreamConfig(
|
||||
upsConf,
|
||||
c.UpstreamsCacheEnabled,
|
||||
int(c.UpstreamsCacheSize),
|
||||
conf.EDNSClientSubnetEnabled,
|
||||
)
|
||||
|
||||
c.customUpstreamConf = conf.Clone()
|
||||
}
|
||||
|
||||
// subnetCompare is a comparison function for the two subnets. It returns -1 if
|
||||
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
|
||||
// position is the same.
|
||||
@ -315,8 +425,8 @@ func (c *Persistent) ShallowClone() (clone *Persistent) {
|
||||
|
||||
// CloseUpstreams closes the client-specific upstream config of c if any.
|
||||
func (c *Persistent) CloseUpstreams() (err error) {
|
||||
if c.UpstreamConfig != nil {
|
||||
if err = c.UpstreamConfig.Close(); err != nil {
|
||||
if c.upstreamConfig != nil {
|
||||
if err = c.upstreamConfig.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
||||
}
|
||||
}
|
||||
|
@ -29,19 +29,6 @@ import (
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
)
|
||||
|
||||
// ClientsContainer provides information about preconfigured DNS clients.
|
||||
type ClientsContainer interface {
|
||||
// UpstreamConfigByID returns the custom upstream configuration for the
|
||||
// client having id, using boot to initialize the one if necessary. It
|
||||
// returns nil if there is no custom upstream configuration for the client.
|
||||
// The id is expected to be either a string representation of an IP address
|
||||
// or the ClientID.
|
||||
UpstreamConfigByID(
|
||||
id string,
|
||||
boot upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
||||
}
|
||||
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config is empty and ready for use.
|
||||
type Config struct {
|
||||
@ -50,9 +37,8 @@ type Config struct {
|
||||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
|
||||
// ClientsContainer stores the information about special handling of some
|
||||
// DNS clients.
|
||||
ClientsContainer ClientsContainer `yaml:"-"`
|
||||
// ClientStorage stores information about persistent clients.
|
||||
ClientStorage *client.Storage `yaml:"-"`
|
||||
|
||||
// Anti-DNS amplification
|
||||
|
||||
|
@ -682,75 +682,6 @@ func TestBlockedRequest(t *testing.T) {
|
||||
assert.True(t, reply.Answer[0].(*dns.A).A.IsUnspecified())
|
||||
}
|
||||
|
||||
func TestServerCustomClientUpstream(t *testing.T) {
|
||||
const defaultCacheSize = 1024 * 1024
|
||||
|
||||
var upsCalledCounter uint32
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
Config: Config{
|
||||
CacheSize: defaultCacheSize,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf)
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
atomic.AddUint32(&upsCalledCounter, 1)
|
||||
|
||||
return cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
|
||||
customUpsConf := proxy.NewCustomUpstreamConfig(
|
||||
&proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
},
|
||||
true,
|
||||
defaultCacheSize,
|
||||
forwardConf.EDNSClientSubnet.Enabled,
|
||||
)
|
||||
|
||||
s.conf.ClientsContainer = &aghtest.ClientsContainer{
|
||||
OnUpstreamConfigByID: func(
|
||||
_ string,
|
||||
_ upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
return customUpsConf, nil
|
||||
},
|
||||
}
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
|
||||
// Send test request.
|
||||
req := createTestMessage("host.")
|
||||
|
||||
reply, err := dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, reply.Answer)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A)
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
|
||||
|
||||
_, err = dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&upsCalledCounter))
|
||||
}
|
||||
|
||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||
var testCNAMEs = map[string][]string{
|
||||
"badhost.": {"NULL.example.org."},
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@ -573,19 +574,29 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
|
||||
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
|
||||
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
|
||||
if !pctx.Addr.IsValid() || s.conf.ClientStorage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := cmp.Or(clientID, pctx.Addr.Addr().String())
|
||||
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
|
||||
c, ok := s.conf.ClientStorage.Find(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
upsConf := c.UpstreamConfig(&client.UpstreamConfig{
|
||||
Bootstrap: s.bootstrap,
|
||||
BootstrapAddrs: s.conf.BootstrapDNS,
|
||||
UpstreamTimeout: s.conf.UpstreamTimeout,
|
||||
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
|
||||
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
|
||||
})
|
||||
|
||||
if upsConf != nil {
|
||||
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
||||
|
||||
|
@ -12,17 +12,13 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
@ -370,63 +366,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
|
||||
|
||||
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
|
||||
// *clientsContainer. upsConf is nil if the client isn't found or if the client
|
||||
// has no custom upstreams.
|
||||
func (clients *clientsContainer) UpstreamConfigByID(
|
||||
id string,
|
||||
bootstrap upstream.Resolver,
|
||||
) (conf *proxy.CustomUpstreamConfig, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.storage.Find(id)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else if c.UpstreamConfig != nil {
|
||||
return c.UpstreamConfig, nil
|
||||
}
|
||||
|
||||
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var upsConf *proxy.UpstreamConfig
|
||||
upsConf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: time.Duration(config.DNS.UpstreamTimeout),
|
||||
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
|
||||
PreferIPv6: config.DNS.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf = proxy.NewCustomUpstreamConfig(
|
||||
upsConf,
|
||||
c.UpstreamsCacheEnabled,
|
||||
int(c.UpstreamsCacheSize),
|
||||
config.DNS.EDNSClientSubnet.Enabled,
|
||||
)
|
||||
c.UpstreamConfig = conf
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err = clients.storage.Update(context.TODO(), c.Name, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting upstream config: %w", err)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||
|
||||
|
@ -1,15 +1,12 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -37,28 +34,3 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add client with upstreams.
|
||||
err := clients.storage.Add(ctx, &client.Persistent{
|
||||
Name: "client1",
|
||||
UID: client.MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
|
||||
assert.Nil(t, upsConf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
|
||||
require.NotNil(t, upsConf)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -235,7 +235,7 @@ func newServerConfig(
|
||||
|
||||
fwdConf := dnsConf.Config
|
||||
fwdConf.FilterHandler = applyAdditionalFiltering
|
||||
fwdConf.ClientsContainer = &Context.clients
|
||||
fwdConf.ClientStorage = Context.clients.storage
|
||||
|
||||
newConf = &dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
|
Loading…
x
Reference in New Issue
Block a user