all: client upstream manager

This commit is contained in:
Stanislav Chzhen 2025-02-19 15:58:11 +03:00
parent a5b073d070
commit 1393417663
8 changed files with 221 additions and 75 deletions

View File

@ -0,0 +1,26 @@
package aghnet
import "github.com/AdguardTeam/dnsproxy/upstream"
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
//
// TODO(s.chzhen): !! Use in the dnsforward package.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}

View File

@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
@ -126,6 +127,8 @@ type Storage struct {
// runtimeIndex contains information about runtime clients.
runtimeIndex *runtimeIndex
upstreamManager *upstreamManager
// dhcp is used to update [SourceDHCP] runtime client information.
dhcp DHCP
@ -163,6 +166,7 @@ func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: newRuntimeIndex(),
upstreamManager: newUpstreamManager(),
dhcp: conf.DHCP,
etcHosts: conf.EtcHosts,
arpDB: conf.ARPDB,
@ -626,3 +630,38 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
func (s *Storage) AllowedTags() (tags []string) {
return s.allowedTags
}
func (s *Storage) UpstreamConfigByID(ids []string) (prxConf *proxy.CustomUpstreamConfig) {
s.mu.Lock()
defer s.mu.Unlock()
var c *Persistent
var ok bool
for _, id := range ids {
c, ok = s.index.find(id)
if ok {
break
}
}
if !ok {
return nil
}
return s.upstreamManager.customUpstreamConfig(c)
}
func (s *Storage) LatestUpstreamConfigUpdate() (t time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
return s.upstreamManager.latestConfigUpdate()
}
func (s *Storage) UpdateUpstreamConfig(conf *UpstreamConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.upstreamManager.updateConfig(conf)
}

View File

@ -0,0 +1,122 @@
package client
import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/stringutil"
)
// TODO(s.chzhen): !! Improve documentation, naming.
type UpstreamConfig struct {
Bootstrap upstream.Resolver
LatestUpstreamConfUpdate time.Time
UpstreamTimeout time.Duration
BootstrapPreferIPv6 bool
EDNSClientSubnetEnabled bool
UseHTTP3Upstreams bool
}
type clientUpstreamConfig struct {
// TODO(s.chzhen): !! Store a list of upstreams and cache settings instead.
client *Persistent
prxConf *proxy.CustomUpstreamConfig
}
type upstreamManager struct {
uidToClientConf map[UID]*clientUpstreamConfig
conf *UpstreamConfig
}
func newUpstreamManager() (m *upstreamManager) {
return &upstreamManager{
uidToClientConf: make(map[UID]*clientUpstreamConfig),
}
}
func (m *upstreamManager) latestConfigUpdate() (t time.Time) {
if m.conf == nil {
return time.Time{}
}
return m.conf.LatestUpstreamConfUpdate
}
func (m *upstreamManager) updateConfig(conf *UpstreamConfig) {
m.conf = conf
for uid, c := range m.uidToClientConf {
prxConf := newCustomUpstreamConfig(c.client, m.conf)
m.uidToClientConf[uid] = &clientUpstreamConfig{
client: c.client,
prxConf: prxConf,
}
}
}
func (m *upstreamManager) customUpstreamConfig(
c *Persistent,
) (prxConf *proxy.CustomUpstreamConfig) {
cliConf, ok := m.uidToClientConf[c.UID]
if ok {
return cliConf.prxConf
}
prxConf = newCustomUpstreamConfig(c, m.conf)
m.uidToClientConf[c.UID] = &clientUpstreamConfig{
client: c,
prxConf: prxConf,
}
return prxConf
}
// TODO(s.chzhen): !! Use it.
func (m *upstreamManager) clearCache() {
for _, c := range m.uidToClientConf {
c.prxConf.ClearCache()
}
}
// TODO(s.chzhen): !! Use it.
func (m *upstreamManager) close() {
for _, c := range m.uidToClientConf {
c.prxConf.Close()
}
}
// newCustomUpstreamConfig returns the new properly initialized custom proxy
// upstream configuration for the client.
func newCustomUpstreamConfig(
c *Persistent,
conf *UpstreamConfig,
) (prxConf *proxy.CustomUpstreamConfig) {
upstreams := stringutil.FilterOut(c.Upstreams, aghnet.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil
}
upsConf, err := proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: conf.Bootstrap,
Timeout: time.Duration(conf.UpstreamTimeout),
HTTPVersions: aghnet.UpstreamHTTPVersions(conf.UseHTTP3Upstreams),
PreferIPv6: conf.BootstrapPreferIPv6,
},
)
if err != nil {
// Should not happen because upstreams are already validated. See
// [Persistent.validate].
panic(err)
}
return proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
conf.EDNSClientSubnetEnabled,
)
}

View File

@ -36,10 +36,11 @@ type ClientsContainer interface {
// returns nil if there is no custom upstream configuration for the client.
// The id is expected to be either a string representation of an IP address
// or the ClientID.
UpstreamConfigByID(
id string,
boot upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error)
UpstreamConfigByID(ids []string) (conf *proxy.CustomUpstreamConfig)
LatestUpstreamConfigUpdate() (t time.Time)
UpdateUpstreamConfig(conf *client.UpstreamConfig)
}
// Config represents the DNS filtering configuration of AdGuard Home. The zero

View File

@ -185,6 +185,8 @@ type Server struct {
// conf is the current configuration of the server.
conf ServerConfig
latestUpstreamConfUpdate time.Time
// serverLock protects Server.
serverLock sync.RWMutex
}
@ -558,6 +560,8 @@ func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
s.conf.UpstreamConfig = uc
s.latestUpstreamConfUpdate = time.Now()
return nil
}

View File

@ -1,7 +1,6 @@
package dnsforward
import (
"cmp"
"context"
"encoding/binary"
"net"
@ -9,6 +8,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
@ -577,17 +577,32 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
return
}
// Use the ClientID first, since it has a higher priority.
id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
func() {
t := s.conf.ClientsContainer.LatestUpstreamConfigUpdate()
return
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if t == s.latestUpstreamConfUpdate {
return
}
s.conf.ClientsContainer.UpdateUpstreamConfig(&client.UpstreamConfig{
Bootstrap: s.bootstrap,
LatestUpstreamConfUpdate: s.latestUpstreamConfUpdate,
UpstreamTimeout: s.conf.UpstreamTimeout,
BootstrapPreferIPv6: s.conf.BootstrapPreferIPv6,
EDNSClientSubnetEnabled: s.conf.EDNSClientSubnet.Enabled,
UseHTTP3Upstreams: s.conf.UseHTTP3Upstreams,
})
}()
// Use the ClientID first, since it has a higher priority.
ids := []string{clientID, pctx.Addr.Addr().String()}
upsConf := s.conf.ClientsContainer.UpstreamConfigByID(ids)
if upsConf != nil {
log.Debug("dnsforward: using custom upstreams for client %s", id)
log.Debug("dnsforward: using custom upstreams for client %s", ids)
pctx.CustomUpstreamConfig = upsConf
}

View File

@ -12,17 +12,13 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
)
// clientsContainer is the storage of all runtime and persistent clients.
@ -370,63 +366,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
return true
}
// type check
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
// *clientsContainer. upsConf is nil if the client isn't found or if the client
// has no custom upstreams.
func (clients *clientsContainer) UpstreamConfigByID(
id string,
bootstrap upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.storage.Find(id)
if !ok {
return nil, nil
} else if c.UpstreamConfig != nil {
return c.UpstreamConfig, nil
}
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil, nil
}
var upsConf *proxy.UpstreamConfig
upsConf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: bootstrap,
Timeout: time.Duration(config.DNS.UpstreamTimeout),
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
PreferIPv6: config.DNS.BootstrapPreferIPv6,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf = proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
config.DNS.EDNSClientSubnet.Enabled,
)
c.UpstreamConfig = conf
// TODO(s.chzhen): Pass context.
err = clients.storage.Update(context.TODO(), c.Name, c)
if err != nil {
return nil, fmt.Errorf("setting upstream config: %w", err)
}
return conf, nil
}
// type check
var _ client.AddressUpdater = (*clientsContainer)(nil)

View File

@ -235,7 +235,7 @@ func newServerConfig(
fwdConf := dnsConf.Config
fwdConf.FilterHandler = applyAdditionalFiltering
fwdConf.ClientsContainer = &Context.clients
fwdConf.ClientsContainer = Context.clients.storage
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),