mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
342 lines
8.2 KiB
Go
342 lines
8.2 KiB
Go
//go:build linux
|
|
|
|
package bindtodevice
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
// chanPacketConn is a [netext.SessionPacketConn] that returns data sent to it
|
|
// through the channel.
|
|
//
|
|
// Connections of this type are returned by [ListenConfig.ListenPacket] and are
|
|
// used in module dnsserver to make the bind-to-device logic work in
|
|
// DNS-over-UDP.
|
|
type chanPacketConn struct {
|
|
// mu protects sessions (against closure) and isClosed.
|
|
mu *sync.Mutex
|
|
sessions chan *packetSession
|
|
|
|
writeRequests chan *packetConnWriteReq
|
|
|
|
sessionsGauge prometheus.Gauge
|
|
writeRequestsGauge prometheus.Gauge
|
|
|
|
// deadlineMu protects readDeadline and writeDeadline.
|
|
deadlineMu *sync.RWMutex
|
|
readDeadline time.Time
|
|
writeDeadline time.Time
|
|
|
|
laddr net.Addr
|
|
subnet netip.Prefix
|
|
isClosed bool
|
|
}
|
|
|
|
// newChanPacketConn returns a new properly initialized *chanPacketConn.
|
|
func newChanPacketConn(
|
|
sessions chan *packetSession,
|
|
subnet netip.Prefix,
|
|
writeRequests chan *packetConnWriteReq,
|
|
writeRequestsGauge prometheus.Gauge,
|
|
laddr net.Addr,
|
|
) (c *chanPacketConn) {
|
|
return &chanPacketConn{
|
|
mu: &sync.Mutex{},
|
|
sessions: sessions,
|
|
writeRequests: writeRequests,
|
|
|
|
sessionsGauge: metrics.BindToDeviceUDPSessionsChanSize.WithLabelValues(
|
|
subnet.String(),
|
|
),
|
|
writeRequestsGauge: writeRequestsGauge,
|
|
|
|
deadlineMu: &sync.RWMutex{},
|
|
|
|
laddr: laddr,
|
|
subnet: subnet,
|
|
}
|
|
}
|
|
|
|
// packetConnWriteReq is a request to write a piece of data to the original
|
|
// packet connection. respCh, body, and either raddr or session must be set.
|
|
type packetConnWriteReq struct {
|
|
respCh chan *packetConnWriteResp
|
|
session *packetSession
|
|
raddr net.Addr
|
|
deadline time.Time
|
|
body []byte
|
|
}
|
|
|
|
// packetConnWriteResp is a response to a [packetConnWriteReq].
|
|
type packetConnWriteResp struct {
|
|
err error
|
|
written int
|
|
}
|
|
|
|
// type check
|
|
var _ netext.SessionPacketConn = (*chanPacketConn)(nil)
|
|
|
|
// Close implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) Close() (err error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.isClosed {
|
|
return wrapConnError(tnChanPConn, "Close", c.laddr, net.ErrClosed)
|
|
}
|
|
|
|
close(c.sessions)
|
|
c.isClosed = true
|
|
|
|
return nil
|
|
}
|
|
|
|
// LocalAddr implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) LocalAddr() (addr net.Addr) { return c.laddr }
|
|
|
|
// ReadFrom implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
|
|
n, sess, err := c.readFromSession(b, "ReadFrom")
|
|
if sess == nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
return n, sess.RemoteAddr(), err
|
|
}
|
|
|
|
// ReadFromSession implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) ReadFromSession(b []byte) (n int, s netext.PacketSession, err error) {
|
|
return c.readFromSession(b, "ReadFromSession")
|
|
}
|
|
|
|
// readFromSession contains the common code of [ReadFrom] and [ReadFromSession].
|
|
func (c *chanPacketConn) readFromSession(
|
|
b []byte,
|
|
fnName string,
|
|
) (n int, s netext.PacketSession, err error) {
|
|
var sess *packetSession
|
|
defer func() {
|
|
if sess != nil {
|
|
n = copy(b, sess.readBody)
|
|
}
|
|
}()
|
|
|
|
var deadline time.Time
|
|
func() {
|
|
c.deadlineMu.RLock()
|
|
defer c.deadlineMu.RUnlock()
|
|
|
|
deadline = c.readDeadline
|
|
}()
|
|
|
|
timerCh, stopTimer, err := timerFromDeadline(deadline)
|
|
if err != nil {
|
|
err = fmt.Errorf("setting deadline: %w", err)
|
|
|
|
return 0, nil, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
|
}
|
|
defer stopTimer()
|
|
|
|
sess, err = receiveWithTimer(c.sessions, timerCh)
|
|
if err != nil {
|
|
err = fmt.Errorf("receiving: %w", err)
|
|
|
|
// Prevent netext.PacketSession((*packetSession)(nil)).
|
|
if sess != nil {
|
|
s = sess
|
|
}
|
|
|
|
return 0, s, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
|
}
|
|
|
|
return 0, sess, nil
|
|
}
|
|
|
|
// timerFromDeadline converts a deadline value into a timer channel. stopTimer
|
|
// must be deferred by the caller.
|
|
func timerFromDeadline(deadline time.Time) (timerCh <-chan time.Time, stopTimer func(), err error) {
|
|
if deadline.IsZero() {
|
|
return nil, func() {}, nil
|
|
}
|
|
|
|
d := time.Until(deadline)
|
|
if d <= 0 {
|
|
return nil, func() {}, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
timer := time.NewTimer(time.Until(deadline))
|
|
timerCh = timer.C
|
|
stopTimer = func() {
|
|
if !timer.Stop() {
|
|
// We don't know if the timer's value has been consumed yet or not,
|
|
// so use a select with default to make sure that this doesn't
|
|
// block.
|
|
select {
|
|
case <-timerCh:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
return timerCh, stopTimer, nil
|
|
}
|
|
|
|
// SetDeadline implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) SetDeadline(t time.Time) (err error) {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
|
|
c.readDeadline = t
|
|
c.writeDeadline = t
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetReadDeadline implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) SetReadDeadline(t time.Time) (err error) {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
|
|
c.readDeadline = t
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetWriteDeadline implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) SetWriteDeadline(t time.Time) (err error) {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
|
|
c.writeDeadline = t
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteTo implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) WriteTo(b []byte, raddr net.Addr) (n int, err error) {
|
|
return c.writeToSession(b, nil, raddr, "WriteTo")
|
|
}
|
|
|
|
// WriteToSession implements the [netext.SessionPacketConn] interface for
|
|
// *chanPacketConn.
|
|
func (c *chanPacketConn) WriteToSession(
|
|
b []byte,
|
|
s netext.PacketSession,
|
|
) (n int, err error) {
|
|
return c.writeToSession(b, s.(*packetSession), nil, "WriteToSession")
|
|
}
|
|
|
|
// writeToSession contains the common code of [WriteTo] and [WriteToSession].
|
|
func (c *chanPacketConn) writeToSession(
|
|
b []byte,
|
|
s *packetSession,
|
|
raddr net.Addr,
|
|
fnName string,
|
|
) (n int, err error) {
|
|
var deadline time.Time
|
|
func() {
|
|
c.deadlineMu.RLock()
|
|
defer c.deadlineMu.RUnlock()
|
|
|
|
deadline = c.writeDeadline
|
|
}()
|
|
|
|
timerCh, stopTimer, err := timerFromDeadline(deadline)
|
|
if err != nil {
|
|
err = fmt.Errorf("setting deadline: %w", err)
|
|
|
|
return 0, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
|
}
|
|
defer stopTimer()
|
|
|
|
resp := make(chan *packetConnWriteResp, 1)
|
|
req := &packetConnWriteReq{
|
|
respCh: resp,
|
|
session: s,
|
|
raddr: raddr,
|
|
deadline: deadline,
|
|
body: b,
|
|
}
|
|
err = sendWithTimer(c.writeRequests, req, timerCh)
|
|
if err != nil {
|
|
err = fmt.Errorf("sending write request: %w", err)
|
|
|
|
return 0, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
|
}
|
|
|
|
c.writeRequestsGauge.Set(float64(len(c.writeRequests)))
|
|
|
|
r, err := receiveWithTimer(resp, timerCh)
|
|
if err != nil {
|
|
err = fmt.Errorf("receiving write response: %w", err)
|
|
|
|
return 0, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
|
}
|
|
|
|
return r.written, r.err
|
|
}
|
|
|
|
// receiveWithTimer is a helper function that uses a timer channel to indicate
|
|
// that a receive did not succeed in time. If the channel is closed, err is
|
|
// [net.ErrClosed]. If the receive from timerCh succeeded first, err is
|
|
// [os.ErrDeadlineExceeded].
|
|
func receiveWithTimer[T any](ch <-chan T, timerCh <-chan time.Time) (v T, err error) {
|
|
var ok bool
|
|
select {
|
|
case v, ok = <-ch:
|
|
if !ok {
|
|
err = net.ErrClosed
|
|
}
|
|
case <-timerCh:
|
|
err = os.ErrDeadlineExceeded
|
|
}
|
|
|
|
return v, err
|
|
}
|
|
|
|
// sendWithTimer is a helper function that uses a timer channel to indicate that
|
|
// a send did not succeed in time. If the receive from timerCh succeeded first,
|
|
// err is [os.ErrDeadlineExceeded].
|
|
func sendWithTimer[T any](ch chan<- T, v T, timerCh <-chan time.Time) (err error) {
|
|
select {
|
|
case ch <- v:
|
|
return nil
|
|
case <-timerCh:
|
|
return os.ErrDeadlineExceeded
|
|
}
|
|
}
|
|
|
|
// send is a helper method to send a session to the packet connection's channel.
|
|
// ok is false if the listener is closed.
|
|
func (c *chanPacketConn) send(sess *packetSession) (ok bool) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.isClosed {
|
|
return false
|
|
}
|
|
|
|
c.sessions <- sess
|
|
|
|
c.sessionsGauge.Set(float64(len(c.sessions)))
|
|
|
|
return true
|
|
}
|