AdGuardDNS/internal/ecscache/ecscache_test.go
Andrey Meshkov 87137bddcf Sync v2.10.0
2024-11-08 16:26:22 +03:00

694 lines
17 KiB
Go

package ecscache_test
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/ecscache"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Common test domain names.
const (
reqHostname = "example.com."
reqCNAME = "cname.example.com."
reqNS1 = "ns1.example.com."
reqNS2 = "ns2.example.com."
)
// defaultTTL is the default TTL to use in tests.
const defaultTTL uint32 = 3600
// remoteIP is the IP to use for tests.
var remoteIP = netip.MustParseAddr("1.2.3.4")
func TestMiddleware_Wrap_noECS(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
cnameReq := dnsservertest.NewReq(reqHostname, dns.TypeCNAME, dns.ClassINET)
cnameAns := dnsservertest.SectionAnswer{
dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME),
}
soaNS := dnsservertest.SectionNs{
dnsservertest.NewSOA(reqHostname, defaultTTL, reqNS1, reqNS2),
}
knownIP := netip.MustParseAddr("1.2.3.4")
testTTL := 60 * time.Second
const N = 5
testCases := []struct {
req *dns.Msg
resp *dns.Msg
minTTL *time.Duration
name string
wantNumReq int
wantTTL uint32
}{{
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, defaultTTL, knownIP),
}),
name: "simple_a",
wantNumReq: 1,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq),
name: "empty_answer",
wantNumReq: N,
wantTTL: 0,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, soaNS),
name: "authoritative_nodata",
wantNumReq: 1,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, cnameAns, soaNS),
name: "nodata_with_cname",
wantNumReq: 1,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, cnameAns),
name: "nodata_with_cname_no_soa",
wantNumReq: N,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.SectionNs{
dnsservertest.NewNS(reqHostname, defaultTTL, reqNS1),
}),
name: "non_authoritative_nxdomain",
// TODO(ameshkov): Consider https://datatracker.ietf.org/doc/html/rfc2308#section-3.
wantNumReq: 1,
wantTTL: 0,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, soaNS),
name: "authoritative_nxdomain",
wantNumReq: 1,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeServerFailure, aReq),
name: "simple_server_failure",
wantNumReq: 1,
wantTTL: dnsmsg.ServFailMaxCacheTTL,
minTTL: nil,
}, {
req: cnameReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.SectionAnswer{
dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME),
}),
name: "simple_cname_ans",
wantNumReq: 1,
wantTTL: defaultTTL,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 0, knownIP),
}),
name: "expired_one",
wantNumReq: N,
wantTTL: 0,
minTTL: nil,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 10, knownIP),
}),
name: "override_ttl_ok",
wantNumReq: 1,
minTTL: &testTTL,
wantTTL: uint32(testTTL.Seconds()),
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 1000, knownIP),
}),
name: "override_ttl_max",
wantNumReq: 1,
minTTL: &testTTL,
wantTTL: 1000,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 0, knownIP),
}),
name: "override_ttl_zero",
wantNumReq: N,
minTTL: &testTTL,
wantTTL: 0,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeServerFailure, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, dnsmsg.ServFailMaxCacheTTL, knownIP),
}),
name: "override_ttl_servfail",
wantNumReq: 1,
minTTL: nil,
wantTTL: dnsmsg.ServFailMaxCacheTTL,
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeNotImplemented, aReq, soaNS),
name: "unexpected_response",
wantNumReq: N,
wantTTL: 0,
minTTL: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
numReq := 0
handler := dnsserver.HandlerFunc(
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
numReq++
return rw.WriteMsg(ctx, req, tc.resp)
},
)
var minTTL time.Duration
if tc.minTTL != nil {
minTTL = *tc.minTTL
}
withCache := newWithCache(
t,
handler,
geoip.CountryNone,
netutil.ZeroPrefix(netutil.AddrFamilyIPv4),
minTTL,
tc.minTTL != nil,
)
ri := &agd.RequestInfo{
Host: tc.req.Question[0].Name,
RemoteIP: remoteIP,
}
var msg *dns.Msg
for range N {
msg = exchange(t, ri, withCache, tc.req)
}
assert.Equal(t, tc.resp, msg)
assert.Equal(t, tc.wantNumReq, numReq)
if len(msg.Answer) > 0 {
assert.Equal(t, tc.wantTTL, msg.Answer[0].Header().Ttl)
}
})
}
}
const prefixLen = 24
// newAReq returns new test A request with ECS option.
func newAReq(hostname string, ip net.IP) (req *dns.Msg) {
aReqNoECS := dnsservertest.NewReq(hostname, dns.TypeA, dns.ClassINET)
aReqNoECS.SetEdns0(dnsmsg.DefaultEDNSUDPSize, false)
req = aReqNoECS.Copy()
opt := req.Extra[len(req.Extra)-1].(*dns.OPT)
opt.Option = append(opt.Option, &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
Family: uint16(netutil.AddrFamilyIPv4),
SourceNetmask: prefixLen,
SourceScope: 0,
Address: ip,
})
return req
}
func TestMiddleware_Wrap_ecs(t *testing.T) {
aReqNoECS := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
ip := net.IP{1, 2, 3, 0}
aReq := newAReq(reqHostname, ip)
var fakeECSFQDN string
ecscache.FakeECSFQDNs.Range(func(s string) (cont bool) {
fakeECSFQDN = s
return false
})
require.NotZero(t, fakeECSFQDN)
fakeECSReq := newAReq(fakeECSFQDN, ip)
subnet := netip.PrefixFrom(netip.AddrFrom4([4]byte(ip)), prefixLen)
const ctry = geoip.CountryAD
defaultCtrySubnet := netip.MustParsePrefix("1.2.0.0/16")
ecsExtra := dnsservertest.NewECSExtra(
net.IP{1, 2, 0, 0},
uint16(netutil.AddrFamilyIPv4),
20,
20,
)
testCases := []struct {
req *dns.Msg
respECS dns.RR
wantECS *dnsmsg.ECS
ctrySubnet netip.Prefix
name string
}{{
req: aReq,
respECS: ecsExtra,
wantECS: &dnsmsg.ECS{
Location: &geoip.Location{
Country: ctry,
},
Subnet: subnet,
Scope: 0,
},
ctrySubnet: defaultCtrySubnet,
name: "with_country",
}, {
req: aReq,
respECS: ecsExtra,
wantECS: &dnsmsg.ECS{
Location: &geoip.Location{
Country: ctry,
},
Subnet: subnet,
Scope: 0,
},
ctrySubnet: netutil.ZeroPrefix(netutil.AddrFamilyIPv4),
name: "no_country",
}, {
req: aReqNoECS,
respECS: ecsExtra,
wantECS: nil,
ctrySubnet: defaultCtrySubnet,
name: "edns_no_ecs",
}, {
req: aReq,
respECS: ecsExtra,
wantECS: nil,
ctrySubnet: defaultCtrySubnet,
name: "country_from_ip",
}, {
req: aReq,
respECS: dnsservertest.NewECSExtra(
netutil.IPv4Zero(),
uint16(netutil.AddrFamilyIPv4),
0,
0,
),
wantECS: &dnsmsg.ECS{
Location: &geoip.Location{
Country: geoip.CountryNone,
},
Subnet: netutil.ZeroPrefix(netutil.AddrFamilyIPv4),
Scope: 0,
},
ctrySubnet: defaultCtrySubnet,
name: "zero_ecs",
}, {
req: fakeECSReq,
respECS: ecsExtra,
wantECS: &dnsmsg.ECS{
Location: &geoip.Location{
Country: ctry,
},
Subnet: netutil.ZeroPrefix(netutil.AddrFamilyIPv4),
Scope: 0,
},
ctrySubnet: defaultCtrySubnet,
name: "fake_ecs_domain",
}}
const N = 5
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp := dnsservertest.NewResp(
dns.RcodeSuccess,
aReq,
dnsservertest.SectionAnswer{dnsservertest.NewA(
reqHostname,
defaultTTL,
netip.MustParseAddr("1.2.3.4"),
)},
dnsservertest.SectionExtra{tc.respECS},
)
numReq := 0
handler := dnsserver.HandlerFunc(
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
numReq++
return rw.WriteMsg(ctx, req, resp)
},
)
withCache := newWithCache(t, handler, ctry, tc.ctrySubnet, 0, false)
ri := &agd.RequestInfo{
Location: &geoip.Location{
Country: ctry,
},
ECS: tc.wantECS,
Host: tc.req.Question[0].Name,
RemoteIP: remoteIP,
}
var msg *dns.Msg
for range N {
msg = exchange(t, ri, withCache, tc.req)
}
require.NotNil(t, msg)
assert.Equal(t, 1, numReq)
require.NotEmpty(t, msg.Answer)
assert.Equal(t, defaultTTL, msg.Answer[0].Header().Ttl)
assertEDNSOpt(t, tc.wantECS, msg.IsEdns0())
})
}
}
// assertEDNSOpt is a helper function that checks ECS and EDNS0 options.
func assertEDNSOpt(t *testing.T, ecs *dnsmsg.ECS, edns *dns.OPT) {
t.Helper()
if ecs == nil {
if edns != nil {
assert.Empty(t, edns.Option)
}
return
}
require.Len(t, edns.Option, 1)
subnetOpt := testutil.RequireTypeAssert[*dns.EDNS0_SUBNET](t, edns.Option[0])
assert.Equal(t, net.IP(ecs.Subnet.Addr().AsSlice()), subnetOpt.Address)
assert.Equal(t, uint8(ecs.Subnet.Bits()), subnetOpt.SourceNetmask)
assert.Equal(t, uint8(ecs.Subnet.Bits()), subnetOpt.SourceScope)
}
func TestMiddleware_Wrap_ecsOrder(t *testing.T) {
// Helper values and functions
reqNoECS := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
reqNoECS.SetEdns0(dnsmsg.DefaultEDNSUDPSize, false)
const prefixLen = 24
reqWithECS := reqNoECS.Copy()
opt := testutil.RequireTypeAssert[*dns.OPT](t, reqWithECS.Extra[len(reqWithECS.Extra)-1])
opt.Option = append(opt.Option, &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
Family: uint16(netutil.AddrFamilyIPv4),
SourceNetmask: prefixLen,
SourceScope: 0,
Address: netip.PrefixFrom(remoteIP, prefixLen).Masked().Addr().AsSlice(),
})
reqZeroECS := reqNoECS.Copy()
opt = testutil.RequireTypeAssert[*dns.OPT](t, reqZeroECS.Extra[len(reqZeroECS.Extra)-1])
opt.Option = append(opt.Option, &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
Family: uint16(netutil.AddrFamilyIPv4),
SourceNetmask: 0,
SourceScope: 0,
Address: netutil.IPv4Zero(),
})
const ctry = geoip.CountryAD
ctrySubnet := netip.PrefixFrom(remoteIP, 16).Masked()
ctryECS := dnsservertest.NewECSExtra(
ctrySubnet.Addr().AsSlice(),
uint16(netutil.AddrFamilyIPv4),
prefixLen,
16,
)
zeroECS := dnsservertest.NewECSExtra(netutil.IPv4Zero(), uint16(netutil.AddrFamilyIPv4), 0, 0)
answerA := dnsservertest.NewA(reqHostname, defaultTTL, netip.MustParseAddr("1.2.3.4"))
answerB := dnsservertest.NewA(reqHostname, defaultTTL, netip.MustParseAddr("5.6.7.8"))
// Tests
testCases := []struct {
name string
sequence sequence
}{{
name: "no_ecs_first",
sequence: sequence{{
answer: answerA,
extra: ctryECS,
msg: reqNoECS,
wantAns: []dns.RR{answerA},
}, {
answer: nil,
extra: nil,
msg: reqWithECS,
wantAns: []dns.RR{answerA},
}},
}, {
name: "ecs_first",
sequence: sequence{{
answer: answerA,
extra: ctryECS,
msg: reqWithECS,
wantAns: []dns.RR{answerA},
}, {
answer: nil,
extra: nil,
msg: reqNoECS,
wantAns: []dns.RR{answerA},
}},
}, {
name: "zero_after_no_ecs",
sequence: sequence{{
answer: answerA,
extra: ctryECS,
msg: reqNoECS,
wantAns: []dns.RR{answerA},
}, {
answer: answerB,
extra: zeroECS,
msg: reqZeroECS,
wantAns: []dns.RR{answerB},
}},
}, {
name: "different_caches",
sequence: sequence{{
answer: answerA,
extra: ctryECS,
msg: reqWithECS,
wantAns: []dns.RR{answerA},
}, {
answer: answerB,
extra: zeroECS,
msg: reqZeroECS,
wantAns: []dns.RR{answerB},
}, {
answer: nil,
extra: nil,
msg: reqWithECS,
wantAns: []dns.RR{answerA},
}, {
answer: nil,
extra: nil,
msg: reqZeroECS,
wantAns: []dns.RR{answerB},
}},
}, {
name: "no_ecs_upstream",
sequence: sequence{{
answer: answerA,
extra: zeroECS,
msg: reqZeroECS,
wantAns: []dns.RR{answerA},
}, {
answer: answerB,
extra: zeroECS,
msg: reqNoECS,
wantAns: []dns.RR{answerB},
}, {
answer: nil,
extra: nil,
msg: reqZeroECS,
wantAns: []dns.RR{answerA},
}, {
answer: nil,
extra: nil,
msg: reqNoECS,
wantAns: []dns.RR{answerB},
}},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.sequence.run(t, ctry, ctrySubnet)
})
}
}
// request is a single request in a sequence. answer and extra are
// prerequisites for configuring handler's response before resolving msg,
// those should be nil when the response is expected to come from cache.
type request = struct {
answer dns.RR
extra dns.RR
msg *dns.Msg
wantAns []dns.RR
}
// sequence is a list of requests.
type sequence []request
// run is a helper method for testing ECS cache middleware with sequence of
// ordered requests.
func (s sequence) run(t *testing.T, ctry geoip.Country, ctrySubnet netip.Prefix) {
t.Helper()
const respSendTimeout = 1 * time.Second
newResp := func(t *testing.T, req *dns.Msg, answer, extra dns.RR) (resp *dns.Msg) {
t.Helper()
return dnsservertest.NewResp(
dns.RcodeSuccess,
req,
dnsservertest.SectionAnswer{answer},
dnsservertest.SectionExtra{extra},
)
}
pt := testutil.PanicT{}
respCh := make(chan *dns.Msg, 1)
handler := dnsserver.HandlerFunc(
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
resp, ok := testutil.RequireReceive(pt, respCh, respSendTimeout)
require.True(pt, ok)
return rw.WriteMsg(ctx, req, resp)
},
)
withCache := newWithCache(t, handler, ctry, ctrySubnet, 0, false)
for i, req := range s {
if req.answer != nil && req.extra != nil {
resp := newResp(t, req.msg, req.answer, req.extra)
testutil.RequireSend(t, respCh, resp, respSendTimeout)
}
subnet, _, err := dnsmsg.ECSFromMsg(req.msg)
require.NoError(t, err)
ri := &agd.RequestInfo{
Location: &geoip.Location{Country: ctry},
ECS: nil,
Host: req.msg.Question[0].Name,
RemoteIP: remoteIP,
}
if subnet != (netip.Prefix{}) {
ri.ECS = &dnsmsg.ECS{Subnet: subnet, Scope: 0}
}
// Make sure each step succeeded.
require.True(t, t.Run(fmt.Sprintf("step_%d", i), func(t *testing.T) {
got := exchange(t, ri, withCache, req.msg)
assert.Equal(t, req.wantAns, got.Answer)
}))
}
}
// exchange resolves req with h using context with ri.
func exchange(
t testing.TB,
ri *agd.RequestInfo,
h dnsserver.Handler,
req *dns.Msg,
) (resp *dns.Msg) {
t.Helper()
// TODO(a.garipov): Propose netip.Addr.WithPort.
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
ctx := agd.ContextWithRequestInfo(context.Background(), ri)
err := h.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
msg := nrw.Msg()
require.NotNil(t, msg)
return msg
}
// newWithCache is a helper constructor of a handler for tests.
func newWithCache(
t testing.TB,
h dnsserver.Handler,
wantCtry geoip.Country,
geoIPNet netip.Prefix,
minTTL time.Duration,
useTTLOverride bool,
) (wrapped dnsserver.Handler) {
t.Helper()
pt := testutil.PanicT{}
geoIP := agdtest.NewGeoIP()
// TODO(a.garipov): Actually test ASNs once we have the data.
geoIP.OnSubnetByLocation = func(
l *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
require.Equal(pt, wantCtry, l.Country)
return geoIPNet, nil
}
return dnsserver.WithMiddlewares(
h,
ecscache.NewMiddleware(&ecscache.MiddlewareConfig{
Cloner: agdtest.NewCloner(),
Logger: slogutil.NewDiscardLogger(),
CacheManager: agdcache.EmptyManager{},
GeoIP: geoIP,
NoECSCount: 100,
ECSCount: 100,
MinTTL: minTTL,
OverrideTTL: useTTLOverride,
}),
)
}