AdGuardDNS/internal/bindtodevice/socket_linux_internal_test.go
Andrey Meshkov 5690301129 Sync v2.7.0
2024-06-07 14:27:46 +03:00

545 lines
14 KiB
Go

//go:build linux
package bindtodevice
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
"syscall"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// TestInterfaceEnvVarName is the environment variable name the presence and
// value of which define whether to run the SO_BINDTODEVICE tests and on which
// network interface.
const TestInterfaceEnvVarName = "ADGUARD_DNS_TEST_NET_INTERFACE"
// InterfaceForTests returns the network interface designated for tests, if
// any, as well as its first network.
func InterfaceForTests(t testing.TB) (iface *net.Interface, ifaceNet *net.IPNet) {
t.Helper()
ifaceName, ok := os.LookupEnv(TestInterfaceEnvVarName)
if !ok {
return nil, nil
}
iface, err := net.InterfaceByName(ifaceName)
require.NoError(t, err)
reqAddrs, err := iface.Addrs()
require.NoError(t, err)
require.NotEmpty(t, reqAddrs)
ifaceNet = testutil.RequireTypeAssert[*net.IPNet](t, reqAddrs[0])
masked := &net.IPNet{
IP: ifaceNet.IP.Mask(ifaceNet.Mask),
Mask: ifaceNet.Mask,
}
t.Logf(
"assuming following command has been called:\n"+
"ip route add local %[1]s dev %[2]s\n"+
"after the test:\n"+
"ip route del local %[1]s dev %[2]s",
masked,
ifaceName,
)
return iface, ifaceNet
}
// TestListenControl checks the SO_BINDTODEVICE handling. The test assumes that
// the correct routing has already been set up on the machine. To test the
// package an actual network interface is required. To set that up:
//
// 1. Run ip a to locate the interface you want to use and its subnet. For
// example, "wlp3s0" and "192.168.10.0/23".
//
// 2. Add a route for that interface: "ip route add local 192.168.10.0/23 dev
// wlp3s0". You might need sudo for that.
//
// 3. Run the test itself: "env ADGUARD_DNS_TEST_NET_INTERFACE='wlp3s0' go test
// -v ./internal/bindtodevice/".
//
// 4. Delete the route you added in step 2: "ip route del local 192.168.10.0/23
// dev wlp3s0". You might need sudo for that.
//
// An all-in-one example, with sudo:
//
// sudo ip route add local 192.168.10.0/23 dev wlp3s0\
// ; env ADGUARD_DNS_TEST_NET_INTERFACE='wlp3s0'\
// go test ./internal/bindtodevice/\
// ; sudo ip route del local 192.168.10.0/23 dev wlp3s0
func TestListenControl(t *testing.T) {
iface, ifaceNet := InterfaceForTests(t)
if iface == nil {
t.Skipf("test %s skipped: please set env var %s", t.Name(), TestInterfaceEnvVarName)
}
ifaceName := iface.Name
lc := newListenConfig(ifaceName, &ControlConfig{})
require.NotNil(t, lc)
t.Run("tcp", func(t *testing.T) {
SubtestListenControlTCP(t, lc, ifaceName, ifaceNet)
})
t.Run("udp", func(t *testing.T) {
SubtestListenControlUDP(t, lc, ifaceName, ifaceNet)
})
}
// SubtestListenControlTCP is a shared subtest that uses lc to dial a listener
// and perform two-way communication using the resulting connection.
func SubtestListenControlTCP(
t *testing.T,
lc netext.ListenConfig,
ifaceName string,
ifaceNet *net.IPNet,
) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
lsnr, err := lc.Listen(ctx, "tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, lsnr.Close)
// Make sure we can work with [agdnet.PrefixNetAddr] as well.
addrStr, _, _ := strings.Cut(lsnr.Addr().String(), "/")
addr, err := netip.ParseAddrPort(addrStr)
require.NoError(t, err)
addrPort := int(addr.Port())
ifaceAddr := &net.TCPAddr{
IP: ifaceNet.IP,
Port: addrPort,
}
normalize(ifaceAddr)
t.Run("main_interface_addr", func(t *testing.T) {
t.Logf("using addr %s for iface %s", ifaceAddr, ifaceName)
testListenControlTCPQuery(t, lsnr, ifaceAddr)
})
t.Run("other_interface_addr", func(t *testing.T) {
otherIfaceAddr := &net.TCPAddr{
IP: closestIP(t, ifaceNet, ifaceAddr.IP),
Port: ifaceAddr.Port,
}
normalize(otherIfaceAddr)
t.Logf("using addr %s for iface %s", otherIfaceAddr, ifaceName)
testListenControlTCPQuery(t, lsnr, otherIfaceAddr)
})
}
func testListenControlTCPQuery(t *testing.T, lsnr net.Listener, reqAddr *net.TCPAddr) {
req, resp := []byte("hello"), []byte("world")
reqLen, respLen := len(req), len(resp)
go requestTCP(reqAddr, slices.Clone(req), slices.Clone(resp))
localConn, err := lsnr.Accept()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, localConn.Close)
laddr := testutil.RequireTypeAssert[*net.TCPAddr](t, localConn.LocalAddr())
normalize(laddr)
assert.Equal(t, reqAddr, laddr)
err = localConn.SetReadDeadline(time.Now().Add(testTimeout))
require.NoError(t, err)
gotReq := make([]byte, reqLen)
n, err := localConn.Read(gotReq)
require.NoError(t, err)
assert.Equal(t, reqLen, n)
assert.Equal(t, req, gotReq)
err = localConn.SetWriteDeadline(time.Now().Add(testTimeout))
require.NoError(t, err)
n, err = localConn.Write(resp)
require.NoError(t, err)
assert.Equal(t, respLen, n)
}
// SubtestListenControlUDP is a shared subtest that uses lc to dial a packet
// connection and perform two-way communication with it.
func SubtestListenControlUDP(
t *testing.T,
lc netext.ListenConfig,
ifaceName string,
ifaceNet *net.IPNet,
) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
packetConn, err := lc.ListenPacket(ctx, "udp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, packetConn.Close)
// Make sure we can work with [agdnet.PrefixNetAddr] as well.
addrStr, _, _ := strings.Cut(packetConn.LocalAddr().String(), "/")
addr, err := netip.ParseAddrPort(addrStr)
require.NoError(t, err)
addrPort := int(addr.Port())
ifaceAddr := &net.UDPAddr{
IP: ifaceNet.IP,
Port: addrPort,
}
normalize(ifaceAddr)
t.Run("main_interface_addr", func(t *testing.T) {
t.Logf("using addr %s for iface %s", ifaceAddr, ifaceName)
testListenControlUDPQuery(t, packetConn, ifaceAddr)
})
t.Run("other_interface_addr", func(t *testing.T) {
otherIfaceAddr := &net.UDPAddr{
IP: closestIP(t, ifaceNet, ifaceAddr.IP),
Port: ifaceAddr.Port,
}
normalize(otherIfaceAddr)
t.Logf("using addr %s for iface %s", otherIfaceAddr, ifaceName)
testListenControlUDPQuery(t, packetConn, otherIfaceAddr)
})
}
func testListenControlUDPQuery(t *testing.T, packetConn net.PacketConn, reqAddr *net.UDPAddr) {
req, resp := []byte("hello"), []byte("world")
reqLen, respLen := len(req), len(resp)
go requestUDP(reqAddr, slices.Clone(req), slices.Clone(resp))
err := packetConn.SetReadDeadline(time.Now().Add(testTimeout))
require.NoError(t, err)
b := make([]byte, reqLen)
oob := make([]byte, netext.IPDstOOBSize)
var sess *packetSession
switch c := packetConn.(type) {
case *net.UDPConn:
sess, err = readPacketSession(c, b, oob)
require.NoError(t, err)
case netext.SessionPacketConn:
var s netext.PacketSession
_, s, err = c.ReadFromSession(req)
require.NoError(t, err)
sess = testutil.RequireTypeAssert[*packetSession](t, s)
default:
t.Fatalf("bad packet conn type %T(%[1]v)", c)
}
assert.Equal(t, reqAddr, sess.laddr)
assert.Equal(t, req, sess.readBody)
err = packetConn.SetWriteDeadline(time.Now().Add(testTimeout))
require.NoError(t, err)
var n int
switch c := packetConn.(type) {
case *net.UDPConn:
n, _, err = c.WriteMsgUDP(resp, sess.respOOB, sess.raddr)
require.NoError(t, err)
case netext.SessionPacketConn:
n, err = c.WriteToSession(resp, sess)
require.NoError(t, err)
}
assert.Equal(t, respLen, n)
}
// requestTCP is a test helper for making TCP queries. It is intended to be
// used as a goroutine.
func requestTCP(raddr *net.TCPAddr, req, wantResp []byte) {
pt := testutil.PanicT{}
remoteConn, err := net.DialTCP("tcp", nil, raddr)
require.NoError(pt, err)
defer func() {
closeErr := remoteConn.Close()
require.NoError(pt, closeErr)
}()
err = remoteConn.SetWriteDeadline(time.Now().Add(testTimeout))
require.NoError(pt, err)
n, err := remoteConn.Write(req)
require.NoError(pt, err)
assert.Equal(pt, len(req), n)
wantRespLen := len(wantResp)
resp := make([]byte, wantRespLen)
err = remoteConn.SetReadDeadline(time.Now().Add(testTimeout))
require.NoError(pt, err)
n, err = remoteConn.Read(resp)
require.NoError(pt, err)
assert.Equal(pt, wantRespLen, n)
assert.Equal(pt, wantResp, resp)
}
// requestUDP is a test helper for making UDP queries. It is intended to be
// used as a goroutine.
func requestUDP(raddr *net.UDPAddr, req, wantResp []byte) {
pt := testutil.PanicT{}
remoteConn, err := net.DialUDP("udp", nil, raddr)
require.NoError(pt, err)
defer func() {
closeErr := remoteConn.Close()
require.NoError(pt, closeErr)
}()
err = remoteConn.SetWriteDeadline(time.Now().Add(testTimeout))
require.NoError(pt, err)
n, err := remoteConn.Write(req)
require.NoError(pt, err)
assert.Equal(pt, len(req), n)
wantRespLen := len(wantResp)
resp := make([]byte, wantRespLen)
err = remoteConn.SetReadDeadline(time.Now().Add(testTimeout))
require.NoError(pt, err)
n, err = remoteConn.Read(resp)
require.NoError(pt, err)
assert.Equal(pt, wantRespLen, n)
assert.Equal(pt, wantResp, resp)
}
// normalize sets the IP address of addr to a 4-byte version of the IP address
// if it is an IPv4 address.
func normalize(addr net.Addr) {
switch addr := addr.(type) {
case *net.TCPAddr:
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
case *net.UDPAddr:
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
default:
panic(fmt.Errorf("bad type %T", addr))
}
}
// closestIP is a test helper that provides a closest IP address based on the
// provided IP network.
func closestIP(t testing.TB, n *net.IPNet, ip net.IP) (closest net.IP) {
t.Helper()
ipAddr, err := netutil.IPToAddrNoMapped(ip)
require.NoError(t, err)
ipNet, err := netutil.IPNetToPrefixNoMapped(n)
require.NoError(t, err)
nextAddr := ipAddr.Next()
if ipNet.Contains(nextAddr) {
return nextAddr.AsSlice()
}
prevAddr := ipAddr.Prev()
if ipNet.Contains(prevAddr) {
return prevAddr.AsSlice()
}
t.Fatalf("neither %s nor %s are in %s", nextAddr, prevAddr, ipNet)
return nil
}
func TestListenControlWithSO(t *testing.T) {
const (
sndBufSize = 10000
rcvBufSize = 20000
)
iface, _ := InterfaceForTests(t)
if iface == nil {
t.Skipf("test %s skipped: please set env var %s", t.Name(), TestInterfaceEnvVarName)
}
ifaceName := iface.Name
lc := newListenConfig(
ifaceName,
&ControlConfig{
RcvBufSize: rcvBufSize,
SndBufSize: sndBufSize,
},
)
require.NotNil(t, lc)
type syscallConner interface {
SyscallConn() (c syscall.RawConn, err error)
}
t.Run("udp", func(t *testing.T) {
c, err := lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0")
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
require.NoError(t, opErr)
assert.Equal(t, sndBufSize*2, val)
})
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
require.NoError(t, opErr)
assert.Equal(t, rcvBufSize*2, val)
})
require.NoError(t, err)
})
t.Run("tcp", func(t *testing.T) {
c, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:0")
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
require.NoError(t, opErr)
assert.Equal(t, sndBufSize*2, val)
})
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
require.NoError(t, opErr)
assert.Equal(t, rcvBufSize*2, val)
})
require.NoError(t, err)
})
}
// testMsgUDPReader is a [msgUDPReader] for tests.
type testMsgUDPReader struct {
onReadMsgUDP func(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
// type check
var _ msgUDPReader = (*testMsgUDPReader)(nil)
// ReadMsgUDP implements the [msgUDPReader] interface for *testMsgUDPReader.
func (r *testMsgUDPReader) ReadMsgUDP(
b []byte,
oob []byte,
) (n, oobn, flags int, addr *net.UDPAddr, err error) {
return r.onReadMsgUDP(b, oob)
}
// Sinks for benchmarks.
var (
sessSink *packetSession
errSink error
)
func BenchmarkReadPacketSession(b *testing.B) {
bodyData := []byte("message body data")
// TODO(a.garipov): Find a better way to pack these control messages than
// just [binary.Write].
oobBuf := &bytes.Buffer{}
ctrlMsgHdr := unix.Cmsghdr{
Len: 24,
Level: unix.SOL_IP,
Type: unix.IP_ORIGDSTADDR,
}
err := binary.Write(oobBuf, binary.NativeEndian, ctrlMsgHdr)
require.NoError(b, err)
pktInfo := unix.Inet4Pktinfo{
Spec_dst: *(*[4]byte)(testRAddr.IP),
Addr: *(*[4]byte)(testRAddr.IP),
}
err = binary.Write(oobBuf, binary.NativeEndian, pktInfo)
require.NoError(b, err)
oobData := oobBuf.Bytes()
c := &testMsgUDPReader{
onReadMsgUDP: func(body, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
copy(body, bodyData)
copy(oob, oobData)
return len(bodyData), len(oobData), 0, testRAddr, nil
},
}
body := make([]byte, dns.DefaultMsgSize)
oob := make([]byte, netext.IPDstOOBSize)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
sessSink, errSink = readPacketSession(c, body, oob)
}
require.NoError(b, errSink)
require.NotNil(b, sessSink)
assert.Equal(b, sessSink.raddr, testRAddr)
assert.Equal(b, sessSink.readBody, bodyData)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkReadPacketSession
// BenchmarkReadPacketSession-16 3311841 458.1 ns/op 224 B/op 5 allocs/op
}