Sync v2.3

This commit is contained in:
Andrey Meshkov 2023-09-06 08:22:07 +03:00
parent 1cc340ddb1
commit cfb4caf935
151 changed files with 10639 additions and 2728 deletions

View File

@ -11,7 +11,49 @@ The format is **not** based on [Keep a Changelog][kec], since the project
## AGDNS-1537 / Build 580
## AGDNS-1607 / Build 617
* New configuration `access` has been added, it has an a list of AdBlock rules
to block requests, and a lists of client subnets to block access from.
Example configuration:
```yaml
access:
blocked_question_domains:
- 'test.org'
- '||example.org^$dnstype=AAAA'
blocked_client_subnets:
- '1.1.1.1'
- '2.2.2.0/8'
```
## AGDNS-1619 / Build 611
* Added a new metric `bill_stat_upload_duration` that counts the duration of
billing statistics upload.
* The environment variable `BILLSTAT_URL`, which describes the endpoint for
backend billing statistics uploader API, now supports GRPC endpoints.
## AGDNS-1600 / Build 582
* The environment variable `PROFILES_CACHE_PATH` no longer supports JSON
files. Use protobuf with `.pb` extension instead. The default value has
been changed to `./profilecache.pb`.
## AGDNS-1539 / Build 581
* The environment variable `PROFILES_URL`, which describes the endpoint for
profiles sync API, now supports GRPC endpoints.
## AGDNS-1579 / Build 580
* The optional property `bind_interfaces` of `server_groups.*.servers`
objects has been changed, property `subnet` is now an array and has been

View File

@ -51,17 +51,11 @@ test: go-test
go-bench: ; $(ENV) "$(SHELL)" ./scripts/make/go-bench.sh
go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh
go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh
go-gen: ; $(ENV) "$(SHELL)" ./scripts/make/go-gen.sh
go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
go-gen:
cd ./internal/agd/ && "$(GO.MACRO)" run ./country_generate.go
cd ./internal/geoip/ && "$(GO.MACRO)" run ./asntops_generate.go
cd ./internal/profiledb/internal/filecachepb/ &&\
protoc --go_opt=paths=source_relative --go_out=. ./filecache.proto
go-check: go-tools go-lint go-test
# A quick check to make sure that all operating systems relevant to the

View File

@ -52,6 +52,15 @@ ratelimit:
stop: 1000
resume: 800
# Access settings.
access:
# Domains to block.
blocked_question_domains:
- 'test.org'
# Client subnets to block.
blocked_client_subnets:
- '1.2.3.0/8'
# DNS cache configuration.
cache:
# The type of cache to use. Can be 'simple' (a simple LRU cache) or 'ecs'

View File

@ -32,6 +32,7 @@ configuration file with comments.
* [Servers](#server_groups-*-servers-*)
* [Connectivity check](#connectivity-check)
* [Network settings](#network)
* [Access settings](#access)
* [Additional metrics information](#additional_metrics_info)
[dist]: ../config.dist.yml
@ -1093,6 +1094,22 @@ The `network` object has the following properties:
## <a href="#access" id="access" name="access">Access settings</a>
The `access` object has the following properties:
* <a href="#access-blocked_question_domains" id="access-blocked_question_domains" name="access-blocked_question_domains">`blocked_question_domains`</a>:
The list of domains or AdBlock rules to block requests.
**Examples:** `test.org`, `||example.org^$dnstype=AAAA`.
* <a href="#access-blocked_client_subnets" id="access-blocked_client_subnets" name="access-blocked_client_subnets">`blocked_client_subnets`</a>:
The list of IP addresses or CIDR-es to block.
**Example:** `127.0.0.1`.
## <a href="#additional_metrics_info" id="additional_metrics_info" name="additional_metrics_info">Additional metrics information</a>
The `additional_metrics_info` object is a map of strings with extra information

View File

@ -241,27 +241,33 @@ Examples below are for the configuration with the following changes:
You may also need to remove `probe_ipv6` if your network does not support IPv6.
If you're using an OS different from Linux, you also need to make these changes:
* Remove the `interface_listeners` section.
* Remove `bind_interfaces` from the `default_dns` server configuration and
replace it with `bind_addresses`.
```sh
env \
ADULT_BLOCKING_URL='https://raw.githubusercontent.com/ameshkov/PersonalFilters/master/adult_test.txt' \
BILLSTAT_URL='PUT BILLSTAT API BACKEND URL HERE' \
BLOCKED_SERVICE_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/services.json'\
CONSUL_ALLOWLIST_URL='PUT CONSUL ALLOWLIST URL HERE' \
ADULT_BLOCKING_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/adult_blocking.txt' \
BILLSTAT_URL='https://httpbin.agrd.workers.dev/post' \
BLOCKED_SERVICE_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/services.json' \
CONSUL_ALLOWLIST_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/consul_allowlist.json' \
CONFIG_PATH='./config.yaml' \
FILTER_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/filters.json' \
FILTER_CACHE_PATH='./test/cache' \
NEW_REG_DOMAINS_URL='PUT NEWLY REGISTERED DOMAINS FILTER URL HERE' \
PROFILES_CACHE_PATH='./test/profilecache.json' \
PROFILES_URL='PUT PROFILES API BACKEND URL HERE' \
SAFE_BROWSING_URL='https://raw.githubusercontent.com/ameshkov/PersonalFilters/master/safebrowsing_test.txt' \
NEW_REG_DOMAINS_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/nrd.txt' \
PROFILES_CACHE_PATH='./test/profilecache.pb' \
PROFILES_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/profiles' \
SAFE_BROWSING_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/safe_browsing.txt' \
GENERAL_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/engines_safe_search.txt' \
GEOIP_ASN_PATH='./test/GeoLite2-ASN-Test.mmdb' \
GEOIP_COUNTRY_PATH='./test/GeoIP2-City-Test.mmdb' \
QUERYLOG_PATH='./test/cache/querylog.jsonl' \
LINKED_IP_TARGET_URL='PUT LINKED IP TARGET URL HERE' \
LINKED_IP_TARGET_URL='https://httpbin.agrd.workers.dev/anything' \
LISTEN_ADDR='127.0.0.1' \
LISTEN_PORT='8081' \
RULESTAT_URL='https://testchrome.adtidy.org/api/1.0/rulestats.html' \
RULESTAT_URL='https://httpbin.agrd.workers.dev/post' \
SENTRY_DSN='https://1:1@localhost/1' \
VERBOSE='1' \
YOUTUBE_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/youtube_safe_search.txt' \

View File

@ -49,9 +49,10 @@ The URL of source list of rules for adult blocking filter.
## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a>
The base backend URL for backend billing statistics uploader API. The backend
endpoints must reply with a 200 status code on success. See the [external HTTP
API requirements section][ext-billstat]
The base backend URL for backend billing statistics uploader API. Supports HTTP
and GRPC protocols. In case of HTTP the backend endpoint must reply with a 200
status code on success. See the [external HTTP API requirements
section][ext-billstat].
**Default:** No default value, the variable is **required.**
@ -227,13 +228,10 @@ The path to the profile cache file:
< /path/to/profilecache.pb
```
* A file with the extension `.json` means that the profiles are cached in the
JSON format. This format is **deprecated** and is not recommended.
The profile cache is read on start and is later updated on every
[full refresh][conf-backend-full_refresh_interval].
**Default:** `./profilecache.json`.
**Default:** `./profilecache.pb`.
[conf-backend-full_refresh_interval]: configuration.md#backend-full_refresh_interval
@ -241,9 +239,10 @@ The profile cache is read on start and is later updated on every
## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a>
The base backend URL for profiles API. The backend endpoints must reply with a
200 status code on success. See the [external HTTP API requirements
section][ext-profiles].
The base backend URL for profiles API. Supports HTTP (`http://` and `https://`)
and GRPC (`grpc://` and `grpcs://`) URLs. In case of HTTP the backend endpoint
must reply with a 200 status code on success. See the [external API
requirements section][ext-profiles].
**Default:** No default value, the variable is **required.**

View File

@ -29,7 +29,9 @@ document should set the `Server` header in their replies.
## <a href="#backend-billstat" id="backend-billstat" name="backend-billstat">Backend Billing Statistics</a>
This is the service to which the [`BILLSTAT_URL`][env-billstat_url] environment
variable points. This service must provide one endpoint:
variable points. Supports `http(s):` and `grpc(s)` URLs. In case of GRPC
protocol, the service must correspond to `./internal/backendpb/backend.proto`.
In case of HTTP protocol this service must provide one endpoint:
`POST /dns_api/v1/devices_activity`, it must respond with a `200 OK` response
code and accept a JSON document in the following format:
@ -55,7 +57,9 @@ code and accept a JSON document in the following format:
## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend Profiles Service</a>
This is the service to which the [`PROFILES_URL`][env-profiles_url] environment
variable points. This service must provide one endpoint:
variable points. Supports `http(s):` and `grpc(s)` URLs. In case of GRPC
protocol, the service must correspond to `./internal/backendpb/backend.proto`.
In case of HTTP protocol this service must provide one endpoint:
`GET /dns_api/v1/settings`, it must respond with a `200 OK` response code and
accept a JSON document in the following format:

29
go.mod
View File

@ -4,8 +4,8 @@ go 1.20
require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
github.com/AdguardTeam/golibs v0.13.6
github.com/AdguardTeam/urlfilter v0.16.2
github.com/AdguardTeam/golibs v0.15.0
github.com/AdguardTeam/urlfilter v0.17.0
github.com/ameshkov/dnscrypt/v2 v2.2.7
github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc
github.com/bluele/gcache v0.0.2
@ -19,12 +19,13 @@ require (
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/client_model v0.4.0
github.com/prometheus/common v0.44.0
github.com/quic-go/quic-go v0.35.1
github.com/quic-go/quic-go v0.38.0
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/net v0.12.0
golang.org/x/sys v0.10.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/net v0.14.0
golang.org/x/sys v0.11.0
golang.org/x/time v0.3.0
google.golang.org/grpc v1.56.2
google.golang.org/protobuf v1.30.0
gopkg.in/yaml.v2 v2.4.0
)
@ -40,19 +41,19 @@ require (
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/onsi/ginkgo/v2 v2.10.0 // indirect
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
github.com/panjf2000/ants/v2 v2.7.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/text v0.11.0 // indirect
golang.org/x/tools v0.10.0 // indirect
github.com/quic-go/qtls-go1-20 v0.3.3 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

104
go.sum
View File

@ -1,12 +1,7 @@
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.13.6 h1:z/0Q25pRLdaQxtoxvfSaooz5mdv8wj0R8KREj54q8yQ=
github.com/AdguardTeam/golibs v0.13.6/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.2 h1:k9m9dUYVJ3sTswYa2/ukVNjicfGcz0oqFDO13hPmfHE=
github.com/AdguardTeam/urlfilter v0.16.2/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/AdguardTeam/golibs v0.15.0 h1:yOv/fdVkJIOWKr0NlUXAE9RA0DK9GKiBbiGzq47vY7o=
github.com/AdguardTeam/golibs v0.15.0/go.mod h1:66ZLs8P7nk/3IfKroQ1rqtieLk+5eXYXMBKXlVL7KeI=
github.com/AdguardTeam/urlfilter v0.17.0 h1:tUzhtR9wMx704GIP3cibsDQJrixlMHfwoQbYJfPdFow=
github.com/AdguardTeam/urlfilter v0.17.0/go.mod h1:bbuZjPUzm/Ip+nz5qPPbwIP+9rZyQbQad8Lt/0fCulU=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
@ -27,7 +22,6 @@ github.com/caarlos0/env/v7 v7.1.0 h1:9lzTF5amyQeWHZzuZeKlCb5FWSUxpG1js43mhbY8ozg
github.com/caarlos0/env/v7 v7.1.0/go.mod h1:LPPWniDUq4JaO6Q41vtlyikhMknqymCLBw0eX4dcH1E=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -38,8 +32,7 @@ github.com/getsentry/sentry-go v0.21.0 h1:c9l5F1nPF30JIppulk4veau90PK6Smu3abgVtV
github.com/getsentry/sentry-go v0.21.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-ole/go-ole v1.2.5 h1:t4MGB5xEDZvXI+0rMjjsfBsD7yAgp/s9ZDkL1JndXwY=
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
@ -50,25 +43,20 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ=
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM=
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg=
github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0=
github.com/panjf2000/ants/v2 v2.7.5 h1:/vhh0Hza9G1vP1PdCj9hl6MUzCRbmtcTJL0OsnmytuU=
@ -77,9 +65,9 @@ github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI=
github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk=
github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
@ -90,47 +78,40 @@ github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+Pymzi
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM=
github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/shirou/gopsutil/v3 v3.21.8 h1:nKct+uP0TV8DjjNiHanKf8SAuub+GNsbrOtM9Nl9biA=
github.com/shirou/gopsutil/v3 v3.21.8/go.mod h1:YWp/H8Qs5fVmf17v7JNZzA0mPJ+mS2e9JdiUF9LlKzQ=
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tklauser/go-sysconf v0.3.9 h1:JeUVdAOWhhxVcU6Eqr/ATFHgXk/mmiItdKeJPev3vTo=
github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs=
github.com/tklauser/numcpus v0.3.0 h1:ILuRUQBtssgnxw0XXIjKUC56fgnOrFoQQ/4+DeU2biQ=
github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210929193557-e81a3d93ecf6/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -138,44 +119,37 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210909193231-528a39cd75f3/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E=
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -119,7 +119,6 @@ github.com/go-kit/log v0.2.0 h1:7i2K3eKTos3Vc0enKCfnVcgHh2olr/MyfboYq7cAcFw=
github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU=
github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab h1:xveKWz2iaueeTaUgdetzel+U7exyigDYBryyVfV/rZk=
@ -150,6 +149,7 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
@ -200,14 +200,11 @@ github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kataras/blocks v0.0.7 h1:cF3RDY/vxnSRezc7vLFlQFTYXG/yAr1o7WImJuZbzC4=
github.com/kataras/blocks v0.0.7/go.mod h1:UJIU97CluDo0f+zEjbnbkeMRlvYORtmc1304EeyXf4I=
github.com/kataras/golog v0.1.7 h1:0TY5tHn5L5DlRIikepcaRR/6oInIr9AiWsxzt0vvlBE=
@ -234,6 +231,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJ
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
github.com/kr/pty v1.1.3 h1:/Um6a/ZmD5tF7peoOJ5oN5KMQ0DrGVQSXLNwyckutPk=
@ -283,12 +281,9 @@ github.com/microcosm-cc/bluemonday v1.0.23/go.mod h1:mN70sk7UkkF8TUr2IGBpNN0jAgS
github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 h1:D6paGObi5Wud7xg83MaEFyjxQB1W5bz5d0IFppr+ymk=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab h1:eFXv9Nu1lGbrNbj619aWwZfVF5HBrm9Plte8aNptuTI=
@ -296,13 +291,18 @@ github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
github.com/onsi/ginkgo/v2 v2.8.1/go.mod h1:N1/NbDngAFcSLdyZ+/aYTYGSlq9qMCS/cNKGJjy+csc=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM=
github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo=
github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM=
github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg=
github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0=
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
@ -313,6 +313,8 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:
github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc=
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
github.com/rogpeppe/go-internal v1.3.0 h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
@ -324,6 +326,8 @@ github.com/schollz/closestmatch v2.1.0+incompatible h1:Uel2GXEpJqOWBrlyI+oY9LTiy
github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g=
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4 h1:Fth6mevc5rX7glNLpbAMJnqKlfIkcTjZCSHEeqvKbcI=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48 h1:vabduItPAIz9px5iryD5peyx7O3Ya8TBThapgXim98o=
@ -388,6 +392,7 @@ github.com/tdewolff/minify/v2 v2.12.4 h1:kejsHQMM17n6/gwdw53qsi6lg0TGddZADVyQOz1
github.com/tdewolff/minify/v2 v2.12.4/go.mod h1:h+SRvSIX3kwgwTFOpSckvSxgax3uy8kZTSF1Ojrr3bk=
github.com/tdewolff/parse/v2 v2.6.4 h1:KCkDvNUMof10e3QExio9OPZJT8SbdKojLBumw8YZycQ=
github.com/tdewolff/parse/v2 v2.6.4/go.mod h1:woz0cgbLwFdtbjJu8PIKxhW05KplTFQkOdX78o+Jgrs=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
@ -429,10 +434,13 @@ golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20230306221820-f0f767cdffd6/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20230807204917-050eac23e9de/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -441,6 +449,9 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -456,6 +467,10 @@ golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -481,9 +496,14 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
@ -491,15 +511,27 @@ golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
@ -519,11 +551,15 @@ google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoA
google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI=
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc=
@ -536,6 +572,8 @@ gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

108
internal/access/access.go Normal file
View File

@ -0,0 +1,108 @@
// Package access contains structures for access control management.
package access
import (
"fmt"
"net/netip"
"strings"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
)
// unit is a convenient alias for struct{}
type unit = struct{}
// Interface is the access manager interface.
type Interface interface {
// IsBlockedHost returns true if host should be blocked.
IsBlockedHost(host string, qt uint16) (blocked bool)
// IsBlockedIP returns the status of the IP address blocking as well as the
// rule that blocked it.
IsBlockedIP(ip netip.Addr) (blocked bool, rule string)
}
// type check
var _ Interface = (*Manager)(nil)
// Manager controls IP and client blocking that takes place before all
// other processing. An Manager is safe for concurrent use.
type Manager struct {
blockedIPs map[netip.Addr]unit
blockedHostsEng *urlfilter.DNSEngine
blockedNets []netip.Prefix
}
// New create an Manager. The parameters assumed to be valid.
func New(blockedDomains, blockedSubnets []string) (am *Manager, err error) {
am = &Manager{
blockedIPs: map[netip.Addr]unit{},
}
processAccessList(blockedSubnets, am.blockedIPs, &am.blockedNets)
b := &strings.Builder{}
for _, h := range blockedDomains {
stringutil.WriteToBuilder(b, strings.ToLower(h), "\n")
}
lists := []filterlist.RuleList{
&filterlist.StringRuleList{
ID: 0,
RulesText: b.String(),
IgnoreCosmetic: true,
},
}
rulesStrg, err := filterlist.NewRuleStorage(lists)
if err != nil {
return nil, fmt.Errorf("adding blocked hosts: %w", err)
}
am.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
return am, nil
}
// processAccessList is a helper for processing a list of strings, each of them
// assumed be a valid IP address or a valid CIDR.
func processAccessList(strs []string, ips map[netip.Addr]unit, nets *[]netip.Prefix) {
for _, s := range strs {
var err error
var ip netip.Addr
var ipnet netip.Prefix
if ip, err = netip.ParseAddr(s); err == nil {
ips[ip] = unit{}
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
*nets = append(*nets, ipnet)
}
}
}
// IsBlockedHost returns true if host should be blocked.
func (am *Manager) IsBlockedHost(host string, qt uint16) (blocked bool) {
_, blocked = am.blockedHostsEng.MatchRequest(&urlfilter.DNSRequest{
Hostname: host,
DNSType: qt,
})
return blocked
}
// IsBlockedIP returns the status of the IP address blocking as well as the rule
// that blocked it.
func (am *Manager) IsBlockedIP(ip netip.Addr) (blocked bool, rule string) {
if _, ok := am.blockedIPs[ip]; ok {
return true, ip.String()
}
for _, ipnet := range am.blockedNets {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}
return false, ""
}

View File

@ -0,0 +1,107 @@
package access_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAccessManager_IsBlockedHost(t *testing.T) {
am, err := access.New([]string{
"block.test",
"UPPERCASE.test",
"||block_aaaa.test^$dnstype=AAAA",
}, []string{})
require.NoError(t, err)
testCases := []struct {
want assert.BoolAssertionFunc
name string
host string
qt uint16
}{{
want: assert.False,
name: "pass",
host: "pass.test",
qt: dns.TypeA,
}, {
want: assert.True,
name: "blocked_domain_A",
host: "block.test",
qt: dns.TypeA,
}, {
want: assert.True,
name: "blocked_domain_HTTPS",
host: "block.test",
qt: dns.TypeHTTPS,
}, {
want: assert.True,
name: "uppercase_domain",
host: "uppercase.test",
qt: dns.TypeHTTPS,
}, {
want: assert.False,
name: "pass_qt",
host: "block_aaaa.test",
qt: dns.TypeA,
}, {
want: assert.True,
name: "block_qt",
host: "block_aaaa.test",
qt: dns.TypeAAAA,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
blocked := am.IsBlockedHost(tc.host, tc.qt)
tc.want(t, blocked)
})
}
}
func TestAccessManager_IsBlockedIP(t *testing.T) {
am, err := access.New([]string{}, []string{
"1.1.1.1",
"2.2.2.0/8",
})
require.NoError(t, err)
testCases := []struct {
want assert.BoolAssertionFunc
ip netip.Addr
wantRule string
name string
}{{
want: assert.False,
wantRule: "",
name: "pass",
ip: netip.MustParseAddr("1.1.1.0"),
}, {
want: assert.True,
wantRule: "1.1.1.1",
name: "block_ip",
ip: netip.MustParseAddr("1.1.1.1"),
}, {
want: assert.False,
wantRule: "",
name: "pass_subnet",
ip: netip.MustParseAddr("1.2.2.2"),
}, {
want: assert.True,
wantRule: "2.2.2.0/8",
name: "block_subnet",
ip: netip.MustParseAddr("2.2.2.2"),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
blocked, rule := am.IsBlockedIP(tc.ip)
tc.want(t, blocked)
assert.Equal(t, tc.wantRule, rule)
})
}
}

View File

@ -6,6 +6,7 @@ import (
"encoding/csv"
"net/http"
"os"
"strings"
"text/template"
"time"
@ -40,10 +41,10 @@ func main() {
// Skip the first row, as it is a header.
rows = rows[1:]
// Sort by the code to make the output more predictable and easier to look
// through.
slices.SortFunc(rows, func(a, b []string) (less bool) {
return a[1] < b[1]
slices.SortFunc(rows, func(a, b []string) (res int) {
// Sort by the code to make the output more predictable and easier to
// look through.
return strings.Compare(a[1], b[1])
})
tmpl, err := template.New("main").Parse(tmplStr)

View File

@ -74,3 +74,10 @@ func ParseSubnets(strs ...string) (subnets []netip.Prefix, err error) {
return subnets, nil
}
// NormalizeDomain returns lowercased version of the host without the final dot.
//
// TODO(a.garipov): Move to golibs.
func NormalizeDomain(fqdn string) (host string) {
return strings.ToLower(strings.TrimSuffix(fqdn, "."))
}

View File

@ -0,0 +1,19 @@
package agdnet
import (
"fmt"
"net/netip"
)
// FormatPrefixAddr returns either a simple IP:port address or one with the
// prefix length appended after a slash, depending on whether or not subnet is a
// single-address subnet. This is done to make using the IP:port part easier to
// split off using functions like [strings.Cut].
func FormatPrefixAddr(subnet netip.Prefix, port uint16) (s string) {
addrPort := netip.AddrPortFrom(subnet.Addr(), port)
if subnet.IsSingleIP() {
return addrPort.String()
}
return fmt.Sprintf("%s/%d", addrPort, subnet.Bits())
}

View File

@ -0,0 +1,17 @@
package agdnet_test
import (
"fmt"
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
)
func ExampeFormatPrefixAddr() {
fmt.Println(agdnet.FormatPrefixAddr(netip.MustParsePrefix("1.2.3.4/32"), 5678))
fmt.Println(agdnet.FormatPrefixAddr(netip.MustParsePrefix("1.2.3.0/24"), 5678))
// Output:
// 1.2.3.4:5678
// 1.2.3.0:5678/24
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"math"
"net"
"net/netip"
"sync"
"time"
@ -19,7 +20,14 @@ import (
//
// See go doc net.Resolver.
type Resolver interface {
LookupIP(ctx context.Context, fam netutil.AddrFamily, host string) (ips []net.IP, err error)
// LookupNetIP returns a slice of host's IP addresses of family specified by
// fam, which must be either [netutil.AddrFamilyIPv4] or
// [netutil.AddrFamilyIPv6].
LookupNetIP(
ctx context.Context,
fam netutil.AddrFamily,
host string,
) (ips []netip.Addr, err error)
}
// DefaultResolver uses [net.DefaultResolver] to resolve addresses.
@ -28,17 +36,17 @@ type DefaultResolver struct{}
// type check
var _ Resolver = DefaultResolver{}
// LookupIP implements the [Resolver] interface for DefaultResolver.
func (DefaultResolver) LookupIP(
// LookupNetIP implements the [Resolver] interface for DefaultResolver.
func (DefaultResolver) LookupNetIP(
ctx context.Context,
fam netutil.AddrFamily,
host string,
) (ips []net.IP, err error) {
) (ips []netip.Addr, err error) {
switch fam {
case netutil.AddrFamilyIPv4:
return net.DefaultResolver.LookupIP(ctx, "ip4", host)
return net.DefaultResolver.LookupNetIP(ctx, "ip4", host)
case netutil.AddrFamilyIPv6:
return net.DefaultResolver.LookupIP(ctx, "ip6", host)
return net.DefaultResolver.LookupNetIP(ctx, "ip6", host)
default:
return nil, net.UnknownNetworkError(fam.String())
}
@ -50,7 +58,7 @@ type resolveCache map[string]*resolveCacheItem
// resolveCacheItem is an item of [resolveCache].
type resolveCacheItem struct {
refrTime time.Time
ips []net.IP
ips []netip.Addr
}
// CachingResolver caches resolved results for hosts for a certain time,
@ -83,13 +91,13 @@ func NewCachingResolver(resolver Resolver, ttl time.Duration) (c *CachingResolve
// type check
var _ Resolver = (*CachingResolver)(nil)
// LookupIP implements the [Resolver] interface for *CachingResolver. host
// LookupNetIP implements the [Resolver] interface for *CachingResolver. host
// should be normalized. Slice ips and its elements must not be mutated.
func (c *CachingResolver) LookupIP(
func (c *CachingResolver) LookupNetIP(
ctx context.Context,
fam netutil.AddrFamily,
host string,
) (ips []net.IP, err error) {
) (ips []netip.Addr, err error) {
c.mu.Lock()
defer c.mu.Unlock()
@ -119,20 +127,20 @@ func (c *CachingResolver) resolve(
fam netutil.AddrFamily,
host string,
) (item *resolveCacheItem, err error) {
var ips []net.IP
var ips []netip.Addr
refrTime := time.Now()
// Don't resolve IP addresses.
ip := ipFromHost(host, fam)
if ip != nil {
ips = []net.IP{ip}
if ip != (netip.Addr{}) {
ips = []netip.Addr{ip}
// Set the refresh time to the maximum date that time.Duration allows to
// prevent this item from refreshing.
refrTime = time.Unix(0, math.MaxInt64)
} else {
ips, err = c.resolver.LookupIP(ctx, fam, host)
ips, err = c.resolver.LookupNetIP(ctx, fam, host)
if err != nil {
if !isExpectedLookupError(fam, err) {
return nil, fmt.Errorf("resolving %s addr for %q: %w", fam, host, err)
@ -159,22 +167,25 @@ func (c *CachingResolver) resolve(
return item, nil
}
// ipFromHost returns a normalized IP address if host contains an IP address of
// the given address family.
func ipFromHost(host string, fam netutil.AddrFamily) (ip net.IP) {
ip = net.ParseIP(host)
if ip == nil {
return nil
// ipFromHost parses host as if it'd be an IP address of specified fam. It
// returns an empty netip.
func ipFromHost(host string, fam netutil.AddrFamily) (ip netip.Addr) {
var famFunc func(netip.Addr) (ok bool)
switch fam {
case netutil.AddrFamilyIPv4:
famFunc = netip.Addr.Is4
case netutil.AddrFamilyIPv6:
famFunc = netip.Addr.Is6
default:
return netip.Addr{}
}
ip, err := netip.ParseAddr(host)
if err != nil || !famFunc(ip) {
return netip.Addr{}
}
ip4 := ip.To4()
if fam == netutil.AddrFamilyIPv4 && ip4 != nil {
return ip4
} else if fam == netutil.AddrFamilyIPv6 && ip4 == nil {
return ip
}
return nil
}
// isExpectedLookupError returns true if the error is an expected lookup error.

View File

@ -2,7 +2,7 @@ package agdnet_test
import (
"context"
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
@ -17,14 +17,14 @@ func TestCachingResolver_Resolve(t *testing.T) {
const testHost = "addr.example"
var numLookups uint64
wantIPv4 := []net.IP{{1, 2, 3, 4}}
wantIPv6 := []net.IP{net.ParseIP("1234::5678")}
wantIPv4 := []netip.Addr{netip.MustParseAddr("1.2.3.4")}
wantIPv6 := []netip.Addr{netip.MustParseAddr("1234::5678")}
r := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
OnLookupNetIP: func(
ctx context.Context,
fam netutil.AddrFamily,
_ string,
) (ips []net.IP, err error) {
host string,
) (ips []netip.Addr, err error) {
numLookups++
if fam == netutil.AddrFamilyIPv4 {
@ -40,7 +40,7 @@ func TestCachingResolver_Resolve(t *testing.T) {
testCases := []struct {
name string
host string
wantIPs []net.IP
wantIPs []netip.Addr
wantNum uint64
fam netutil.AddrFamily
}{{
@ -78,7 +78,7 @@ func TestCachingResolver_Resolve(t *testing.T) {
ctx := context.Background()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := cached.LookupIP(ctx, tc.fam, tc.host)
got, err := cached.LookupNetIP(ctx, tc.fam, tc.host)
require.NoError(t, err)
assert.Equal(t, tc.wantNum, numLookups)

View File

@ -0,0 +1,27 @@
// Package agdprotobuf contains protobuf utils.
package agdprotobuf
import (
"fmt"
"net/netip"
)
// ByteSlicesToIPs converts a slice of byte slices into a slice of netip.Addrs.
func ByteSlicesToIPs(data [][]byte) (ips []netip.Addr, err error) {
if data == nil {
return nil, nil
}
ips = make([]netip.Addr, 0, len(data))
for i, ipData := range data {
var ip netip.Addr
err = ip.UnmarshalBinary(ipData)
if err != nil {
return nil, fmt.Errorf("ip at index %d: %w", i, err)
}
ips = append(ips, ip)
}
return ips, nil
}

View File

@ -0,0 +1,37 @@
// Package agdsync contains extensions and utilities for package sync from the
// standard library.
//
// TODO(a.garipov): Move to module golibs.
package agdsync
import "sync"
// TypedPool is the strongly typed version of [sync.Pool] that manages pointers
// to T.
type TypedPool[T any] struct {
pool *sync.Pool
}
// NewTypedPool returns a new strongly typed pool. newFunc must not be nil.
func NewTypedPool[T any](newFunc func() (v *T)) (p *TypedPool[T]) {
return &TypedPool[T]{
pool: &sync.Pool{
New: func() (v any) { return newFunc() },
},
}
}
// Get selects an arbitrary item from the pool, removes it from the pool, and
// returns it to the caller.
//
// See [sync.Pool.Get].
func (p *TypedPool[T]) Get() (v *T) {
return p.pool.Get().(*T)
}
// Put adds v to the pool.
//
// See [sync.Pool.Put].
func (p *TypedPool[T]) Put(v *T) {
p.pool.Put(v)
}

View File

@ -6,6 +6,7 @@ import (
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@ -56,6 +57,27 @@ func (r *Refresher) Refresh(ctx context.Context) (err error) {
return r.OnRefresh(ctx)
}
// Package access
// type check
var _ access.Interface = (*AccessManager)(nil)
// AccessManager is a [access.Interface] for tests.
type AccessManager struct {
OnIsBlockedHost func(host string, qt uint16) (blocked bool)
OnIsBlockedIP func(ip netip.Addr) (blocked bool, rule string)
}
// IsBlockedHost implements the [access.Interface] interface for *AccessManager.
func (a *AccessManager) IsBlockedHost(host string, qt uint16) (blocked bool) {
return a.OnIsBlockedHost(host, qt)
}
// IsBlockedIP implements the [access.Interface] interface for *AccessManager.
func (a *AccessManager) IsBlockedIP(ip netip.Addr) (blocked bool, rule string) {
return a.OnIsBlockedIP(ip)
}
// Package agdnet
// type check
@ -63,20 +85,20 @@ var _ agdnet.Resolver = (*Resolver)(nil)
// Resolver is an agd.Resolver for tests.
type Resolver struct {
OnLookupIP func(
OnLookupNetIP func(
ctx context.Context,
fam netutil.AddrFamily,
host string,
) (ips []net.IP, err error)
) (ips []netip.Addr, err error)
}
// LookupIP implements the agd.Resolver interface for *Resolver.
func (r *Resolver) LookupIP(
// LookupNetIP implements the [agd.Resolver] interface for *Resolver.
func (r *Resolver) LookupNetIP(
ctx context.Context,
fam netutil.AddrFamily,
host string,
) (ips []net.IP, err error) {
return r.OnLookupIP(ctx, fam, host)
) (ips []netip.Addr, err error) {
return r.OnLookupNetIP(ctx, fam, host)
}
// Package billstat

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,124 @@
syntax = "proto3";
option go_package = "./backendpb";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/empty.proto";
option java_multiple_files = true;
option java_package = "com.adguard.backend.dns.generated";
option java_outer_classname = "DNSProfilesProto";
option objc_class_prefix = "DNS";
service DNSService {
/*
Gets DNS profiles.
Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
The trailers headers will include a "sync_time", given in milliseconds,
that should be used for subsequent incremental DNS profile synchronization requests.
*/
rpc getDNSProfiles(DNSProfilesRequest) returns (stream DNSProfile);
/*
Stores devices activity.
*/
rpc saveDevicesBillingStat(stream DeviceBillingStat) returns (google.protobuf.Empty);
}
message DNSProfilesRequest {
google.protobuf.Timestamp sync_time = 1;
}
message DNSProfile {
string dns_id = 1;
bool filtering_enabled = 2;
bool query_log_enabled = 3;
bool deleted = 4;
SafeBrowsingSettings safe_browsing = 5;
ParentalSettings parental = 6;
RuleListsSettings rule_lists = 7;
repeated DeviceSettings devices = 8;
repeated string custom_rules = 9;
google.protobuf.Duration filtered_response_ttl = 10;
bool block_private_relay = 11;
bool block_firefox_canary = 12;
oneof blocking_mode {
BlockingModeCustomIP blocking_mode_custom_ip = 13;
BlockingModeNXDOMAIN blocking_mode_nxdomain = 14;
BlockingModeNullIP blocking_mode_null_ip = 15;
BlockingModeREFUSED blocking_mode_refused = 16;
}
}
message SafeBrowsingSettings {
bool enabled = 1;
bool block_dangerous_domains = 2;
bool block_nrd = 3;
}
message DeviceSettings {
string id = 1;
string name = 2;
bool filtering_enabled = 3;
bytes linked_ip = 4;
repeated bytes dedicated_ips = 5;
}
message ParentalSettings {
bool enabled = 1;
bool block_adult = 2;
bool general_safe_search = 3;
bool youtube_safe_search = 4;
repeated string blocked_services = 5;
ScheduleSettings schedule = 6;
}
message ScheduleSettings {
string tmz = 1;
WeeklyRange weeklyRange = 2;
}
message WeeklyRange {
DayRange mon = 1;
DayRange tue = 2;
DayRange wed = 3;
DayRange thu = 4;
DayRange fri = 5;
DayRange sat = 6;
DayRange sun = 7;
}
message DayRange {
google.protobuf.Duration start = 1;
google.protobuf.Duration end = 2;
}
message RuleListsSettings {
bool enabled = 1;
repeated string ids = 2;
}
message BlockingModeCustomIP {
bytes ipv4 = 1;
bytes ipv6 = 2;
}
message BlockingModeNXDOMAIN {}
message BlockingModeNullIP {}
message BlockingModeREFUSED {}
message DeviceBillingStat {
google.protobuf.Timestamp last_activity_time = 1;
string device_id = 2;
string client_country = 3;
// Protocol type. Possible values see here: https://bit.adguard.com/projects/DNS/repos/dns-server/browse#ql-properties
uint32 proto = 4;
uint32 asn = 5;
uint32 queries = 6;
}

View File

@ -0,0 +1,222 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v4.23.4
// source: backend.proto
package backendpb
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
DNSService_GetDNSProfiles_FullMethodName = "/DNSService/getDNSProfiles"
DNSService_SaveDevicesBillingStat_FullMethodName = "/DNSService/saveDevicesBillingStat"
)
// DNSServiceClient is the client API for DNSService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type DNSServiceClient interface {
// Gets DNS profiles.
//
// Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
//
// The trailers headers will include a "sync_time", given in milliseconds,
// that should be used for subsequent incremental DNS profile synchronization requests.
GetDNSProfiles(ctx context.Context, in *DNSProfilesRequest, opts ...grpc.CallOption) (DNSService_GetDNSProfilesClient, error)
// Stores devices activity.
SaveDevicesBillingStat(ctx context.Context, opts ...grpc.CallOption) (DNSService_SaveDevicesBillingStatClient, error)
}
type dNSServiceClient struct {
cc grpc.ClientConnInterface
}
func NewDNSServiceClient(cc grpc.ClientConnInterface) DNSServiceClient {
return &dNSServiceClient{cc}
}
func (c *dNSServiceClient) GetDNSProfiles(ctx context.Context, in *DNSProfilesRequest, opts ...grpc.CallOption) (DNSService_GetDNSProfilesClient, error) {
stream, err := c.cc.NewStream(ctx, &DNSService_ServiceDesc.Streams[0], DNSService_GetDNSProfiles_FullMethodName, opts...)
if err != nil {
return nil, err
}
x := &dNSServiceGetDNSProfilesClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type DNSService_GetDNSProfilesClient interface {
Recv() (*DNSProfile, error)
grpc.ClientStream
}
type dNSServiceGetDNSProfilesClient struct {
grpc.ClientStream
}
func (x *dNSServiceGetDNSProfilesClient) Recv() (*DNSProfile, error) {
m := new(DNSProfile)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *dNSServiceClient) SaveDevicesBillingStat(ctx context.Context, opts ...grpc.CallOption) (DNSService_SaveDevicesBillingStatClient, error) {
stream, err := c.cc.NewStream(ctx, &DNSService_ServiceDesc.Streams[1], DNSService_SaveDevicesBillingStat_FullMethodName, opts...)
if err != nil {
return nil, err
}
x := &dNSServiceSaveDevicesBillingStatClient{stream}
return x, nil
}
type DNSService_SaveDevicesBillingStatClient interface {
Send(*DeviceBillingStat) error
CloseAndRecv() (*emptypb.Empty, error)
grpc.ClientStream
}
type dNSServiceSaveDevicesBillingStatClient struct {
grpc.ClientStream
}
func (x *dNSServiceSaveDevicesBillingStatClient) Send(m *DeviceBillingStat) error {
return x.ClientStream.SendMsg(m)
}
func (x *dNSServiceSaveDevicesBillingStatClient) CloseAndRecv() (*emptypb.Empty, error) {
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
m := new(emptypb.Empty)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DNSServiceServer is the server API for DNSService service.
// All implementations must embed UnimplementedDNSServiceServer
// for forward compatibility
type DNSServiceServer interface {
// Gets DNS profiles.
//
// Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
//
// The trailers headers will include a "sync_time", given in milliseconds,
// that should be used for subsequent incremental DNS profile synchronization requests.
GetDNSProfiles(*DNSProfilesRequest, DNSService_GetDNSProfilesServer) error
// Stores devices activity.
SaveDevicesBillingStat(DNSService_SaveDevicesBillingStatServer) error
mustEmbedUnimplementedDNSServiceServer()
}
// UnimplementedDNSServiceServer must be embedded to have forward compatible implementations.
type UnimplementedDNSServiceServer struct {
}
func (UnimplementedDNSServiceServer) GetDNSProfiles(*DNSProfilesRequest, DNSService_GetDNSProfilesServer) error {
return status.Errorf(codes.Unimplemented, "method GetDNSProfiles not implemented")
}
func (UnimplementedDNSServiceServer) SaveDevicesBillingStat(DNSService_SaveDevicesBillingStatServer) error {
return status.Errorf(codes.Unimplemented, "method SaveDevicesBillingStat not implemented")
}
func (UnimplementedDNSServiceServer) mustEmbedUnimplementedDNSServiceServer() {}
// UnsafeDNSServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to DNSServiceServer will
// result in compilation errors.
type UnsafeDNSServiceServer interface {
mustEmbedUnimplementedDNSServiceServer()
}
func RegisterDNSServiceServer(s grpc.ServiceRegistrar, srv DNSServiceServer) {
s.RegisterService(&DNSService_ServiceDesc, srv)
}
func _DNSService_GetDNSProfiles_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(DNSProfilesRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DNSServiceServer).GetDNSProfiles(m, &dNSServiceGetDNSProfilesServer{stream})
}
type DNSService_GetDNSProfilesServer interface {
Send(*DNSProfile) error
grpc.ServerStream
}
type dNSServiceGetDNSProfilesServer struct {
grpc.ServerStream
}
func (x *dNSServiceGetDNSProfilesServer) Send(m *DNSProfile) error {
return x.ServerStream.SendMsg(m)
}
func _DNSService_SaveDevicesBillingStat_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(DNSServiceServer).SaveDevicesBillingStat(&dNSServiceSaveDevicesBillingStatServer{stream})
}
type DNSService_SaveDevicesBillingStatServer interface {
SendAndClose(*emptypb.Empty) error
Recv() (*DeviceBillingStat, error)
grpc.ServerStream
}
type dNSServiceSaveDevicesBillingStatServer struct {
grpc.ServerStream
}
func (x *dNSServiceSaveDevicesBillingStatServer) SendAndClose(m *emptypb.Empty) error {
return x.ServerStream.SendMsg(m)
}
func (x *dNSServiceSaveDevicesBillingStatServer) Recv() (*DeviceBillingStat, error) {
m := new(DeviceBillingStat)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DNSService_ServiceDesc is the grpc.ServiceDesc for DNSService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var DNSService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "DNSService",
HandlerType: (*DNSServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "getDNSProfiles",
Handler: _DNSService_GetDNSProfiles_Handler,
ServerStreams: true,
},
{
StreamName: "saveDevicesBillingStat",
Handler: _DNSService_SaveDevicesBillingStat_Handler,
ClientStreams: true,
},
},
Metadata: "backend.proto",
}

View File

@ -0,0 +1,43 @@
// Package backendpb contains the protobuf structures for the backend API.
package backendpb
import (
"context"
"fmt"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
// newClient returns new properly initialized DNSServiceClient.
func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
var creds credentials.TransportCredentials
switch s := apiURL.Scheme; s {
case "grpc":
creds = insecure.NewCredentials()
case "grpcs":
// Use a nil [tls.Config] to get the default TLS configuration.
creds = credentials.NewTLS(nil)
default:
return nil, fmt.Errorf("bad grpc url scheme %q", s)
}
conn, err := grpc.Dial(apiURL.Host, grpc.WithTransportCredentials(creds))
if err != nil {
return nil, fmt.Errorf("dialing: %w", err)
}
// Immediately make a connection attempt, since the constructor is often
// called right before the initial refresh.
conn.Connect()
return NewDNSServiceClient(conn), nil
}
// reportf is a helper method for reporting non-critical errors.
func reportf(ctx context.Context, errColl agd.ErrorCollector, format string, args ...any) {
agd.Collectf(ctx, errColl, "backendpb: "+format, args...)
}

View File

@ -0,0 +1,38 @@
package backendpb_test
import "github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
// testDNSServiceServer is the [backendpb.DNSServiceServer] for tests.
//
// TODO(d.kolyshev): Use this to remove as much as possible from the internal
// test.
type testDNSServiceServer struct {
backendpb.UnimplementedDNSServiceServer
OnGetDNSProfiles func(
req *backendpb.DNSProfilesRequest,
srv backendpb.DNSService_GetDNSProfilesServer,
) (err error)
OnSaveDevicesBillingStat func(
srv backendpb.DNSService_SaveDevicesBillingStatServer,
) (err error)
}
// type check
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
// GetDNSProfiles implements the [backendpb.DNSServiceServer] interface for
// *testDNSServiceServer
func (s *testDNSServiceServer) GetDNSProfiles(
req *backendpb.DNSProfilesRequest,
srv backendpb.DNSService_GetDNSProfilesServer,
) (err error) {
return s.OnGetDNSProfiles(req, srv)
}
// SaveDevicesBillingStat implements the [backendpb.DNSServiceServer] interface
// for *testDNSServiceServer
func (s *testDNSServiceServer) SaveDevicesBillingStat(
srv backendpb.DNSService_SaveDevicesBillingStatServer,
) (err error) {
return s.OnSaveDevicesBillingStat(srv)
}

View File

@ -0,0 +1,101 @@
package backendpb
import (
"context"
"fmt"
"io"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/golibs/errors"
"google.golang.org/protobuf/types/known/timestamppb"
)
// BillStatConfig is the configuration structure for the business logic backend
// billing statistics uploader.
type BillStatConfig struct {
// ErrColl is the error collector that is used to collect critical and
// non-critical errors.
ErrColl agd.ErrorCollector
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs".
Endpoint *url.URL
}
// NewBillStat creates a new billing statistics uploader. c must not be nil.
func NewBillStat(c *BillStatConfig) (b *BillStat, err error) {
b = &BillStat{
errColl: c.ErrColl,
}
b.client, err = newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return b, nil
}
// BillStat is the implementation of the [billstat.Uploader] interface that
// uploads the billing statistics to the business logic backend. It is safe for
// concurrent use.
//
// TODO(a.garipov): Consider uniting with [ProfileStorage] into a single
// backendpb.Client.
type BillStat struct {
errColl agd.ErrorCollector
// client is the current GRPC client.
client DNSServiceClient
}
// type check
var _ billstat.Uploader = (*BillStat)(nil)
// Upload implements the [billstat.Uploader] interface for *BillStat.
func (b *BillStat) Upload(ctx context.Context, records billstat.Records) (err error) {
if len(records) == 0 {
return nil
}
stream, err := b.client.SaveDevicesBillingStat(ctx)
if err != nil {
return fmt.Errorf("opening stream: %w", err)
}
for deviceID, record := range records {
if record == nil {
reportf(ctx, b.errColl, "device %q: null record", deviceID)
continue
}
sendErr := stream.Send(recordToProtobuf(record, deviceID))
if sendErr != nil {
return fmt.Errorf("uploading device %q record: %w", deviceID, sendErr)
}
}
_, err = stream.CloseAndRecv()
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("finishing stream: %w", err)
}
return nil
}
// recordToProtobuf converts a billstat record structure into the protobuf
// structure.
func recordToProtobuf(r *billstat.Record, devID agd.DeviceID) (s *DeviceBillingStat) {
return &DeviceBillingStat{
LastActivityTime: timestamppb.New(r.Time),
DeviceId: string(devID),
ClientCountry: string(r.Country),
Proto: uint32(r.Proto),
Asn: uint32(r.ASN),
Queries: uint32(r.Queries),
}
}

View File

@ -0,0 +1,104 @@
package backendpb_test
import (
"context"
"io"
"net"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestBillStat_Upload(t *testing.T) {
const (
wantDeviceID = "test"
invalidDeviceID = "invalid"
)
wantRecord := &billstat.Record{
Time: time.Time{},
Country: agd.CountryCY,
ASN: 1221,
Queries: 1122,
Proto: agd.ProtoDNS,
}
records := billstat.Records{
wantDeviceID: wantRecord,
invalidDeviceID: nil,
}
srv := &testDNSServiceServer{
OnSaveDevicesBillingStat: func(
srv backendpb.DNSService_SaveDevicesBillingStatServer,
) (err error) {
pt := &testutil.PanicT{}
for {
data, recvErr := srv.Recv()
if recvErr != nil && errors.Is(recvErr, io.EOF) {
return srv.SendAndClose(&emptypb.Empty{})
}
require.NoError(t, recvErr)
assert.Equal(pt, wantDeviceID, data.DeviceId)
assert.Equal(pt, uint32(wantRecord.ASN), data.Asn)
assert.Equal(pt, string(wantRecord.Country), data.ClientCountry)
assert.Equal(pt, timestamppb.New(wantRecord.Time), data.LastActivityTime)
assert.Equal(pt, uint32(wantRecord.Proto), data.Proto)
assert.Equal(pt, uint32(wantRecord.Queries), data.Queries)
}
},
}
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterDNSServiceServer(grpcSrv, srv)
go func() {
pt := &testutil.PanicT{}
srvErr := grpcSrv.Serve(l)
require.NoError(pt, srvErr)
}()
t.Cleanup(grpcSrv.GracefulStop)
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
testutil.AssertErrorMsg(t, `backendpb: device "invalid": null record`, err)
},
}
b, err := backendpb.NewBillStat(&backendpb.BillStatConfig{
ErrColl: errColl,
Endpoint: &url.URL{
Scheme: "grpc",
Host: l.Addr().String(),
},
})
require.NoError(t, err)
ctx := context.Background()
err = b.Upload(ctx, records)
require.NoError(t, err)
}

View File

@ -0,0 +1,457 @@
package backendpb
import (
"context"
"fmt"
"io"
"net/netip"
"net/url"
"strconv"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdprotobuf"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/timestamppb"
)
// ProfileStorageConfig is the configuration for the business logic backend
// profile storage.
type ProfileStorageConfig struct {
// ErrColl is the error collector that is used to collect critical and
// non-critical errors.
ErrColl agd.ErrorCollector
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs".
Endpoint *url.URL
}
// ProfileStorage is the implementation of the [profiledb.Storage] interface
// that retrieves the profile and device information from the business logic
// backend. It is safe for concurrent use.
type ProfileStorage struct {
errColl agd.ErrorCollector
// client is the current GRPC client.
client DNSServiceClient
}
// NewProfileStorage returns a new [ProfileStorage] that retrieves information
// from the business logic backend.
func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage, err error) {
client, err := newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return &ProfileStorage{
client: client,
errColl: c.ErrColl,
}, nil
}
// type check
var _ profiledb.Storage = (*ProfileStorage)(nil)
// Profiles implements the [profiledb.Storage] interface for *ProfileStorage.
func (s *ProfileStorage) Profiles(
ctx context.Context,
req *profiledb.StorageRequest,
) (resp *profiledb.StorageResponse, err error) {
stream, err := s.client.GetDNSProfiles(ctx, toProtobuf(req))
if err != nil {
return nil, fmt.Errorf("loading profiles: %w", err)
}
defer func() { err = errors.WithDeferred(err, stream.CloseSend()) }()
resp = &profiledb.StorageResponse{
Profiles: []*agd.Profile{},
Devices: []*agd.Device{},
}
stats := &profilesCallStats{
isFullSync: req.SyncTime == time.Time{},
}
for {
stats.startRecv()
profile, profErr := stream.Recv()
if profErr != nil {
if errors.Is(profErr, io.EOF) {
break
}
return nil, fmt.Errorf("receiving profile: %w", profErr)
}
stats.endRecv()
stats.startDec()
prof, devices, profErr := profile.toInternal(ctx, time.Now(), s.errColl)
if profErr != nil {
reportf(ctx, s.errColl, "loading profile: %w", profErr)
continue
}
stats.endDec()
resp.Profiles = append(resp.Profiles, prof)
resp.Devices = append(resp.Devices, devices...)
}
stats.report()
trailer := stream.Trailer()
resp.SyncTime, err = syncTimeFromTrailer(trailer)
if err != nil {
return nil, fmt.Errorf("retrieving sync_time: %w", err)
}
return resp, nil
}
// toInternal converts the protobuf-encoded data into a profile structure.
func (x *DNSProfile) toInternal(
ctx context.Context,
updTime time.Time,
errColl agd.ErrorCollector,
) (profile *agd.Profile, devices []*agd.Device, err error) {
if x == nil {
return nil, nil, fmt.Errorf("profile is nil")
}
parental, err := x.Parental.toInternal(ctx, errColl)
if err != nil {
return nil, nil, fmt.Errorf("parental: %w", err)
}
m, err := blockingModeToInternal(x.BlockingMode)
if err != nil {
return nil, nil, fmt.Errorf("blocking mode: %w", err)
}
devices, deviceIds := devicesToInternal(ctx, x.Devices, errColl)
listsEnabled, listIDs := x.RuleLists.toInternal(ctx, errColl)
profID, err := agd.NewProfileID(x.DnsId)
if err != nil {
return nil, nil, fmt.Errorf("id: %w", err)
}
var fltRespTTL time.Duration
if respTTL := x.FilteredResponseTtl; respTTL != nil {
fltRespTTL = respTTL.AsDuration()
}
return &agd.Profile{
Parental: parental,
BlockingMode: m,
ID: profID,
UpdateTime: updTime,
DeviceIDs: deviceIds,
RuleListIDs: listIDs,
CustomRules: rulesToInternal(ctx, x.CustomRules, errColl),
FilteredResponseTTL: fltRespTTL,
FilteringEnabled: x.FilteringEnabled,
SafeBrowsing: x.SafeBrowsing.toInternal(),
RuleListsEnabled: listsEnabled,
QueryLogEnabled: x.QueryLogEnabled,
Deleted: x.Deleted,
BlockPrivateRelay: x.BlockPrivateRelay,
BlockFirefoxCanary: x.BlockFirefoxCanary,
}, devices, nil
}
// toInternal converts a protobuf parental-settings structure to an internal
// one. If x is nil, toInternal returns nil.
func (x *ParentalSettings) toInternal(
ctx context.Context,
errColl agd.ErrorCollector,
) (s *agd.ParentalProtectionSettings, err error) {
if x == nil {
return nil, nil
}
schedule, err := x.Schedule.toInternal()
if err != nil {
return nil, fmt.Errorf("schedule: %w", err)
}
return &agd.ParentalProtectionSettings{
Schedule: schedule,
BlockedServices: blockedSvcsToInternal(ctx, errColl, x.BlockedServices),
Enabled: x.Enabled,
BlockAdult: x.BlockAdult,
GeneralSafeSearch: x.GeneralSafeSearch,
YoutubeSafeSearch: x.YoutubeSafeSearch,
}, nil
}
// toInternal converts protobuf safe-browsing settings to an internal structure.
// If x is nil, toInternal returns nil.
func (x *SafeBrowsingSettings) toInternal() (sb *agd.SafeBrowsingSettings) {
if x == nil {
return nil
}
return &agd.SafeBrowsingSettings{
Enabled: x.Enabled,
BlockDangerousDomains: x.BlockDangerousDomains,
BlockNewlyRegisteredDomains: x.BlockNrd,
}
}
// blockedSvcsToInternal is a helper that converts the blocked service IDs from
// the backend response to AdGuard DNS blocked service IDs.
func blockedSvcsToInternal(
ctx context.Context,
errColl agd.ErrorCollector,
respSvcs []string,
) (svcs []agd.BlockedServiceID) {
l := len(respSvcs)
if l == 0 {
return nil
}
svcs = make([]agd.BlockedServiceID, 0, l)
for i, s := range respSvcs {
id, err := agd.NewBlockedServiceID(s)
if err != nil {
reportf(ctx, errColl, "blocked service at index %d: %w", i, err)
continue
}
svcs = append(svcs, id)
}
return svcs
}
// toInternal converts a protobuf protection-schedule structure to an internal
// one. If x is nil, toInternal returns nil.
func (x *ScheduleSettings) toInternal() (sch *agd.ParentalProtectionSchedule, err error) {
if x == nil {
return nil, nil
}
sch = &agd.ParentalProtectionSchedule{}
sch.TimeZone, err = agdtime.LoadLocation(x.Tmz)
if err != nil {
return nil, fmt.Errorf("loading timezone: %w", err)
}
sch.Week = &agd.WeeklySchedule{}
w := x.WeeklyRange
days := []*DayRange{w.Sun, w.Mon, w.Tue, w.Wed, w.Thu, w.Fri, w.Sat}
for i, d := range days {
if d == nil {
sch.Week[i] = agd.ZeroLengthDayRange()
continue
}
sch.Week[i] = agd.DayRange{
Start: uint16(d.Start.AsDuration().Minutes()),
End: uint16(d.End.AsDuration().Minutes()),
}
}
for i, r := range sch.Week {
err = r.Validate()
if err != nil {
return nil, fmt.Errorf("weekday %s: %w", time.Weekday(i), err)
}
}
return sch, nil
}
// blockingModeToInternal converts a protobuf blocking-mode sum-type to an
// internal one. If pbm is nil, blockingModeToInternal returns a null-IP
// blocking mode.
func blockingModeToInternal(pbm isDNSProfile_BlockingMode) (m dnsmsg.BlockingModeCodec, err error) {
switch pbm := pbm.(type) {
case nil:
m.Mode = &dnsmsg.BlockingModeNullIP{}
case *DNSProfile_BlockingModeCustomIp:
custom := &dnsmsg.BlockingModeCustomIP{}
err = custom.IPv4.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv4)
if err != nil {
return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv4: %w", err)
}
err = custom.IPv6.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv6)
if err != nil {
return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv6: %w", err)
}
m.Mode = custom
case *DNSProfile_BlockingModeNxdomain:
m.Mode = &dnsmsg.BlockingModeNXDOMAIN{}
case *DNSProfile_BlockingModeNullIp:
m.Mode = &dnsmsg.BlockingModeNullIP{}
case *DNSProfile_BlockingModeRefused:
m.Mode = &dnsmsg.BlockingModeREFUSED{}
default:
// Consider unhandled type-switch cases programmer errors.
panic(fmt.Errorf("bad pb blocking mode %T(%[1]v)", pbm))
}
return m, nil
}
// devicesToInternal is a helper that converts the devices from protobuf to
// AdGuard DNS devices.
func devicesToInternal(
ctx context.Context,
ds []*DeviceSettings,
errColl agd.ErrorCollector,
) (out []*agd.Device, ids []agd.DeviceID) {
l := len(ds)
if l == 0 {
return nil, nil
}
out = make([]*agd.Device, 0, l)
for _, d := range ds {
dev, err := d.toInternal()
if err != nil {
reportf(ctx, errColl, "invalid device settings: %w", err)
continue
}
ids = append(ids, dev.ID)
out = append(out, dev)
}
return out, ids
}
// toInternal is a helper that converts device settings from backend protobuf
// response to AdGuard DNS device object.
func (ds *DeviceSettings) toInternal() (dev *agd.Device, err error) {
if ds == nil {
return nil, fmt.Errorf("device is nil")
}
var linkedIP netip.Addr
err = linkedIP.UnmarshalBinary(ds.LinkedIp)
if err != nil {
return nil, fmt.Errorf("linked ip: %w", err)
}
var dedicatedIPs []netip.Addr
dedicatedIPs, err = agdprotobuf.ByteSlicesToIPs(ds.DedicatedIps)
if err != nil {
return nil, fmt.Errorf("dedicated ips: %w", err)
}
id, err := agd.NewDeviceID(ds.Id)
if err != nil {
return nil, fmt.Errorf("device id: %s: %w", ds.Id, err)
}
name, err := agd.NewDeviceName(ds.Name)
if err != nil {
return nil, fmt.Errorf("device name: %s: %w", ds.Name, err)
}
return &agd.Device{
ID: id,
Name: name,
LinkedIP: linkedIP,
DedicatedIPs: dedicatedIPs,
FilteringEnabled: ds.FilteringEnabled,
}, nil
}
// rulesToInternal is a helper that converts the filter rules from the backend
// response to AdGuard DNS filtering rules.
func rulesToInternal(
ctx context.Context,
respRules []string,
errColl agd.ErrorCollector,
) (rules []agd.FilterRuleText) {
l := len(respRules)
if l == 0 {
return nil
}
rules = make([]agd.FilterRuleText, 0, l)
for i, r := range respRules {
text, err := agd.NewFilterRuleText(r)
if err != nil {
reportf(ctx, errColl, "rule at index %d: %w", i, err)
continue
}
rules = append(rules, text)
}
return rules
}
// toInternal is a helper that converts the filter lists from the backend
// response to AdGuard DNS filter list ids. If x is nil, toInternal returns
// false and nil.
func (x *RuleListsSettings) toInternal(
ctx context.Context,
errColl agd.ErrorCollector,
) (enabled bool, filterLists []agd.FilterListID) {
if x == nil {
return false, nil
}
l := len(x.Ids)
if l == 0 {
return x.Enabled, nil
}
filterLists = make([]agd.FilterListID, 0, l)
for _, f := range x.Ids {
id, err := agd.NewFilterListID(f)
if err != nil {
reportf(ctx, errColl, "invalid filter id: %s: %w", f, err)
continue
}
filterLists = append(filterLists, id)
}
return x.Enabled, filterLists
}
// toProtobuf converts a storage request structure into the protobuf structure.
func toProtobuf(r *profiledb.StorageRequest) (req *DNSProfilesRequest) {
return &DNSProfilesRequest{
SyncTime: timestamppb.New(r.SyncTime),
}
}
// syncTimeFromTrailer returns sync time from trailer metadata. Trailer
// metadata must contain "sync_time" field with milliseconds presentation of
// sync time.
func syncTimeFromTrailer(trailer metadata.MD) (syncTime time.Time, err error) {
st := trailer.Get("sync_time")
if len(st) == 0 {
return syncTime, fmt.Errorf("empty value")
}
syncTimeMs, err := strconv.ParseInt(st[0], 10, 64)
if err != nil {
return syncTime, fmt.Errorf("invalid value: %w", err)
}
return time.Unix(0, syncTimeMs*time.Millisecond.Nanoseconds()), nil
}

View File

@ -0,0 +1,459 @@
package backendpb
import (
"context"
"net/netip"
"strconv"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/durationpb"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testProfileID is the common profile ID for tests.
const testProfileID agd.ProfileID = "prof1234"
// TestUpdTime is the common update time for tests.
var TestUpdTime = time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
func TestDNSProfile_ToInternal(t *testing.T) {
ctx := context.Background()
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
panic(err)
},
}
t.Run("success", func(t *testing.T) {
got, gotDevices, err := NewTestDNSProfile(t).toInternal(ctx, TestUpdTime, errColl)
require.NoError(t, err)
assert.Equal(t, newProfile(t), got)
assert.Equal(t, newDevices(t), gotDevices)
})
t.Run("success_bad_data", func(t *testing.T) {
var errCollErr error
savingErrColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
errCollErr = err
},
}
got, gotDevices, err := newDNSProfileWithBadData(t).toInternal(
ctx,
TestUpdTime,
savingErrColl,
)
require.NoError(t, err)
require.Error(t, errCollErr)
// See the TODO in [blockingModeToInternal].
wantProf := newProfile(t)
wantProf.BlockingMode = dnsmsg.BlockingModeCodec{
Mode: &dnsmsg.BlockingModeNullIP{},
}
assert.Equal(t, wantProf, got)
assert.Equal(t, newDevices(t), gotDevices)
})
t.Run("empty", func(t *testing.T) {
var emptyDNSProfile *DNSProfile
_, _, err := emptyDNSProfile.toInternal(ctx, TestUpdTime, errColl)
testutil.AssertErrorMsg(t, "profile is nil", err)
})
t.Run("deleted", func(t *testing.T) {
dp := &DNSProfile{
DnsId: string(testProfileID),
Deleted: true,
}
got, gotDevices, err := dp.toInternal(ctx, TestUpdTime, errColl)
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got.ID, testProfileID)
assert.True(t, got.Deleted)
assert.Empty(t, gotDevices)
})
t.Run("inv_parental_sch_tmz", func(t *testing.T) {
dp := NewTestDNSProfile(t)
dp.Parental.Schedule.Tmz = "invalid"
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
testutil.AssertErrorMsg(t, "parental: schedule: loading timezone: unknown time zone invalid", err)
})
t.Run("inv_parental_sch_day_range", func(t *testing.T) {
dp := NewTestDNSProfile(t)
dp.Parental.Schedule.WeeklyRange.Sun = &DayRange{
Start: durationpb.New(1000000000000),
End: nil,
}
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
testutil.AssertErrorMsg(t, "parental: schedule: weekday Sunday: bad day range: end 0 less than start 16", err)
})
t.Run("inv_blocking_mode_v4", func(t *testing.T) {
dp := NewTestDNSProfile(t)
bm := dp.BlockingMode.(*DNSProfile_BlockingModeCustomIp)
bm.BlockingModeCustomIp.Ipv4 = []byte("1")
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
testutil.AssertErrorMsg(t, "blocking mode: bad custom ipv4: unexpected slice size", err)
})
t.Run("inv_blocking_mode_v6", func(t *testing.T) {
dp := NewTestDNSProfile(t)
bm := dp.BlockingMode.(*DNSProfile_BlockingModeCustomIp)
bm.BlockingModeCustomIp.Ipv6 = []byte("1")
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
testutil.AssertErrorMsg(t, "blocking mode: bad custom ipv6: unexpected slice size", err)
})
}
// newDNSProfileWithBadData returns a new instance of *DNSProfile with bad data
// for tests.
func newDNSProfileWithBadData(tb testing.TB) (dp *DNSProfile) {
tb.Helper()
dayRange := &DayRange{
Start: durationpb.New(0),
End: durationpb.New(59 * time.Minute),
}
devices := []*DeviceSettings{{
Id: "118ffe93",
Name: "118ffe93-name",
FilteringEnabled: false,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
DedicatedIps: [][]byte{ipToBytes(tb, netip.MustParseAddr("1.1.1.2"))},
}, {
Id: "b9e1a762",
Name: "b9e1a762-name",
FilteringEnabled: true,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("2.2.2.2")),
DedicatedIps: nil,
}, {
Id: "invalid-too-long-device-id",
Name: "device_name",
FilteringEnabled: true,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
DedicatedIps: nil,
}, {
Id: "dev-name",
Name: "invalid-too-long-device-name-invalid-too-long-device-name-" +
"invalid-too-long-device-name-invalid-too-long-device-name-" +
"invalid-too-long-device-name-invalid-too-long-device-name",
FilteringEnabled: true,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
DedicatedIps: nil,
}, {
Id: "inv-ip",
Name: "test-name",
FilteringEnabled: true,
LinkedIp: []byte("1"),
DedicatedIps: nil,
}, {
Id: "inv-d-ip",
Name: "test-name",
FilteringEnabled: true,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
DedicatedIps: [][]byte{[]byte("1")},
}}
return &DNSProfile{
DnsId: string(testProfileID),
FilteringEnabled: true,
QueryLogEnabled: true,
Deleted: false,
SafeBrowsing: &SafeBrowsingSettings{
Enabled: true,
BlockDangerousDomains: true,
BlockNrd: false,
},
Parental: &ParentalSettings{
Enabled: false,
BlockAdult: false,
GeneralSafeSearch: false,
YoutubeSafeSearch: false,
BlockedServices: []string{"youtube", "inv_blocked_svc\r"},
Schedule: &ScheduleSettings{
Tmz: "GMT",
WeeklyRange: &WeeklyRange{
Sun: nil,
Mon: dayRange,
Tue: dayRange,
Wed: dayRange,
Thu: dayRange,
Fri: dayRange,
Sat: nil,
},
},
},
RuleLists: &RuleListsSettings{
Enabled: true,
Ids: []string{"1", "inv_filter_id\r"},
},
Devices: devices,
CustomRules: []string{"||example.org^"},
FilteredResponseTtl: durationpb.New(10 * time.Second),
BlockPrivateRelay: true,
BlockFirefoxCanary: true,
}
}
// NewTestDNSProfile returns a new instance of *DNSProfile for tests.
func NewTestDNSProfile(tb testing.TB) (dp *DNSProfile) {
tb.Helper()
dayRange := &DayRange{
Start: durationpb.New(0),
End: durationpb.New(59 * time.Minute),
}
devices := []*DeviceSettings{{
Id: "118ffe93",
Name: "118ffe93-name",
FilteringEnabled: false,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
DedicatedIps: [][]byte{ipToBytes(tb, netip.MustParseAddr("1.1.1.2"))},
}, {
Id: "b9e1a762",
Name: "b9e1a762-name",
FilteringEnabled: true,
LinkedIp: ipToBytes(tb, netip.MustParseAddr("2.2.2.2")),
DedicatedIps: nil,
}}
return &DNSProfile{
DnsId: string(testProfileID),
FilteringEnabled: true,
QueryLogEnabled: true,
Deleted: false,
SafeBrowsing: &SafeBrowsingSettings{
Enabled: true,
BlockDangerousDomains: true,
BlockNrd: false,
},
Parental: &ParentalSettings{
Enabled: false,
BlockAdult: false,
GeneralSafeSearch: false,
YoutubeSafeSearch: false,
BlockedServices: []string{"youtube"},
Schedule: &ScheduleSettings{
Tmz: "GMT",
WeeklyRange: &WeeklyRange{
Sun: nil,
Mon: dayRange,
Tue: dayRange,
Wed: dayRange,
Thu: dayRange,
Fri: dayRange,
Sat: nil,
},
},
},
RuleLists: &RuleListsSettings{
Enabled: true,
Ids: []string{"1"},
},
Devices: devices,
CustomRules: []string{"||example.org^"},
FilteredResponseTtl: durationpb.New(10 * time.Second),
BlockPrivateRelay: true,
BlockFirefoxCanary: true,
BlockingMode: &DNSProfile_BlockingModeCustomIp{
BlockingModeCustomIp: &BlockingModeCustomIP{
Ipv4: ipToBytes(tb, netip.MustParseAddr("1.2.3.4")),
Ipv6: ipToBytes(tb, netip.MustParseAddr("1234::cdef")),
},
},
}
}
// newProfile returns a new profile for tests.
func newProfile(tb testing.TB) (p *agd.Profile) {
tb.Helper()
wantLoc, err := agdtime.LoadLocation("GMT")
require.NoError(tb, err)
dayRange := agd.DayRange{
Start: 0,
End: 59,
}
wantParental := &agd.ParentalProtectionSettings{
Schedule: &agd.ParentalProtectionSchedule{
Week: &agd.WeeklySchedule{
agd.ZeroLengthDayRange(),
dayRange,
dayRange,
dayRange,
dayRange,
dayRange,
agd.ZeroLengthDayRange(),
},
TimeZone: wantLoc,
},
BlockedServices: []agd.BlockedServiceID{"youtube"},
Enabled: false,
BlockAdult: false,
GeneralSafeSearch: false,
YoutubeSafeSearch: false,
}
wantSafeBrowsing := &agd.SafeBrowsingSettings{
Enabled: true,
BlockDangerousDomains: true,
BlockNewlyRegisteredDomains: false,
}
wantBlockingMode := dnsmsg.BlockingModeCodec{
Mode: &dnsmsg.BlockingModeCustomIP{
IPv4: netip.MustParseAddr("1.2.3.4"),
IPv6: netip.MustParseAddr("1234::cdef"),
},
}
return &agd.Profile{
Parental: wantParental,
BlockingMode: wantBlockingMode,
ID: testProfileID,
UpdateTime: TestUpdTime,
DeviceIDs: []agd.DeviceID{
"118ffe93",
"b9e1a762",
},
RuleListIDs: []agd.FilterListID{"1"},
CustomRules: []agd.FilterRuleText{"||example.org^"},
FilteredResponseTTL: 10 * time.Second,
SafeBrowsing: wantSafeBrowsing,
RuleListsEnabled: true,
FilteringEnabled: true,
QueryLogEnabled: true,
Deleted: false,
BlockPrivateRelay: true,
BlockFirefoxCanary: true,
}
}
// newDevices returns a slice of test devices.
func newDevices(t *testing.T) (d []*agd.Device) {
t.Helper()
return []*agd.Device{{
ID: "118ffe93",
LinkedIP: netip.MustParseAddr("1.1.1.1"),
Name: "118ffe93-name",
DedicatedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
FilteringEnabled: false,
}, {
ID: "b9e1a762",
LinkedIP: netip.MustParseAddr("2.2.2.2"),
Name: "b9e1a762-name",
DedicatedIPs: nil,
FilteringEnabled: true,
}}
}
// ipToBytes is a wrapper around netip.Addr.MarshalBinary.
func ipToBytes(tb testing.TB, ip netip.Addr) (b []byte) {
tb.Helper()
b, err := ip.MarshalBinary()
require.NoError(tb, err)
return b
}
func TestSyncTimeFromTrailer(t *testing.T) {
milliseconds := strconv.FormatInt(TestUpdTime.UnixMilli(), 10)
testCases := []struct {
wantError string
want time.Time
name string
in metadata.MD
}{{
wantError: "empty value",
want: time.Time{},
name: "no_key",
in: metadata.MD{},
}, {
wantError: "empty value",
want: time.Time{},
name: "empty_key",
in: metadata.MD{"sync_time": []string{}},
}, {
wantError: `invalid value: strconv.ParseInt: parsing "": invalid syntax`,
want: time.Time{},
name: "empty_value",
in: metadata.MD{"sync_time": []string{""}},
}, {
wantError: "",
want: TestUpdTime,
name: "success",
in: metadata.MD{"sync_time": []string{milliseconds}},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
syncTime, err := syncTimeFromTrailer(tc.in)
testutil.AssertErrorMsg(t, tc.wantError, err)
assert.True(t, tc.want.Equal(syncTime), "want %s; got %s", tc.want, syncTime)
})
}
}
var (
errSink error
profSink *agd.Profile
)
func BenchmarkDNSProfile_ToInternal(b *testing.B) {
dp := NewTestDNSProfile(b)
ctx := context.Background()
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
panic(err)
},
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
profSink, _, errSink = dp.toInternal(ctx, TestUpdTime, errColl)
}
require.NotNil(b, profSink)
require.NoError(b, errSink)
// Most recent result, on a ThinkPad X13:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/backendpb
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkDNSProfile_ToInternal
// BenchmarkDNSProfile_ToInternal-16 157513 10340 ns/op 1148 B/op 27 allocs/op
}

View File

@ -0,0 +1,96 @@
package backendpb_test
import (
context "context"
"net"
"net/url"
"strconv"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
var (
errSink error
respSink *profiledb.StorageResponse
)
func BenchmarkProfileStorage_Profiles(b *testing.B) {
syncTime := strconv.FormatInt(backendpb.TestUpdTime.UnixMilli(), 10)
srvProf := backendpb.NewTestDNSProfile(b)
trailerMD := metadata.MD{
"sync_time": []string{syncTime},
}
srv := &testDNSServiceServer{
OnGetDNSProfiles: func(
req *backendpb.DNSProfilesRequest,
srv backendpb.DNSService_GetDNSProfilesServer,
) (err error) {
sendErr := srv.Send(srvProf)
srv.SetTrailer(trailerMD)
return sendErr
},
}
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
panic(err)
},
}
l, err := net.Listen("tcp", "localhost:0")
require.NoError(b, err)
s, err := backendpb.NewProfileStorage(&backendpb.ProfileStorageConfig{
ErrColl: errColl,
Endpoint: &url.URL{
Scheme: "grpc",
Host: l.Addr().String(),
},
})
require.NoError(b, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterDNSServiceServer(grpcSrv, srv)
go func() {
pt := &testutil.PanicT{}
srvErr := grpcSrv.Serve(l)
require.NoError(pt, srvErr)
}()
b.Cleanup(grpcSrv.GracefulStop)
ctx := context.Background()
req := &profiledb.StorageRequest{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
respSink, errSink = s.Profiles(ctx, req)
}
require.NoError(b, errSink)
require.NotNil(b, respSink)
// Most recent result, on a ThinkPad X13:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/backendpb
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkProfileStorage_Profiles
// BenchmarkProfileStorage_Profiles-16 5347 245341 ns/op 15129 B/op 265 allocs/op
}

View File

@ -0,0 +1,82 @@
package backendpb
import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
)
// profilesCallStats is a stateful structure that collects and reports
// statistics about a [ProfileStorage.Profiles] call.
type profilesCallStats struct {
recvStart time.Time
decStart time.Time
initRecv time.Duration
totalRecv time.Duration
totalDec time.Duration
numRecv int
isFullSync bool
}
// startRecv starts the receive timer.
func (s *profilesCallStats) startRecv() {
s.recvStart = time.Now()
}
// endRecv ends the receive timer and records the results.
func (s *profilesCallStats) endRecv() {
d := time.Since(s.recvStart)
if s.numRecv == 0 {
// Count the initial receive separately, since it is often not
// representative of an average receive, because this is when gRPC
// actually performs the call.
s.initRecv = d
} else {
s.totalRecv += d
}
s.numRecv++
}
// startDec starts the decoding timer.
func (s *profilesCallStats) startDec() {
s.decStart = time.Now()
}
// endDec ends the decoding timer and records the results.
func (s *profilesCallStats) endDec() {
s.totalDec += time.Since(s.decStart)
}
// report writes the statistics to the log and the metrics.
func (s *profilesCallStats) report() {
logFunc := log.Debug
if s.isFullSync {
logFunc = log.Info
}
if s.numRecv == 0 {
logFunc("backendpb: no recv")
return
}
n := time.Duration(s.numRecv)
avgRecv := s.totalRecv / n
avgDec := s.totalDec / n
logFunc(
"backendpb: total recv: %s; agv recv: %s; init recv: %s",
s.totalRecv,
avgRecv,
s.initRecv,
)
logFunc("backendpb: total dec: %s; agv dec: %s", s.totalDec, avgDec)
metrics.GRPCAvgProfileRecvDuration.Observe(avgRecv.Seconds())
metrics.GRPCAvgProfileDecDuration.Observe(avgDec.Seconds())
}

View File

@ -86,10 +86,17 @@ var _ agd.Refresher = (*RuntimeRecorder)(nil)
// uploads the currently available data and resets it.
func (r *RuntimeRecorder) Refresh(ctx context.Context) (err error) {
records := r.resetRecords()
startTime := time.Now()
defer func() {
dur := time.Since(startTime).Seconds()
metrics.BillStatUploadDuration.Observe(dur)
if err != nil {
r.remergeRecords(records)
log.Info("billstat: refresh failed, records remerged")
} else {
metrics.BillStatUploadTimestamp.SetToCurrentTime()
}
metrics.SetStatusGauge(metrics.BillStatUploadStatus, err)
@ -100,8 +107,6 @@ func (r *RuntimeRecorder) Refresh(ctx context.Context) (err error) {
return fmt.Errorf("uploading billstat records: %w", err)
}
metrics.BillStatUploadTimestamp.SetToCurrentTime()
return nil
}

View File

@ -6,17 +6,21 @@ import (
"net"
"net/netip"
"sync"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
)
// chanListener is a [net.Listener] that returns data sent to it through a
// channel.
//
// Listeners of this type are returned by [chanListenConfig.Listen] and are used
// in module dnsserver to make the bind-to-device logic work in DNS-over-TCP.
// Listeners of this type are returned by [ListenConfig.Listen] and are used in
// module dnsserver to make the bind-to-device logic work in DNS-over-TCP.
type chanListener struct {
// mu protects conns (against closure) and isClosed.
mu *sync.Mutex
conns chan net.Conn
connsGauge prometheus.Gauge
laddr net.Addr
subnet netip.Prefix
isClosed bool
@ -27,6 +31,7 @@ func newChanListener(conns chan net.Conn, subnet netip.Prefix, laddr net.Addr) (
return &chanListener{
mu: &sync.Mutex{},
conns: conns,
connsGauge: metrics.BindToDeviceTCPConnsChanSize.WithLabelValues(subnet.String()),
laddr: laddr,
subnet: subnet,
isClosed: false,
@ -77,5 +82,7 @@ func (l *chanListener) send(conn net.Conn) (ok bool) {
l.conns <- conn
l.connsGauge.Set(float64(len(l.conns)))
return true
}

View File

@ -11,13 +11,15 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
)
// chanPacketConn is a [netext.SessionPacketConn] that returns data sent to it
// through the channel.
//
// Connections of this type are returned by [chanListenConfig.ListenPacket] and
// are used in module dnsserver to make the bind-to-device logic work in
// Connections of this type are returned by [ListenConfig.ListenPacket] and are
// used in module dnsserver to make the bind-to-device logic work in
// DNS-over-UDP.
type chanPacketConn struct {
// mu protects sessions (against closure) and isClosed.
@ -26,6 +28,9 @@ type chanPacketConn struct {
writeRequests chan *packetConnWriteReq
sessionsGauge prometheus.Gauge
writeRequestsGauge prometheus.Gauge
// deadlineMu protects readDeadline and writeDeadline.
deadlineMu *sync.RWMutex
readDeadline time.Time
@ -48,6 +53,9 @@ func newChanPacketConn(
sessions: sessions,
writeRequests: writeRequests,
sessionsGauge: metrics.BindToDeviceUDPSessionsChanSize.WithLabelValues(subnet.String()),
writeRequestsGauge: metrics.BindToDeviceUDPWriteRequestsChanSize.WithLabelValues(subnet.String()),
deadlineMu: &sync.RWMutex{},
laddr: laddr,
@ -257,6 +265,8 @@ func (c *chanPacketConn) writeToSession(
return 0, wrapConnError(tnChanPConn, fnName, c.laddr, err)
}
c.writeRequestsGauge.Set(float64(len(c.writeRequests)))
r, err := receiveWithTimer(resp, timerCh)
if err != nil {
err = fmt.Errorf("receiving write response: %w", err)
@ -309,5 +319,7 @@ func (c *chanPacketConn) send(sess *packetSession) (ok bool) {
c.sessions <- sess
c.sessionsGauge.Set(float64(len(c.sessions)))
return true
}

View File

@ -20,27 +20,23 @@ type connIndex struct {
listeners []*chanListener
}
// subnetSortsBefore returns true if subnet x sorts before subnet y.
func subnetSortsBefore(x, y netip.Prefix) (isBefore bool) {
xAddr, xBits := x.Addr(), x.Bits()
yAddr, yBits := y.Addr(), y.Bits()
if xBits == yBits {
return xAddr.Less(yAddr)
}
return xBits > yBits
}
// subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
// position is the same.
func subnetCompare(x, y netip.Prefix) (cmp int) {
switch {
case x == y:
if x == y {
return 0
case subnetSortsBefore(x, y):
}
xAddr, xBits := x.Addr(), x.Bits()
yAddr, yBits := y.Addr(), y.Bits()
if xBits == yBits {
return xAddr.Compare(yAddr)
}
if xBits > yBits {
return -1
default:
} else {
return 1
}
}

View File

@ -10,7 +10,7 @@ import (
"golang.org/x/exp/slices"
)
func TestSubnetSortsBefore(t *testing.T) {
func TestSubnetCompare(t *testing.T) {
want := []netip.Prefix{
netip.MustParsePrefix("1.0.0.0/24"),
netip.MustParsePrefix("1.2.3.0/24"),
@ -24,6 +24,6 @@ func TestSubnetSortsBefore(t *testing.T) {
netip.MustParsePrefix("1.2.3.0/24"),
}
slices.SortFunc(got, subnetSortsBefore)
slices.SortFunc(got, subnetCompare)
assert.Equalf(t, want, got, "got (as strings): %q", got)
}

View File

@ -9,18 +9,22 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// interfaceListener contains information about a single interface listener.
type interfaceListener struct {
conns *connIndex
listenConf *net.ListenConfig
bodyPool *agdsync.TypedPool[[]byte]
oobPool *agdsync.TypedPool[[]byte]
writeRequests chan *packetConnWriteReq
done chan unit
listenConf *net.ListenConfig
errColl agd.ErrorCollector
ifaceName string
port uint16
@ -63,17 +67,30 @@ func (l *interfaceListener) listenTCP(errCh chan<- error) {
continue
}
l.processConn(conn, logPrefix)
}
}
// processConn processes a single connection. If the connection doesn't have a
// connected channel-listener, it is closed.
func (l *interfaceListener) processConn(conn net.Conn, logPrefix string) {
laddr := netutil.NetAddrToAddrPort(conn.LocalAddr())
lsnr := l.conns.listener(laddr.Addr())
if lsnr == nil {
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
continue
}
raddr := conn.RemoteAddr()
if lsnr := l.conns.listener(laddr.Addr()); lsnr != nil {
if !lsnr.send(conn) {
log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
log.Info("%s: from raddr %s: channel for laddr %s is closed", logPrefix, raddr, laddr)
}
return
}
metrics.BindToDeviceUnknownTCPRequestsTotal.Inc()
optlog.Debug3("%s: from raddr %s: no stream channel for laddr %s", logPrefix, raddr, laddr)
err := conn.Close()
if err != nil {
log.Debug("%s: from raddr %s: closing: %s", logPrefix, raddr, err)
}
}
@ -112,27 +129,60 @@ func (l *interfaceListener) listenUDP(errCh chan<- error) {
// Go on.
}
// TODO(a.garipov): Consider customization of body sizes.
var sess *packetSession
sess, err = readPacketSession(udpConn, dns.DefaultMsgSize)
err = l.readUDP(udpConn, logPrefix)
if err != nil {
agd.Collectf(ctx, l.errColl, "%s: reading session: %w", logPrefix, err)
}
}
}
continue
// readUDP reads a UDP session from c and sends it to the appropriate channel.
func (l *interfaceListener) readUDP(c *net.UDPConn, logPrefix string) (err error) {
bodyPtr := l.bodyPool.Get()
body := *bodyPtr
// Extend body to the capacity in case it had already been used and sliced
// by [readPacketSession].
body = body[:cap(body)]
oobPtr := l.oobPool.Get()
oob := *oobPtr
defer func() {
l.oobPool.Put(oobPtr)
// Only return the body to the pool in case of error here. The actual
// return is done in writeUDP.
if err != nil {
l.bodyPool.Put(bodyPtr)
}
}()
sess, err := readPacketSession(c, body, oob)
if err != nil {
return fmt.Errorf("reading session: %w", err)
}
laddr := sess.laddr.AddrPort().Addr()
chanPConn := l.conns.packetConn(laddr)
if chanPConn == nil {
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
chanPacketConn := l.conns.packetConn(laddr)
if chanPacketConn == nil {
metrics.BindToDeviceUnknownUDPRequestsTotal.Inc()
continue
optlog.Debug3(
"%s: from raddr %s: no packet channel for laddr %s",
logPrefix,
sess.raddr,
laddr,
)
return nil
}
if !chanPConn.send(sess) {
if !chanPacketConn.send(sess) {
log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
}
}
return nil
}
// writeUDP runs the UDP write loop. It is intended to be used as a goroutine.
@ -167,6 +217,8 @@ func (l *interfaceListener) writeUDP(c *net.UDPConn) {
s.respOOB,
req.session.raddr,
)
l.bodyPool.Put(&s.readBody)
}
resetDeadlineErr := c.SetWriteDeadline(time.Time{})

View File

@ -18,7 +18,7 @@ type NetInterface interface {
// type check
var _ NetInterface = osInterface{}
// osInterface is a wapper around [*net.Interface] that implements the
// osInterface is a wrapper around [*net.Interface] that implements the
// [NetInterface] interface.
type osInterface struct {
iface *net.Interface

View File

@ -9,23 +9,24 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
)
// chanListenConfig is a [netext.ListenConfig] implementation that uses the
// ListenConfig is a [netext.ListenConfig] implementation that uses the
// provided channel-based packet connection and listener to implement the
// methods of the interface.
//
// netext.ListenConfig instances of this type are the ones that are going to be
// set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic
// work.
type chanListenConfig struct {
type ListenConfig struct {
packetConn *chanPacketConn
listener *chanListener
addr string
}
// type check
var _ netext.ListenConfig = (*chanListenConfig)(nil)
var _ netext.ListenConfig = (*ListenConfig)(nil)
// Listen implements the [netext.ListenConfig] interface for *chanListenConfig.
func (lc *chanListenConfig) Listen(
// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
func (lc *ListenConfig) Listen(
ctx context.Context,
network string,
address string,
@ -34,11 +35,17 @@ func (lc *chanListenConfig) Listen(
}
// ListenPacket implements the [netext.ListenConfig] interface for
// *chanListenConfig.
func (lc *chanListenConfig) ListenPacket(
// *ListenConfig.
func (lc *ListenConfig) ListenPacket(
ctx context.Context,
network string,
address string,
) (c net.PacketConn, err error) {
return lc.packetConn, nil
}
// Addr returns the address on which lc accepts connections. See
// [agdnet.FormatPrefixAddr] for the format.
func (lc *ListenConfig) Addr() (addr string) {
return lc.addr
}

View File

@ -6,16 +6,19 @@ import (
"context"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChanListenConfig(t *testing.T) {
func TestListenConfig(t *testing.T) {
pc := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr)
lsnr := newChanListener(nil, testSubnetIPv4, testLAddr)
c := chanListenConfig{
addr := agdnet.FormatPrefixAddr(testSubnetIPv4, 1234)
c := &ListenConfig{
packetConn: pc,
listener: lsnr,
addr: addr,
}
ctx := context.Background()
@ -29,4 +32,7 @@ func TestChanListenConfig(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, lsnr, gotLsnr)
gotAddr := c.Addr()
assert.Equal(t, addr, gotAddr)
}

View File

@ -0,0 +1,55 @@
//go:build !linux
package bindtodevice
import (
"context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
)
// ListenConfig is a [netext.ListenConfig] implementation that uses the
// provided channel-based packet connection and listener to implement the
// methods of the interface.
//
// netext.ListenConfig instances of this type are the ones that are going to be
// set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic
// work.
//
// It is only supported on Linux.
type ListenConfig struct{}
// type check
var _ netext.ListenConfig = (*ListenConfig)(nil)
// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
//
// It is only supported on Linux.
func (lc *ListenConfig) Listen(
ctx context.Context,
network string,
address string,
) (l net.Listener, err error) {
return nil, errUnsupported
}
// ListenPacket implements the [netext.ListenConfig] interface for
// *ListenConfig.
//
// It is only supported on Linux.
func (lc *ListenConfig) ListenPacket(
ctx context.Context,
network string,
address string,
) (c net.PacketConn, err error) {
return nil, errUnsupported
}
// Addr returns the address on which lc accepts connections. See
// [agdnet.FormatPrefixAddr] for the format.
//
// It is only supported on Linux.
func (lc *ListenConfig) Addr() (addr string) {
return ""
}

View File

@ -10,10 +10,13 @@ import (
"sync"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mapsutil"
"github.com/miekg/dns"
)
// Manager creates individual listeners and dispatches connections to them.
@ -49,7 +52,7 @@ var defaultCtrlConf = &ControlConfig{
// configuration is used.
//
// Add must not be called after Start is called.
func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig) (err error) {
func (m *Manager) Add(id ID, ifaceName string, port uint16, ctrlConf *ControlConfig) (err error) {
defer func() { err = errors.Annotate(err, "adding interface listener with id %q: %w", id) }()
_, err = m.interfaces.InterfaceByName(ifaceName)
@ -86,29 +89,51 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig)
return err
}
if conf == nil {
conf = defaultCtrlConf
if ctrlConf == nil {
ctrlConf = defaultCtrlConf
}
m.ifaceListeners[id] = &interfaceListener{
conns: &connIndex{},
writeRequests: make(chan *packetConnWriteReq, m.chanBufSize),
done: m.done,
listenConf: newListenConfig(ifaceName, conf),
errColl: m.errColl,
ifaceName: ifaceName,
port: port,
}
// TODO(a.garipov): Consider customization of body sizes.
m.ifaceListeners[id] = m.newInterfaceListener(ctrlConf, ifaceName, dns.DefaultMsgSize, port)
return nil
}
// ListenConfig returns a new netext.ListenConfig that receives connections from
// the interface listener with the given id and the destination addresses of
// which fall within subnet. subnet should be masked.
// newInterfaceListener returns a new properly initialized *interfaceListener
// for this manager.
func (m *Manager) newInterfaceListener(
ctrlConf *ControlConfig,
ifaceName string,
bodySize int,
port uint16,
) (l *interfaceListener) {
return &interfaceListener{
conns: &connIndex{},
listenConf: newListenConfig(ifaceName, ctrlConf),
bodyPool: agdsync.NewTypedPool(func() (v *[]byte) {
b := make([]byte, bodySize)
return &b
}),
oobPool: agdsync.NewTypedPool(func() (v *[]byte) {
b := make([]byte, netext.IPDstOOBSize)
return &b
}),
writeRequests: make(chan *packetConnWriteReq, m.chanBufSize),
done: m.done,
errColl: m.errColl,
ifaceName: ifaceName,
port: port,
}
}
// ListenConfig returns a new *ListenConfig that receives connections from the
// interface listener with the given id and the destination addresses of which
// fall within subnet. subnet should be masked.
//
// ListenConfig must not be called after Start is called.
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c *ListenConfig, err error) {
defer func() {
err = errors.Annotate(
err,
@ -154,9 +179,10 @@ func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfi
return nil, fmt.Errorf("adding udp conn: %w", err)
}
return &chanListenConfig{
return &ListenConfig{
packetConn: pConn,
listener: lsnr,
addr: agdnet.FormatPrefixAddr(subnet, l.port),
}, nil
}

View File

@ -221,6 +221,8 @@ func TestManager(t *testing.T) {
err := m.Add(testID1, ifaceName, testPort1, nil)
require.NoError(t, err)
// TODO(a.garipov): Add tests for addresses within ifaceNet but outside of a
// narrower subnet.
subnet, err := netutil.IPNetToPrefixNoMapped(&net.IPNet{
IP: ifaceNet.IP.Mask(ifaceNet.Mask),
Mask: ifaceNet.Mask,

View File

@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
)
@ -25,6 +24,8 @@ func NewManager(c *ManagerConfig) (m *Manager) {
// errUnsupported is returned from all [Manager] methods on OSs other than
// Linux.
//
// TODO(a.garipov): Consider using [errors.ErrUnsupported] in Go 1.21.
const errUnsupported errors.Error = "bindtodevice is only supported on linux"
// Add creates a new interface-listener record in m.
@ -34,12 +35,12 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16, cc *ControlConfig) (
return errUnsupported
}
// ListenConfig returns a new netext.ListenConfig that receives connections from
// the interface listener with the given id and the destination addresses of
// which fall within subnet. subnet should be masked.
// ListenConfig returns a new *ListenConfig that receives connections from the
// interface listener with the given id and the destination addresses of which
// fall within subnet. subnet should be masked.
//
// It is only supported on Linux.
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c *ListenConfig, err error) {
return nil, errUnsupported
}

View File

@ -3,9 +3,10 @@
package bindtodevice
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
)
// prefixNetAddr is a wrapper around netip.Prefix that makes it a [net.Addr].
@ -21,16 +22,11 @@ type prefixNetAddr struct {
// type check
var _ net.Addr = (*prefixNetAddr)(nil)
// String implements the [net.Addr] interface for *prefixNetAddr. It returns an
// address of the form "1.2.3.0:56789/24". That is, IP:port with a subnet after
// a slash. This is done to make using the IP:port part easier to split off
// using something like [strings.Cut].
// String implements the [net.Addr] interface for *prefixNetAddr.
//
// See [agdnet.FormatPrefixAddr] for the format.
func (addr *prefixNetAddr) String() (n string) {
return fmt.Sprintf(
"%s/%d",
netip.AddrPortFrom(addr.prefix.Addr(), addr.port),
addr.prefix.Bits(),
)
return agdnet.FormatPrefixAddr(addr.prefix, addr.port)
}
// Network implements the [net.Addr] interface for *prefixNetAddr.

View File

@ -16,6 +16,8 @@ func TestPrefixAddr(t *testing.T) {
network = "tcp"
)
fullPrefix := netip.MustParsePrefix("1.2.3.4/32")
testCases := []struct {
in *prefixNetAddr
want string
@ -42,6 +44,14 @@ func TestPrefixAddr(t *testing.T) {
netip.AddrPortFrom(testSubnetIPv6.Addr(), port), testSubnetIPv6.Bits(),
),
name: "ipv6",
}, {
in: &prefixNetAddr{
prefix: fullPrefix,
network: network,
port: port,
},
want: netip.AddrPortFrom(fullPrefix.Addr(), port).String(),
name: "ipv4_full",
}}
for _, tc := range testCases {

View File

@ -7,7 +7,6 @@ import (
"net"
"syscall"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/unix"
)
@ -106,14 +105,16 @@ func listenControlWithSO(
return errors.WithDeferred(opErr, err)
}
// msgUDPReader is an interface for types of connections that can read UDP
// messages. See [*net.UDPConn].
type msgUDPReader interface {
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
// readPacketSession is a helper that reads a packet-session data from a UDP
// connection.
func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err error) {
// TODO(a.garipov): Consider adding pooling.
b := make([]byte, bodySize)
oob := make([]byte, netext.IPDstOOBSize)
n, oobn, _, raddr, err := c.ReadMsgUDP(b, oob)
func readPacketSession(c msgUDPReader, body, oob []byte) (sess *packetSession, err error) {
n, oobn, _, raddr, err := c.ReadMsgUDP(body, oob)
if err != nil {
return nil, fmt.Errorf("reading: %w", err)
}
@ -142,7 +143,7 @@ func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err e
sess = &packetSession{
laddr: origDstAddr,
raddr: raddr,
readBody: b[:n],
readBody: body[:n],
respOOB: respOOB,
}

View File

@ -3,7 +3,9 @@
package bindtodevice
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
@ -16,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
@ -240,10 +243,13 @@ func testListenControlUDPQuery(t *testing.T, packetConn net.PacketConn, reqAddr
err := packetConn.SetReadDeadline(time.Now().Add(testTimeout))
require.NoError(t, err)
b := make([]byte, reqLen)
oob := make([]byte, netext.IPDstOOBSize)
var sess *packetSession
switch c := packetConn.(type) {
case *net.UDPConn:
sess, err = readPacketSession(c, reqLen)
sess, err = readPacketSession(c, b, oob)
require.NoError(t, err)
case netext.SessionPacketConn:
var s netext.PacketSession
@ -460,3 +466,84 @@ func TestListenControlWithSO(t *testing.T) {
require.NoError(t, err)
})
}
// testMsgUDPReader is a [msgUDPReader] for tests.
type testMsgUDPReader struct {
onReadMsgUDP func(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
// type check
var _ msgUDPReader = (*testMsgUDPReader)(nil)
// ReadMsgUDP implements the [msgUDPReader] interface for *testMsgUDPReader.
func (r *testMsgUDPReader) ReadMsgUDP(
b []byte,
oob []byte,
) (n, oobn, flags int, addr *net.UDPAddr, err error) {
return r.onReadMsgUDP(b, oob)
}
// Sinks for benchmarks.
var (
sessSink *packetSession
errSink error
)
func BenchmarkReadPacketSession(b *testing.B) {
bodyData := []byte("message body data")
// TODO(a.garipov): Find a better way to pack these control messages than
// just [binary.Write].
oobBuf := &bytes.Buffer{}
ctrlMsgHdr := unix.Cmsghdr{
Len: 24,
Level: unix.SOL_IP,
Type: unix.IP_ORIGDSTADDR,
}
// TODO(a.garipov): Use binary.NativeEndian in Go 1.21 here and below.
err := binary.Write(oobBuf, binary.LittleEndian, ctrlMsgHdr)
require.NoError(b, err)
pktInfo := unix.Inet4Pktinfo{
Spec_dst: *(*[4]byte)(testRAddr.IP),
Addr: *(*[4]byte)(testRAddr.IP),
}
err = binary.Write(oobBuf, binary.LittleEndian, pktInfo)
require.NoError(b, err)
oobData := oobBuf.Bytes()
c := &testMsgUDPReader{
onReadMsgUDP: func(body, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
copy(body, bodyData)
copy(oob, oobData)
return len(bodyData), len(oobData), 0, testRAddr, nil
},
}
body := make([]byte, dns.DefaultMsgSize)
oob := make([]byte, netext.IPDstOOBSize)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
sessSink, errSink = readPacketSession(c, body, oob)
}
require.NoError(b, errSink)
require.NotNil(b, sessSink)
assert.Equal(b, sessSink.raddr, testRAddr)
assert.Equal(b, sessSink.readBody, bodyData)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkReadPacketSession
// BenchmarkReadPacketSession-16 3311841 458.1 ns/op 224 B/op 5 allocs/op
}

40
internal/cmd/access.go Normal file
View File

@ -0,0 +1,40 @@
package cmd
import (
"fmt"
"net/netip"
)
// accessConfig is the configuration that controls IP and hosts blocking.
type accessConfig struct {
// BlockedQuestionDomains is a list of AdBlock rules used to block access.
BlockedQuestionDomains []string `yaml:"blocked_question_domains"`
// BlockedClientSubnets is a list of IP addresses or subnets to block.
BlockedClientSubnets []string `yaml:"blocked_client_subnets"`
}
// validate returns an error if the access configuration is invalid.
func (a *accessConfig) validate() (err error) {
if a == nil {
return errNilConfig
}
for i, s := range a.BlockedClientSubnets {
// TODO(a.garipov): Use [netutil.ParseSubnet] after refactoring it to
// [netip.Addr].
_, parseErr := netip.ParseAddr(s)
if parseErr == nil {
continue
}
_, parseErr = netip.ParsePrefix(s)
if parseErr == nil {
continue
}
return fmt.Errorf("value %q at index %d: bad ip or cidr: %w", s, i, parseErr)
}
return nil
}

View File

@ -3,10 +3,12 @@ package cmd
import (
"context"
"fmt"
"net/url"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/backend"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/netutil"
@ -87,12 +89,14 @@ func setupBillStat(
sigHdlr signalHandler,
errColl agd.ErrorCollector,
) (rec *billstat.RuntimeRecorder, err error) {
billStatConf := &backend.BillStatConfig{
BaseEndpoint: netutil.CloneURL(&envs.BillStatURL.URL),
apiURL := netutil.CloneURL(&envs.BillStatURL.URL)
billStatUploader, err := setupBillStatUploader(apiURL, errColl)
if err != nil {
return nil, fmt.Errorf("creating bill stat uploader: %w", err)
}
rec = billstat.NewRuntimeRecorder(&billstat.RuntimeRecorderConfig{
Uploader: backend.NewBillStat(billStatConf),
Uploader: billStatUploader,
})
refrIvl := conf.RefreshIvl.Duration
@ -127,13 +131,12 @@ func setupProfDB(
sigHdlr signalHandler,
errColl agd.ErrorCollector,
) (profDB *profiledb.Default, err error) {
profStrgConf := &backend.ProfileStorageConfig{
BaseEndpoint: netutil.CloneURL(&envs.ProfilesURL.URL),
Now: time.Now,
ErrColl: errColl,
apiURL := netutil.CloneURL(&envs.ProfilesURL.URL)
profStrg, err := setupProfStorage(apiURL, errColl)
if err != nil {
return nil, fmt.Errorf("creating profile storage: %w", err)
}
profStrg := backend.NewProfileStorage(profStrgConf)
profDB, err = profiledb.New(profStrg, conf.FullRefreshIvl.Duration, envs.ProfilesCachePath)
if err != nil {
return nil, fmt.Errorf("creating default profile database: %w", err)
@ -162,3 +165,55 @@ func setupProfDB(
return profDB, nil
}
// Backend API URL schemes.
const (
schemeHTTP = "http"
schemeHTTPS = "https"
schemeGRPC = "grpc"
schemeGRPCS = "grpcs"
)
// setupProfStorage creates and returns a profile storage depending on the
// provided API URL.
func setupProfStorage(
apiURL *url.URL,
errColl agd.ErrorCollector,
) (s profiledb.Storage, err error) {
switch apiURL.Scheme {
case schemeGRPC, schemeGRPCS:
return backendpb.NewProfileStorage(&backendpb.ProfileStorageConfig{
Endpoint: apiURL,
ErrColl: errColl,
})
case schemeHTTP, schemeHTTPS:
return backend.NewProfileStorage(&backend.ProfileStorageConfig{
BaseEndpoint: apiURL,
Now: time.Now,
ErrColl: errColl,
}), nil
default:
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
}
}
// setupBillStatUploader creates and returns a billstat uploader depending on
// the provided API URL.
func setupBillStatUploader(
apiURL *url.URL,
errColl agd.ErrorCollector,
) (s billstat.Uploader, err error) {
switch apiURL.Scheme {
case schemeGRPC, schemeGRPCS:
return backendpb.NewBillStat(&backendpb.BillStatConfig{
ErrColl: errColl,
Endpoint: apiURL,
})
case schemeHTTP, schemeHTTPS:
return backend.NewBillStat(&backend.BillStatConfig{
BaseEndpoint: apiURL,
}), nil
default:
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
}
}

View File

@ -2,7 +2,6 @@ package cmd
import (
"fmt"
"net"
"net/netip"
"net/url"
"strings"
@ -53,18 +52,6 @@ func (c *checkConfig) toInternal(
sessURL = netutil.CloneURL(&envs.ConsulDNSCheckSessionURL.URL)
}
// TODO(a.garipov): Use netip.Addrs in dnscheck, which also means using it
// in dnsmsg.Constructor.
ipv4 := make([]net.IP, len(c.IPv4))
for i, ip := range c.IPv4 {
ipv4[i] = ip.AsSlice()
}
ipv6 := make([]net.IP, len(c.IPv6))
for i, ip := range c.IPv6 {
ipv6[i] = ip.AsSlice()
}
domains := make([]string, len(c.Domains))
for i, d := range c.Domains {
domains[i] = strings.ToLower(d)
@ -78,8 +65,8 @@ func (c *checkConfig) toInternal(
Domains: domains,
NodeLocation: c.NodeLocation,
NodeName: c.NodeName,
IPv4: ipv4,
IPv6: ipv6,
IPv4: c.IPv4,
IPv6: c.IPv6,
TTL: c.TTL.Duration,
}
}

View File

@ -8,6 +8,7 @@ import (
"os"
"runtime"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
@ -15,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
@ -160,6 +162,11 @@ func Main() {
sigHdlr.add(btdMgr)
// access
accessManager, err := access.New(c.Access.BlockedQuestionDomains, c.Access.BlockedClientSubnets)
check(err)
// Server groups
messages := dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, c.Filters.ResponseTTL.Duration)
@ -239,6 +246,7 @@ func Main() {
}
dnsConf := &dnssvc.Config{
AccessManager: accessManager,
Messages: messages,
SafeBrowsing: hashprefix.NewMatcher(hashStorages),
BillStat: billStatRec,
@ -327,5 +335,10 @@ func collectPanics(errColl agd.ErrorCollector) {
errColl.Collect(context.Background(), err)
errFlushColl, ok := errColl.(errcoll.ErrorFlushCollector)
if ok {
errFlushColl.Flush()
}
panic(v)
}

View File

@ -71,6 +71,9 @@ type configuration struct {
// Network is the configuration for network listeners.
Network *network `yaml:"network"`
// Access is the configuration of the service managing access control.
Access *accessConfig `yaml:"access"`
// AdditionalMetricsInfo is extra information, which is exposed by metrics.
AdditionalMetricsInfo additionalInfo `yaml:"additional_metrics_info"`
@ -159,6 +162,9 @@ func (c *configuration) validate() (err error) {
}, {
validate: c.Network.validate,
name: "network",
}, {
validate: c.Access.validate,
name: "access",
}, {
validate: c.AdditionalMetricsInfo.validate,
name: "additional_metrics_info",

View File

@ -36,10 +36,10 @@ func (c *connCheckConfig) validate() (err error) {
// connectivityCheck performs connectivity checks for bind addresses with
// provided dialer and probe addresses. For each server group it reviews each
// server bind addresses looking up for an IPv6 addresses. If an IPv6 address
// is found, then additionally to a general probe to IPv4 it will perform a
// check to IPv6 probe address.
func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) error {
// server bind addresses looking up for IPv6 addresses. If an IPv6 address is
// found, then additionally to a general probe to IPv4 it will perform a check
// to IPv6 probe address.
func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) (err error) {
probeIPv4 := net.TCPAddrFromAddrPort(connCheck.ProbeIPv4)
// General check to IPv4 probe address.
@ -49,13 +49,16 @@ func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) error {
}
defer func() {
err = conn.Close()
if err != nil {
log.Fatalf("connectivity check: ipv4: %v", err)
closeErr := conn.Close()
if closeErr != nil {
log.Fatalf("connectivity check: ipv4: %v", closeErr)
}
}()
if containsIPv6BindAddress(c.ServerGroups) {
if !requireIPv6ConnCheck(c.ServerGroups) {
return nil
}
if (connCheck.ProbeIPv6 == netip.AddrPort{}) {
log.Fatal("connectivity check: no ipv6 probe address in config")
}
@ -63,34 +66,43 @@ func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) error {
probeIPv6 := net.TCPAddrFromAddrPort(connCheck.ProbeIPv6)
// Check to IPv6 probe address.
connV6, errV6 := net.DialTCP("tcp6", nil, probeIPv6)
if errV6 != nil {
return fmt.Errorf("connectivity check: ipv6: %w", errV6)
connV6, err := net.DialTCP("tcp6", nil, probeIPv6)
if err != nil {
return fmt.Errorf("connectivity check: ipv6: %w", err)
}
defer func() {
err = connV6.Close()
if err != nil {
log.Fatalf("connectivity check: ipv6: %v", err)
closeErr := connV6.Close()
if closeErr != nil {
log.Fatalf("connectivity check: ipv6: %v", closeErr)
}
}()
}
return nil
}
// containsIPv6BindAddress returns true if provided serverGroups require
// IPv6 connectivity check.
func containsIPv6BindAddress(serverGroups []*agd.ServerGroup) (ok bool) {
// requireIPv6ConnCheck returns true if provided serverGroups require IPv6
// connectivity check.
func requireIPv6ConnCheck(serverGroups []*agd.ServerGroup) (ok bool) {
for _, srvGrp := range serverGroups {
for _, s := range srvGrp.Servers {
for _, bindData := range s.BindData {
if addr := bindData.AddrPort; addr.IsValid() && addr.Addr().Is6() {
if containsIPv6BindAddress(s.BindData) {
return true
}
}
}
}
return false
}
// containsIPv6BindAddress returns true if provided bindData contains valid IPv6
// address.
func containsIPv6BindAddress(bindData []*agd.ServerBindData) (ok bool) {
for _, bData := range bindData {
if addr := bData.AddrPort; addr.Addr().Is6() {
return true
}
}
return false
}

View File

@ -57,8 +57,16 @@ func ddrRecsToSVCBTmpls(
tmpls = appendDDRSVCBTmpls(tmpls, msgs, r, target)
}
slices.SortStableFunc(tmpls, func(a, b *dns.SVCB) (less bool) {
return a.Priority < b.Priority
// TODO(e.burkov): Use cmp.Compare when updated to go1.21.
slices.SortStableFunc(tmpls, func(a, b *dns.SVCB) (res int) {
switch x, y := a.Priority, b.Priority; {
case x < y:
return -1
case x > y:
return +1
default:
return 0
}
})
return targets, tmpls
@ -118,26 +126,16 @@ func (c *ddrConfig) validate() (err error) {
}
domainSuf := wildcard[2:]
err = netutil.ValidateHostname(domainSuf)
err = errors.Join(netutil.ValidateHostname(domainSuf), r.validate())
if err != nil {
return fmt.Errorf("device_records: %w", err)
}
err = r.validate()
if err != nil {
return fmt.Errorf("device_records: record for wildcard %q: %w", wildcard, err)
return fmt.Errorf("device_records: wildcard %q: %w", wildcard, err)
}
}
for domain, r := range c.PublicRecords {
err = netutil.ValidateHostname(domain)
err = errors.Join(netutil.ValidateHostname(domain), r.validate())
if err != nil {
return fmt.Errorf("public_records: %w", err)
}
err = r.validate()
if err != nil {
return fmt.Errorf("public_records: record for domain %q: %w", domain, err)
return fmt.Errorf("public_records: domain %q: %w", domain, err)
}
}

View File

@ -41,7 +41,7 @@ type environments struct {
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.json"`
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"`
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"`

View File

@ -6,7 +6,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
@ -205,10 +204,8 @@ func (s *server) bindData(
ifaces := s.BindInterfaces
bindData = make([]*agd.ServerBindData, 0, len(ifaces))
for i, iface := range ifaces {
address := string(iface.ID)
for j, subnet := range iface.Subnets {
var lc netext.ListenConfig
var lc *bindtodevice.ListenConfig
lc, err = btdMgr.ListenConfig(iface.ID, subnet)
if err != nil {
const errStr = "bind_interface at index %d: subnet at index %d: %w"
@ -218,7 +215,7 @@ func (s *server) bindData(
bindData = append(bindData, &agd.ServerBindData{
ListenConfig: lc,
Address: address,
Address: lc.Addr(),
})
}
}

View File

@ -136,14 +136,15 @@ func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
return nil, fmt.Errorf("certificate at index %d: %w", i, err)
}
tlsCerts[i] = cert
var leaf *x509.Certificate
leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, fmt.Errorf("invalid leaf, certificate at index %d: %w", i, err)
}
cert.Leaf = leaf
tlsCerts[i] = cert
authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String()
metrics.TLSCertificateInfo.With(prometheus.Labels{
"auth_algo": authAlgo,

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
@ -23,6 +22,7 @@ import (
"github.com/miekg/dns"
cache "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/exp/slices"
"golang.org/x/time/rate"
)
@ -46,8 +46,8 @@ type Consul struct {
nodeLocation string
nodeName string
ipv4 []net.IP
ipv6 []net.IP
ipv4 []netip.Addr
ipv6 []netip.Addr
}
// ConsulConfig is the configuration structure for Consul KV based DNS checker.
@ -77,10 +77,10 @@ type ConsulConfig struct {
NodeName string
// IPv4 are the IPv4 addresses to respond with to A requests.
IPv4 []net.IP
IPv4 []netip.Addr
// IPv6 are the IPv6 addresses to respond with to AAAA requests.
IPv6 []net.IP
IPv6 []netip.Addr
// TTL defines, for how long to keep the information about a single client.
TTL time.Duration
@ -108,8 +108,8 @@ func NewConsul(c *ConsulConfig) (cc *Consul, err error) {
nodeLocation: c.NodeLocation,
nodeName: c.NodeName,
ipv4: netutil.CloneIPs(c.IPv4),
ipv6: netutil.CloneIPs(c.IPv6),
ipv4: slices.Clone(c.IPv4),
ipv6: slices.Clone(c.IPv6),
}
// TODO(e.burkov): Validate also c.ConsulSessionURL?

View File

@ -6,6 +6,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"testing"
@ -222,8 +223,8 @@ func TestConsul_Check(t *testing.T) {
conf := &dnscheck.ConsulConfig{
Messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second),
Domains: []string{checkDomain},
IPv4: []net.IP{{1, 2, 3, 4}},
IPv6: []net.IP{net.ParseIP("1234::5678")},
IPv4: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
IPv6: []netip.Addr{netip.MustParseAddr("1234::5678")},
}
dnsCk, err := dnscheck.NewConsul(conf)

View File

@ -5,9 +5,9 @@ import (
"compress/gzip"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"testing"
@ -25,7 +25,8 @@ import (
func TestDefault_ServeHTTP(t *testing.T) {
const dname = "some-domain.name"
testIP := net.IP{1, 2, 3, 4}
testIP := netip.MustParseAddr("1.2.3.4")
successHdr := http.Header{
httphdr.ContentType: []string{agdhttp.HdrValTextCSV},
@ -100,7 +101,7 @@ func TestDefault_ServeHTTP(t *testing.T) {
db.Record(ctx, m, &agd.RequestInfo{
// Emulate the logic from init middleware.
//
// See [dnssvc.initMw.newRequestInfo].
// See [initial.Middleware.newRequestInfo].
Host: strings.TrimSuffix(m.Question[0].Name, "."),
})
}

View File

@ -56,6 +56,7 @@ func answerString(rr dns.RR) (s string) {
case *dns.AAAA:
return v.AAAA.String()
case *dns.CNAME:
// TODO(a.garipov): Consider lowercasing target hostname.
return strings.TrimSuffix(v.Target, ".")
default:
return ""

View File

@ -21,6 +21,8 @@ type BlockingMode interface {
// BlockingModeCodec is a wrapper around a BlockingMode that implements the
// [json.Marshaler] and [json.Unmarshaler] interfaces.
//
// TODO(s.chzhen): Remove once it's not used anymore.
type BlockingModeCodec struct {
Mode BlockingMode
}

510
internal/dnsmsg/cloner.go Normal file
View File

@ -0,0 +1,510 @@
package dnsmsg
import (
"fmt"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// Cloner is a pool that can clone common parts of DNS messages with fewer
// allocations.
//
// TODO(a.garipov): Add ECS/OPT.
//
// TODO(a.garipov): Use.
//
// TODO(a.garipov): Consider merging into [Constructor].
type Cloner struct {
// Top-level structures.
msg *agdsync.TypedPool[dns.Msg]
question *agdsync.TypedPool[[]dns.Question]
// Mostly-answer structures.
a *agdsync.TypedPool[dns.A]
aaaa *agdsync.TypedPool[dns.AAAA]
cname *agdsync.TypedPool[dns.CNAME]
ptr *agdsync.TypedPool[dns.PTR]
srv *agdsync.TypedPool[dns.SRV]
txt *agdsync.TypedPool[dns.TXT]
// Mostly-answer custom cloners.
https *httpsCloner
// Mostly-NS structures.
soa *agdsync.TypedPool[dns.SOA]
}
// NewCloner returns a new properly initialized *Cloner.
func NewCloner() (c *Cloner) {
return &Cloner{
msg: agdsync.NewTypedPool(func() (v *dns.Msg) {
return &dns.Msg{}
}),
question: agdsync.NewTypedPool(func() (v *[]dns.Question) {
q := make([]dns.Question, 1)
return &q
}),
a: agdsync.NewTypedPool(func() (v *dns.A) {
return &dns.A{}
}),
aaaa: agdsync.NewTypedPool(func() (v *dns.AAAA) {
return &dns.AAAA{}
}),
cname: agdsync.NewTypedPool(func() (v *dns.CNAME) {
return &dns.CNAME{}
}),
ptr: agdsync.NewTypedPool(func() (v *dns.PTR) {
return &dns.PTR{}
}),
srv: agdsync.NewTypedPool(func() (v *dns.SRV) {
return &dns.SRV{}
}),
txt: agdsync.NewTypedPool(func() (v *dns.TXT) {
return &dns.TXT{}
}),
https: newHTTPSCloner(),
soa: agdsync.NewTypedPool(func() (v *dns.SOA) {
return &dns.SOA{}
}),
}
}
// Clone returns a deep clone of msg. full is true if msg was cloned entirely
// without the use of [dns.Copy].
//
// msg must have exactly one question.
//
// TODO(a.garipov): Don't require one question?
func (c *Cloner) Clone(msg *dns.Msg) (clone *dns.Msg, full bool) {
if msg == nil {
return nil, true
}
clone = c.msg.Get()
clone.MsgHdr = msg.MsgHdr
clone.Compress = msg.Compress
clone.Question = *c.question.Get()
clone.Question[0] = msg.Question[0]
clone.Answer, full = c.appendAnswer(clone.Answer[:0], msg.Answer)
clone.Ns = clone.Ns[:0]
for _, orig := range msg.Ns {
var nsClone dns.RR
switch orig := orig.(type) {
case *dns.SOA:
ns := c.soa.Get()
*ns = *orig
nsClone = ns
// TODO(a.garipov): Add more if necessary.
default:
nsClone = dns.Copy(orig)
full = false
}
clone.Ns = append(clone.Ns, nsClone)
}
clone.Extra = clone.Extra[:0]
for _, orig := range msg.Extra {
var exClone dns.RR
switch orig := orig.(type) {
// TODO(a.garipov): Add more if necessary.
default:
exClone = dns.Copy(orig)
full = false
}
clone.Extra = append(clone.Extra, exClone)
}
return clone, full
}
// appendAnswer appends deep clones of all resource recornds from original to
// clones and returns it.
func (c *Cloner) appendAnswer(clones, original []dns.RR) (res []dns.RR, full bool) {
full = true
for _, orig := range original {
var ansClone dns.RR
switch orig := orig.(type) {
case *dns.A:
ans := c.a.Get()
ans.Hdr = orig.Hdr
ans.A = append(ans.A[:0], orig.A...)
ansClone = ans
case *dns.AAAA:
ans := c.aaaa.Get()
ans.Hdr = orig.Hdr
ans.AAAA = append(ans.AAAA[:0], orig.AAAA...)
ansClone = ans
case *dns.CNAME:
ans := c.cname.Get()
*ans = *orig
ansClone = ans
case *dns.HTTPS:
var httpsFull bool
ansClone, httpsFull = c.https.clone(orig)
full = full && httpsFull
case *dns.PTR:
ans := c.ptr.Get()
*ans = *orig
ansClone = ans
case *dns.SRV:
ans := c.srv.Get()
*ans = *orig
ansClone = ans
case *dns.TXT:
ans := c.txt.Get()
ans.Hdr = orig.Hdr
ans.Txt = append(ans.Txt[:0], orig.Txt...)
ansClone = ans
default:
ansClone = dns.Copy(orig)
full = false
}
clones = append(clones, ansClone)
}
return clones, full
}
// Put returns structures from msg into c's pools. Neither msg nor any of its
// parts must not be used after this.
//
// msg must have exactly one question.
//
// TODO(a.garipov): Don't require one question?
func (c *Cloner) Put(msg *dns.Msg) {
if msg == nil {
return
}
c.putAnswers(msg.Answer)
for _, ns := range msg.Ns {
switch ns := ns.(type) {
case *dns.SOA:
c.soa.Put(ns)
default:
// Go on.
}
}
for _, ex := range msg.Extra {
// TODO(a.garipov): Add OPT.
_ = ex
}
c.question.Put(&msg.Question)
c.msg.Put(msg)
}
// putAnswers returns answers into c's pools.
func (c *Cloner) putAnswers(answers []dns.RR) {
for _, ans := range answers {
switch ans := ans.(type) {
case *dns.A:
c.a.Put(ans)
case *dns.AAAA:
c.aaaa.Put(ans)
case *dns.CNAME:
c.cname.Put(ans)
case *dns.HTTPS:
c.https.put(ans)
case *dns.PTR:
c.ptr.Put(ans)
case *dns.SRV:
c.srv.Put(ans)
case *dns.TXT:
c.txt.Put(ans)
default:
// Go on.
}
}
}
// httpsCloner is a pool that can clone common parts of DNS messages of type
// HTTPS with fewer allocations.
type httpsCloner struct {
// Top-level structures.
rr *agdsync.TypedPool[dns.HTTPS]
// Values.
alpn *agdsync.TypedPool[dns.SVCBAlpn]
dohpath *agdsync.TypedPool[dns.SVCBDoHPath]
echconfig *agdsync.TypedPool[dns.SVCBECHConfig]
ipv4hint *agdsync.TypedPool[dns.SVCBIPv4Hint]
ipv6hint *agdsync.TypedPool[dns.SVCBIPv6Hint]
local *agdsync.TypedPool[dns.SVCBLocal]
mandatory *agdsync.TypedPool[dns.SVCBMandatory]
noDefALPN *agdsync.TypedPool[dns.SVCBNoDefaultAlpn]
port *agdsync.TypedPool[dns.SVCBPort]
// Miscellaneous.
ip *agdsync.TypedPool[net.IP]
}
// newHTTPSCloner returns a new properly initialized *httpsCloner.
func newHTTPSCloner() (c *httpsCloner) {
return &httpsCloner{
rr: agdsync.NewTypedPool(func() (v *dns.HTTPS) {
return &dns.HTTPS{}
}),
alpn: agdsync.NewTypedPool(func() (v *dns.SVCBAlpn) {
return &dns.SVCBAlpn{}
}),
dohpath: agdsync.NewTypedPool(func() (v *dns.SVCBDoHPath) {
return &dns.SVCBDoHPath{}
}),
echconfig: agdsync.NewTypedPool(func() (v *dns.SVCBECHConfig) {
return &dns.SVCBECHConfig{}
}),
ipv4hint: agdsync.NewTypedPool(func() (v *dns.SVCBIPv4Hint) {
return &dns.SVCBIPv4Hint{}
}),
ipv6hint: agdsync.NewTypedPool(func() (v *dns.SVCBIPv6Hint) {
return &dns.SVCBIPv6Hint{}
}),
local: agdsync.NewTypedPool(func() (v *dns.SVCBLocal) {
return &dns.SVCBLocal{}
}),
mandatory: agdsync.NewTypedPool(func() (v *dns.SVCBMandatory) {
return &dns.SVCBMandatory{}
}),
noDefALPN: agdsync.NewTypedPool(func() (v *dns.SVCBNoDefaultAlpn) {
return &dns.SVCBNoDefaultAlpn{}
}),
port: agdsync.NewTypedPool(func() (v *dns.SVCBPort) {
return &dns.SVCBPort{}
}),
ip: agdsync.NewTypedPool(func() (v *net.IP) {
// Use the IPv6 length to increase the effectiveness of the pool.
ip := make(net.IP, 16)
return &ip
}),
}
}
// clone returns a deep clone of rr. full is true if rr was cloned entirely
// without the use of [dns.Copy].
func (c *httpsCloner) clone(rr *dns.HTTPS) (clone *dns.HTTPS, full bool) {
if rr == nil {
return nil, true
}
clone = c.rr.Get()
clone.Hdr = rr.Hdr
clone.Priority = rr.Priority
clone.Target = rr.Target
clone.Value = clone.Value[:0]
for _, orig := range rr.Value {
valClone, knownKV := c.cloneKV(orig)
if !knownKV {
// This branch is only reached if there is a new SVCB key-value type
// in miekg/dns. Give up and just use their copy function.
return dns.Copy(rr).(*dns.HTTPS), false
}
clone.Value = append(clone.Value, valClone)
}
return clone, true
}
// cloneKV returns a deep clone of orig. full is true if orig was recognized.
func (c *httpsCloner) cloneKV(orig dns.SVCBKeyValue) (clone dns.SVCBKeyValue, known bool) {
switch orig := orig.(type) {
case *dns.SVCBAlpn:
v := c.alpn.Get()
v.Alpn = append(v.Alpn[:0], orig.Alpn...)
clone = v
case *dns.SVCBDoHPath:
v := c.dohpath.Get()
*v = *orig
clone = v
case *dns.SVCBECHConfig:
v := c.echconfig.Get()
v.ECH = append(v.ECH[:0], orig.ECH...)
clone = v
case *dns.SVCBIPv4Hint:
v := c.ipv4hint.Get()
v.Hint = c.appendIPs(v.Hint[:0], orig.Hint)
clone = v
case *dns.SVCBIPv6Hint:
v := c.ipv6hint.Get()
v.Hint = c.appendIPs(v.Hint[:0], orig.Hint)
clone = v
case *dns.SVCBLocal:
v := c.local.Get()
v.KeyCode = orig.KeyCode
v.Data = append(v.Data[:0], orig.Data...)
clone = v
case *dns.SVCBMandatory:
v := c.mandatory.Get()
v.Code = append(v.Code[:0], orig.Code...)
clone = v
case *dns.SVCBNoDefaultAlpn:
clone = c.noDefALPN.Get()
case *dns.SVCBPort:
v := c.port.Get()
*v = *orig
clone = v
default:
// This branch is only reached if there is a new SVCB key-value type
// in miekg/dns.
return nil, false
}
return clone, true
}
// appendIPs appends the clones of IP addresses from orig to hints and returns
// the resulting slice. clone is allocated as a single continuous slice.
func (c *httpsCloner) appendIPs(hints, orig []net.IP) (clone []net.IP) {
if len(orig) == 0 {
if orig == nil {
return nil
}
return []net.IP{}
}
// Use a single large slice and subslice it to make it easier to maintain a
// pool of these.
ips := *c.ip.Get()
ips = ips[:0]
neededCap := 0
for _, origIP := range orig {
neededCap += len(origIP)
}
ips = slices.Grow(ips, neededCap)
hints = hints[:0]
for _, origIP := range orig {
ips = append(ips, origIP...)
origLen := len(origIP)
lastIdx := len(ips)
hints = append(hints, ips[lastIdx-origLen:lastIdx])
}
return hints
}
// put returns structures from rr into c's pools.
func (c *httpsCloner) put(rr *dns.HTTPS) {
if rr == nil {
return
}
for _, kv := range rr.Value {
c.putKV(kv)
}
c.rr.Put(rr)
}
// putKV returns structures from kv into c's pools.
func (c *httpsCloner) putKV(kv dns.SVCBKeyValue) {
switch kv := kv.(type) {
case *dns.SVCBAlpn:
c.alpn.Put(kv)
case *dns.SVCBDoHPath:
c.dohpath.Put(kv)
case *dns.SVCBECHConfig:
c.echconfig.Put(kv)
case *dns.SVCBIPv4Hint:
putIPHint(c, kv)
case *dns.SVCBIPv6Hint:
putIPHint(c, kv)
case *dns.SVCBLocal:
c.local.Put(kv)
case *dns.SVCBMandatory:
c.mandatory.Put(kv)
case *dns.SVCBNoDefaultAlpn:
c.noDefALPN.Put(kv)
case *dns.SVCBPort:
c.port.Put(kv)
default:
// This branch is only reached if there is a new SVCB key-value type
// in miekg/dns. Noting to do.
}
}
// putIPHint is a generic helper that returns the structures of kv into c.
func putIPHint[T *dns.SVCBIPv4Hint | *dns.SVCBIPv6Hint](c *httpsCloner, kv T) {
switch kv := any(kv).(type) {
case *dns.SVCBIPv4Hint:
// TODO(a.garipov): Put the common code above the switch when Go learns
// about common fields between types.
if len(kv.Hint) > 0 {
// Assume that the array underlying these slices is a single and
// continuous one.
c.ip.Put(&kv.Hint[0])
}
c.ipv4hint.Put(kv)
case *dns.SVCBIPv6Hint:
// TODO(a.garipov): Put the common code above the switch when Go learns
// about common fields between types.
if len(kv.Hint) > 0 {
// Assume that the array underlying these slices is a single and
// continuous one.
c.ip.Put(&kv.Hint[0])
}
c.ipv6hint.Put(kv)
default:
// Must not happen, because there is a strict type parameter above.
panic(fmt.Errorf("bad type %T", kv))
}
}

View File

@ -0,0 +1,280 @@
package dnsmsg_test
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// clonerTestCase is the type for the common test cases for the cloner tests and
// benchmarks.
type clonerTestCase struct {
msg *dns.Msg
wantFull assert.BoolAssertionFunc
name string
}
// clonerTestCases are the common test cases for the clone benchmarks.
var clonerTestCases = []clonerTestCase{{
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
name: "req_a",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewA(testFQDN, 10, testIPv4),
},
),
name: "resp_a",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewA(testFQDN, 10, testIPv4),
dnsservertest.NewA(testFQDN, 10, testIPv4.Next()),
},
),
name: "resp_a_many",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewA(testFQDN, 10, testIPv4),
},
dnsservertest.SectionNs{
dnsservertest.NewSOA(testFQDN, 10, "ns.example.", "mbox.example."),
},
),
name: "resp_a_soa",
wantFull: assert.True,
}, {
msg: dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET),
name: "req_aaaa",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewAAAA(testFQDN, 10, testIPv6),
},
),
name: "resp_aaaa",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewCNAME(testFQDN, 10, "cname.example."),
dnsservertest.NewA("cname.example.", 10, testIPv4),
},
),
name: "resp_cname_a",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq("4.3.2.1.in-addr.arpa", dns.TypePTR, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewPTR("4.3.2.1.in-addr.arpa", 10, "ptr.example."),
},
),
name: "resp_ptr",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewTXT(testFQDN, 10, "a", "b", "c"),
},
),
name: "resp_txt",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeSRV, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewSRV(testFQDN, 10, "target.example.", 1, 1, 8080),
},
),
name: "resp_srv",
wantFull: assert.True,
}, {
msg: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeDNSKEY, dns.ClassINET),
dnsservertest.SectionAnswer{
&dns.DNSKEY{},
},
),
name: "resp_not_full",
wantFull: assert.False,
}, {
msg: newHTTPSResp([]dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"http/1.1", "h2", "h3"}},
&dns.SVCBDoHPath{Template: "/dns-query"},
&dns.SVCBECHConfig{ECH: []byte{0, 1, 2, 3}},
&dns.SVCBIPv4Hint{Hint: []net.IP{
testIPv4.AsSlice(),
testIPv4.Next().AsSlice(),
}},
&dns.SVCBIPv6Hint{Hint: []net.IP{
testIPv6.AsSlice(),
testIPv6.Next().AsSlice(),
}},
&dns.SVCBLocal{KeyCode: dns.SVCBKey(1234), Data: []byte{3, 2, 1, 0}},
&dns.SVCBMandatory{Code: []dns.SVCBKey{dns.SVCB_ALPN}},
&dns.SVCBNoDefaultAlpn{},
&dns.SVCBPort{Port: 443},
}),
name: "resp_https",
wantFull: assert.True,
}, {
msg: newHTTPSResp([]dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{}},
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
}),
name: "resp_https_empty_hint",
wantFull: assert.True,
}}
// newHTTPSResp is a hepler that returns a response of type HTTPS with the given
// parameter values.
func newHTTPSResp(kv []dns.SVCBKeyValue) (resp *dns.Msg) {
ans := &dns.HTTPS{
SVCB: dns.SVCB{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeHTTPS,
Class: dns.ClassINET,
Ttl: 10,
},
Priority: 10,
Target: testFQDN,
Value: kv,
},
}
return dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(testFQDN, dns.TypeHTTPS, dns.ClassINET),
dnsservertest.SectionAnswer{ans},
)
}
func TestCloner_Clone(t *testing.T) {
c := dnsmsg.NewCloner()
for _, tc := range clonerTestCases {
t.Run(tc.name, func(t *testing.T) {
clone, full := c.Clone(tc.msg)
assert.NotSame(t, tc.msg, clone)
assert.Equal(t, tc.msg, clone)
tc.wantFull(t, full)
// Check again after putting it back.
c.Put(clone)
clone, full = c.Clone(tc.msg)
assert.NotSame(t, tc.msg, clone)
assert.Equal(t, tc.msg, clone)
tc.wantFull(t, full)
})
}
}
// Sinks for benchmarks
var (
msgSink *dns.Msg
boolSink bool
)
func BenchmarkClone(b *testing.B) {
for _, tc := range clonerTestCases {
b.Run(tc.name, func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
msgSink = dnsmsg.Clone(tc.msg)
}
require.Equal(b, tc.msg, msgSink)
})
}
// Most recent results, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/querylog
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkClone/req_a-16 32849714 231.7 ns/op 168 B/op 2 allocs/op
// BenchmarkClone/resp_a-16 12051967 509.1 ns/op 256 B/op 5 allocs/op
// BenchmarkClone/resp_a_many-16 8579755 669.4 ns/op 344 B/op 7 allocs/op
// BenchmarkClone/resp_a_soa-16 10393932 681.9 ns/op 368 B/op 6 allocs/op
// BenchmarkClone/req_aaaa-16 25616247 232.1 ns/op 168 B/op 2 allocs/op
// BenchmarkClone/resp_aaaa-16 14519920 493.4 ns/op 264 B/op 5 allocs/op
// BenchmarkClone/resp_cname_a-16 8652282 662.2 ns/op 320 B/op 6 allocs/op
// BenchmarkClone/resp_ptr-16 13558555 370.0 ns/op 232 B/op 4 allocs/op
// BenchmarkClone/resp_txt-16 12322016 532.7 ns/op 296 B/op 5 allocs/op
// BenchmarkClone/resp_srv-16 15878784 396.3 ns/op 248 B/op 4 allocs/op
// BenchmarkClone/resp_not_full-16 15718658 384.6 ns/op 248 B/op 4 allocs/op
// BenchmarkClone/resp_https-16 2621149 2020 ns/op 880 B/op 24 allocs/op
// BenchmarkClone/resp_https_empty_hint-16 6829873 890.8 ns/op 424 B/op 8 allocs/op
}
func BenchmarkCloner_Clone(b *testing.B) {
c := dnsmsg.NewCloner()
for _, tc := range clonerTestCases {
b.Run(tc.name, func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
msgSink, boolSink = c.Clone(tc.msg)
if i < b.N-1 {
// Don't put the last one to be sure that we can compare
// that one.
c.Put(msgSink)
}
}
require.Equal(b, tc.msg, msgSink)
tc.wantFull(b, boolSink)
})
}
// Most recent results, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/querylog
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkCloner_Clone/req_a-16 163590546 36.33 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_a-16 100000000 56.55 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_a_many-16 72498543 84.52 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_a_soa-16 81750753 73.07 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/req_aaaa-16 165287482 39.00 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_aaaa-16 99625165 59.56 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_cname_a-16 72154432 81.15 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_ptr-16 100418211 60.88 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_txt-16 80963180 73.66 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_srv-16 89021206 69.35 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_not_full-16 31277523 187.6 ns/op 64 B/op 1 allocs/op
// BenchmarkCloner_Clone/resp_https-16 14601229 396.3 ns/op 0 B/op 0 allocs/op
// BenchmarkCloner_Clone/resp_https_empty_hint-16 45725181 127.4 ns/op 0 B/op 0 allocs/op
}

View File

@ -2,10 +2,9 @@ package dnsmsg
import (
"fmt"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
@ -38,7 +37,7 @@ func (c *Constructor) NewBlockedRespMsg(req *dns.Msg) (msg *dns.Msg, err error)
case *BlockingModeNullIP:
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA, dns.TypeAAAA:
return c.NewIPRespMsg(req, nil)
return c.NewIPRespMsg(req, netip.Addr{})
default:
return c.NewMsgNODATA(req), nil
}
@ -62,11 +61,11 @@ func (c *Constructor) newBlockedCustomIPRespMsg(
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
if m.IPv4.IsValid() {
return c.NewIPRespMsg(req, m.IPv4.AsSlice())
return c.NewIPRespMsg(req, m.IPv4)
}
case dns.TypeAAAA:
if m.IPv6.IsValid() {
return c.NewIPRespMsg(req, m.IPv6.AsSlice())
return c.NewIPRespMsg(req, m.IPv6)
}
default:
// Go on.
@ -78,7 +77,7 @@ func (c *Constructor) newBlockedCustomIPRespMsg(
// NewIPRespMsg returns a DNS A or AAAA response message with the given IP
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
// null) IP. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return c.newMsgA(req, ips...)
@ -205,40 +204,38 @@ func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, cl dns.Class)
}
// NewAnsA returns a new resource record with an IPv4 address. ip must be an
// IPv4 address. If ip is nil, it is replaced by an unspecified (aka null) IP,
// 0.0.0.0.
func (c *Constructor) NewAnsA(req *dns.Msg, ip net.IP) (ans *dns.A, err error) {
var ip4 net.IP
if ip == nil {
ip4 = net.IP{0, 0, 0, 0}
} else if err = netutil.ValidateIP(ip); err != nil {
return nil, err
} else if ip4 = ip.To4(); ip4 == nil {
// IPv4 address. If ip is a zero netip.Addr, it is replaced by an unspecified
// (aka null) IP, 0.0.0.0.
func (c *Constructor) NewAnsA(req *dns.Msg, ip netip.Addr) (ans *dns.A, err error) {
if ip == (netip.Addr{}) {
ip = netip.IPv4Unspecified()
} else if !ip.Is4() {
return nil, fmt.Errorf("bad ipv4: %s", ip)
}
data := ip.As4()
return &dns.A{
Hdr: c.newHdr(req, dns.TypeA),
A: ip4,
A: data[:],
}, nil
}
// NewAnsAAAA returns a new resource record with an IPv6 address. ip must be an
// IPv6 address. If ip is nil, it is replaced by an unspecified (aka null) IP,
// [::].
func (c *Constructor) NewAnsAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA, err error) {
var ip6 net.IP
if ip == nil {
ip6 = net.IPv6unspecified
} else if err = netutil.ValidateIP(ip); err != nil {
return nil, err
} else {
ip6 = ip.To16()
// IPv6 address. If ip is a zero netip.Addr, it is replaced by an unspecified
// (aka null) IP, [::].
func (c *Constructor) NewAnsAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA, err error) {
if ip == (netip.Addr{}) {
ip = netip.IPv6Unspecified()
} else if !ip.Is6() {
return nil, fmt.Errorf("bad ipv6: %s", ip)
}
data := ip.As16()
return &dns.AAAA{
Hdr: c.newHdr(req, dns.TypeAAAA),
AAAA: ip6,
AAAA: data[:],
}, nil
}
@ -358,7 +355,7 @@ func (c *Constructor) NewRespMsg(req *dns.Msg) (resp *dns.Msg) {
// newMsgA returns a new DNS response with the given IPv4 addresses. If any IP
// address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0.
func (c *Constructor) newMsgA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req)
for i, ip := range ips {
var ans dns.RR
@ -375,7 +372,7 @@ func (c *Constructor) newMsgA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err er
// newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any
// IP address is nil, it is replaced by an unspecified (aka null) IP, [::].
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req)
for i, ip := range ips {
var ans dns.RR

View File

@ -2,7 +2,6 @@ package dnsmsg_test
import (
"net"
"net/netip"
"strings"
"testing"
@ -75,9 +74,6 @@ func TestConstructor_NewBlockedRespMsg_nullIP(t *testing.T) {
}
func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
wantIPv4 := netip.MustParseAddr("1.2.3.4")
wantIPv6 := netip.MustParseAddr("1234::cdef")
testCases := []struct {
messages *dnsmsg.Constructor
name string
@ -85,22 +81,22 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
wantAAAA bool
}{{
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
IPv4: wantIPv4,
IPv6: wantIPv6,
IPv4: testIPv4,
IPv6: testIPv6,
}, testFltRespTTL),
name: "both",
wantA: true,
wantAAAA: true,
}, {
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
IPv4: wantIPv4,
IPv4: testIPv4,
}, testFltRespTTL),
name: "ipv4_only",
wantA: true,
wantAAAA: false,
}, {
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
IPv6: wantIPv6,
IPv6: testIPv6,
}, testFltRespTTL),
name: "ipv6_only",
wantA: false,
@ -120,7 +116,7 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
require.Len(t, respA.Answer, 1)
a := testutil.RequireTypeAssert[*dns.A](t, respA.Answer[0])
assert.Equal(t, net.IP(wantIPv4.AsSlice()), a.A)
assert.Equal(t, net.IP(testIPv4.AsSlice()), a.A)
} else {
assert.Empty(t, respA.Answer)
}
@ -136,7 +132,7 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
require.Len(t, respAAAA.Answer, 1)
aaaa := testutil.RequireTypeAssert[*dns.AAAA](t, respAAAA.Answer[0])
assert.Equal(t, net.IP(wantIPv6.AsSlice()), aaaa.AAAA)
assert.Equal(t, net.IP(testIPv6.AsSlice()), aaaa.AAAA)
} else {
assert.Empty(t, respAAAA.Answer)
}

View File

@ -31,6 +31,12 @@ const (
testFQDN = testDomain + "."
)
// Common IP addresses for tests.
var (
testIPv4 = netip.MustParseAddr("1.2.3.4")
testIPv6 = netip.MustParseAddr("1234::cdef")
)
// Common test constants.
const (
ipv4MaskBits = 24

View File

@ -14,14 +14,6 @@ import (
"github.com/stretchr/testify/require"
)
var (
// testIPv4 is an IPv4 for tests.
testIPv4 = netip.MustParseAddr("1.2.3.4")
// testIPv6 is an IPv6 for tests.
testIPv6 = netip.MustParseAddr("::1")
)
func TestConstructor_NewAnswerHTTPS_andSVCB(t *testing.T) {
// Preconditions.

View File

@ -3,6 +3,7 @@ package cache_test
import (
"context"
"net"
"net/netip"
"testing"
"time"
@ -26,6 +27,7 @@ func TestMiddleware_Wrap(t *testing.T) {
defaultTTL uint32 = 3600
)
reqAddr := netip.MustParseAddr("1.2.3.4")
testTTL := 60 * time.Second
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
@ -44,7 +46,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}{{
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, defaultTTL, reqAddr),
}),
name: "simple_a",
wantNumReq: 1,
@ -114,7 +116,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, 0, reqAddr),
}),
name: "expired_one",
wantNumReq: N,
@ -123,7 +125,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 10, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, 10, reqAddr),
}),
name: "override_ttl_ok",
wantNumReq: 1,
@ -132,7 +134,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 1000, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, 1000, reqAddr),
}),
name: "override_ttl_max",
wantNumReq: 1,
@ -141,7 +143,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, 0, reqAddr),
}),
name: "override_ttl_zero",
wantNumReq: N,
@ -150,7 +152,7 @@ func TestMiddleware_Wrap(t *testing.T) {
}, {
req: aReq,
resp: dnsservertest.NewResp(dns.RcodeServerFailure, aReq, dnsservertest.SectionAnswer{
dnsservertest.NewA(reqHostname, servFailMaxCacheTTL, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(reqHostname, servFailMaxCacheTTL, reqAddr),
}),
name: "override_ttl_servfail",
wantNumReq: 1,

View File

@ -2,6 +2,7 @@ package dnsservertest
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/golibs/testutil"
@ -127,6 +128,38 @@ func NewResp(rcode int, req *dns.Msg, rrs ...RRSection) (resp *dns.Msg) {
return resp
}
// NewA constructs the new resource record of type A. a must be a valid 4-byte
// IPv4-address.
func NewA(name string, ttl uint32, a netip.Addr) (rr dns.RR) {
data := a.As4()
return &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: data[:],
}
}
// NewAAAA constructs the new resource record of type AAAA. aaaa must be a
// valid 16-byte IPv6-address.
func NewAAAA(name string, ttl uint32, aaaa netip.Addr) (rr dns.RR) {
data := aaaa.As16()
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: data[:],
}
}
// NewCNAME constructs the new resource record of type CNAME.
func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) {
return &dns.CNAME{
@ -140,39 +173,20 @@ func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) {
}
}
// NewA constructs the new resource record of type A. a must be a valid 4-byte
// IPv4-address.
func NewA(name string, ttl uint32, a net.IP) (rr dns.RR) {
return &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: a,
}
}
// NewAAAA constructs the new resource record of type AAAA. aaaa must be a
// valid 16-byte IPv6-address.
func NewAAAA(name string, ttl uint32, aaaa net.IP) (rr dns.RR) {
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: aaaa,
}
}
// NewHTTPS constructs the new resource record of type HTTPS with IPv4 and IPv6
// hint records from provided v4Hint and v6Hint parameters.
//
// TODO(d.kolyshev): Add "alpn" and other SVCB key-value pairs.
func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
func NewHTTPS(name string, ttl uint32, v4Hints, v6Hints []netip.Addr) (rr dns.RR) {
v4Hint := &dns.SVCBIPv4Hint{}
for _, ip := range v4Hints {
v4Hint.Hint = append(v4Hint.Hint, ip.AsSlice())
}
v6Hint := &dns.SVCBIPv6Hint{}
for _, ip := range v6Hints {
v6Hint.Hint = append(v6Hint.Hint, ip.AsSlice())
}
svcb := dns.SVCB{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
@ -181,10 +195,7 @@ func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
Ttl: ttl,
},
Target: dns.Fqdn(name),
Value: []dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: v4Hint},
&dns.SVCBIPv6Hint{Hint: v6Hint},
},
Value: []dns.SVCBKeyValue{v4Hint, v6Hint},
}
return &dns.HTTPS{
@ -192,6 +203,49 @@ func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
}
}
// NewPTR constructs the new resource record of type PTR.
func NewPTR(name string, ttl uint32, target string) (rr dns.RR) {
return &dns.PTR{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: ttl,
},
Ptr: dns.Fqdn(target),
}
}
// NewSRV constructs the new resource record of type SRV.
func NewSRV(name string, ttl uint32, target string, prio, weight, port uint16) (rr dns.RR) {
return &dns.SRV{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: ttl,
},
Priority: prio,
Weight: weight,
Port: port,
Target: target,
}
}
// NewTXT constructs the new resource record of type TXT. txts are put into the
// TXT record as is.
func NewTXT(name string, ttl uint32, txts ...string) (rr dns.RR) {
return &dns.TXT{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: ttl,
},
Txt: txts,
}
}
// NewSOA constructs the new resource record of type SOA.
func NewSOA(name string, ttl uint32, ns, mbox string) (rr dns.RR) {
return &dns.SOA{

View File

@ -2,7 +2,7 @@ package dnsservertest_test
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/netutil"
@ -45,7 +45,7 @@ func ExampleNewResp() {
m = dnsservertest.NewResp(dns.RcodeSuccess, m, dnsservertest.SectionAnswer{
dnsservertest.NewCNAME(testFQDN, 3600, realTestFQDN),
dnsservertest.NewA(realTestFQDN, 3600, net.IP{1, 2, 3, 4}),
dnsservertest.NewA(realTestFQDN, 3600, netip.MustParseAddr("1.2.3.4")),
}, dnsservertest.SectionNs{
dnsservertest.NewSOA(realTestFQDN, 1000, "ns."+realTestFQDN, "mbox."+realTestFQDN),
dnsservertest.NewNS(testFQDN, 1000, "ns."+testFQDN),

View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.20
require (
github.com/AdguardTeam/golibs v0.13.6
github.com/AdguardTeam/golibs v0.14.0
github.com/ameshkov/dnscrypt/v2 v2.2.7
github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2
@ -11,11 +11,11 @@ require (
github.com/panjf2000/ants/v2 v2.7.5
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/prometheus/client_golang v1.15.1
github.com/quic-go/quic-go v0.35.1
github.com/quic-go/quic-go v0.38.0
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/net v0.12.0
golang.org/x/sys v0.10.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/net v0.14.0
golang.org/x/sys v0.11.0
)
require (
@ -27,21 +27,20 @@ require (
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect
github.com/kr/text v0.2.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/onsi/ginkgo/v2 v2.10.0 // indirect
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/text v0.11.0 // indirect
golang.org/x/tools v0.10.0 // indirect
github.com/quic-go/qtls-go1-20 v0.3.3 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,5 +1,5 @@
github.com/AdguardTeam/golibs v0.13.6 h1:z/0Q25pRLdaQxtoxvfSaooz5mdv8wj0R8KREj54q8yQ=
github.com/AdguardTeam/golibs v0.13.6/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
github.com/AdguardTeam/golibs v0.14.0 h1:/vfJshXBVaevMuBgzAIr+F64XdNqZL+j9F33GXJmgeQ=
github.com/AdguardTeam/golibs v0.14.0/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
@ -29,8 +29,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@ -38,9 +37,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zk
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
github.com/panjf2000/ants/v2 v2.7.5 h1:/vhh0Hza9G1vP1PdCj9hl6MUzCRbmtcTJL0OsnmytuU=
github.com/panjf2000/ants/v2 v2.7.5/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8=
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible h1:IWzUvJ72xMjmrjR9q3H1PF+jwdN0uNQiR2t1BLNalyo=
@ -57,12 +55,8 @@ github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+Pymzi
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM=
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
@ -76,18 +70,14 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -98,18 +88,15 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -9,7 +9,7 @@ import (
// responsePaddingMaxSize is used to calculate the EDNS padding length. We use
// the Random-Length Padding strategy from RFC 8467 as we find it more
// efficient, it requires less extra traffic while provides comparable entropy.
const responsePaddingMaxSize = 128
const responsePaddingMaxSize = 32
// respPadBuf is a fixed buffer to draw on for padding.
var respPadBuf [responsePaddingMaxSize]byte
@ -64,7 +64,7 @@ func normalize(network Network, proto Protocol, req, resp *dns.Msg) {
// In the case of encrypted protocols we should pad responses.
if proto.HasPaddingSupport() {
padAnswer(resp, reqOpt, respOpt)
padAnswer(reqOpt, respOpt)
}
}
@ -123,7 +123,7 @@ func filterUnsupportedOptions(o []dns.EDNS0) (supported []dns.EDNS0) {
// padAnswer adds padding to a DNS response before it's sent back over an
// encrypted DNS protocol according to RFC 8467. Unencrypted responses should
// not be padded. Inspired by github.com/folbricht/routedns padding.
func padAnswer(resp *dns.Msg, reqOpt, respOpt *dns.OPT) {
func padAnswer(reqOpt, respOpt *dns.OPT) {
if findOption[*dns.EDNS0_PADDING](reqOpt) == nil {
// According to the RFC, responders MAY (or may not) pad responses when
// the padding option is not included in the request. In our case, we
@ -146,22 +146,14 @@ func padAnswer(resp *dns.Msg, reqOpt, respOpt *dns.OPT) {
// TODO(ameshkov): Consider changing to crypto/rand, need to hold a vote.
// #nosec G404 -- We don't need a real random for a simple padding
// randomization, pseudo-random is enough.
padLen := rand.Intn(responsePaddingMaxSize-1) + 1
// If padding would make the packet larger than the request EDNS0 allows,
// we need to truncate it.
//
// TODO(ameshkov): Consider removing this check and not calling resp.Len().
// resp.Len() is a rather heavy function which we'd better avoid calling.
// However, we risk having a message larger than 64kB in this case.
answerLen := resp.Len()
if packetSize := int(reqOpt.UDPSize()); answerLen+padLen > packetSize {
padLen = packetSize - answerLen
if padLen < 0 {
// Still doesn't fit? Give up on padding.
padLen = 0
}
}
// Note, that we don't check for whether reqOpt.UDPSize() here is smaller
// than resp.Len() + padLen so in theory the padded response may be larger
// than 64kB. This is an acceptable risk considering the savings on
// avoiding calling resp.Len().
//
// TODO(ameshkov): Return this check if we optimize resp.Len().
padLen := rand.Intn(responsePaddingMaxSize-1) + 1
paddingOpt.Padding = respPadBuf[:padLen:padLen]
}

View File

@ -2,7 +2,6 @@ package prometheus_test
import (
"context"
"net"
"testing"
"time"
@ -28,20 +27,17 @@ func TestCacheMetricsListener_integration_cache(t *testing.T) {
cacheMiddleware,
)
// Pass 10 requests through the middleware
// This way we'll increment and set both hits and misses.
// Pass 10 requests through the middleware. This way we'll increment and
// set both hits and misses.
for i := 0; i < 10; i++ {
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
addr := &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 53}
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
ctx := dnsserver.ContextWithServerInfo(context.Background(), dnsserver.ServerInfo{
Name: "test_server",
Addr: "127.0.0.1:0",
Proto: dnsserver.ProtoDNS,
})
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
nrw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
err := handlerWithMiddleware.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false)

View File

@ -29,7 +29,7 @@ func TestForwardMetricsListener_integration_request(t *testing.T) {
// Prepare a test DNS message and call the handler's ServeDNS function.
// It will then call the metrics listener and prom metrics should be
// incremented.
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr())
err := handler.ServeDNS(context.Background(), rw, req)

View File

@ -1,95 +1,95 @@
package prometheus
import (
"context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
)
// counterWithRequestLabels is a helper method that gets or creates a
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
// this method is to avoid allocating [prometheus.Labels] and instead use the
// WithLabelValues function. This way extra allocations are avoided, but it is
// sensitive to the labels order.
func counterWithRequestLabels(
serverInfo dnsserver.ServerInfo,
// reqLabelMetricKey contains the information for a request label.
type reqLabelMetricKey struct {
network string
qType string
family string
srvInfo dnsserver.ServerInfo
}
// newReqLabelMetricKey returns a new metric key from the given data.
func newReqLabelMetricKey(
ctx context.Context,
req *dns.Msg,
rw dnsserver.ResponseWriter,
vec *prometheus.CounterVec,
) (c prometheus.Counter) {
ip, _ := netutil.IPAndPortFromAddr(rw.RemoteAddr())
) (k reqLabelMetricKey) {
return reqLabelMetricKey{
network: string(dnsserver.NetworkFromAddr(rw.LocalAddr())),
qType: typeToString(req),
family: raddrToFamily(rw.RemoteAddr()),
srvInfo: dnsserver.MustServerInfoFromContext(ctx),
}
}
// Address family metric.
var family string
if ip == nil {
// Unknown.
family = "0"
} else if ip.To4() != nil {
// IPv4.
family = "1"
} else {
// IPv6.
family = "2"
// withLabelValues returns a counter with the given arguments in the correct
// order.
func (k reqLabelMetricKey) withLabelValues(vec *prometheus.CounterVec) (c prometheus.Counter) {
// The labels must be in the following order:
// 1. server name;
// 2. server protocol;
// 3. server socket network ("tcp"/"udp");
// 4. server addr;
// 5. question type (see [typeToString]);
// 6. IP family (see [raddrToFamily]).
return vec.WithLabelValues(
k.srvInfo.Name,
k.srvInfo.Proto.String(),
k.network,
k.srvInfo.Addr,
k.qType,
k.family,
)
}
// prometheusVector is the interface for vectors of counters, histograms, etc.
type prometheusVector[T any] interface {
WithLabelValues(labelValues ...string) (m T)
}
// withSrvInfoLabelValues returns a metric with the server info data in the
// correct order.
func withSrvInfoLabelValues[T any](
vec prometheusVector[T],
srvInfo dnsserver.ServerInfo,
) (m T) {
// The labels must be in the following order:
// 1. server name;
// 2. server protocol;
// 3. server addr;
return vec.WithLabelValues(
srvInfo.Name,
srvInfo.Proto.String(),
srvInfo.Addr,
)
}
// raddrToFamily returns a family metric value for raddr.
// The values are:
//
// 0. Unknown.
// 1. IPv4.
// 2. IPv6.
func raddrToFamily(raddr net.Addr) (family string) {
ip := netutil.NetAddrToAddrPort(raddr).Addr()
if !ip.IsValid() {
return "0"
} else if ip.Is4() {
return "1"
}
// The metric's labels MUST be in the following order:
// "name", "proto", "network", "addr", "type", "family"
return vec.WithLabelValues(
serverInfo.Name,
serverInfo.Proto.String(),
string(dnsserver.NetworkFromAddr(rw.LocalAddr())),
serverInfo.Addr,
typeToString(req),
family,
)
}
// counterWithRequestLabels is a helper method that gets or creates a
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
// this method is to avoid allocating [prometheus.Labels] and instead use the
// WithLabelValues function. This way extra allocations are avoided, but it is
// sensitive to the labels order.
func counterWithServerLabels(
serverInfo dnsserver.ServerInfo,
vec *prometheus.CounterVec,
) (c prometheus.Counter) {
// The metric's labels MUST be in the following order:
// "name", "proto", "addr"
return vec.WithLabelValues(
serverInfo.Name,
serverInfo.Proto.String(),
serverInfo.Addr,
)
}
// histogramWithServerLabels is a helper method that gets or creates a
// [prometheus.Observer] from the specified *prometheus.HistogramVec. The point
// of this method is to avoid allocating [prometheus.Labels] and instead use the
// WithLabelValues function. This way extra allocations are avoided, but it is
// sensitive to the labels order.
func histogramWithServerLabels(
serverInfo dnsserver.ServerInfo,
vec *prometheus.HistogramVec,
) (h prometheus.Observer) {
// The metric's labels MUST be in the following order:
// "name", "proto", "addr"
return vec.WithLabelValues(serverInfo.Name, serverInfo.Proto.String(), serverInfo.Addr)
}
// counterWithServerLabelsPlusRCode is a helper method that gets or creates a
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
// this method is to avoid allocating [prometheus.Labels] and instead use the
// WithLabelValues function. This way extra allocations are avoided, but it is
// sensitive to the labels order.
func counterWithServerLabelsPlusRCode(
serverInfo dnsserver.ServerInfo,
rCode string,
vec *prometheus.CounterVec,
) (c prometheus.Counter) {
// The metric's labels MUST be in the following order:
// "name", "proto", "addr", "rcode"
return vec.WithLabelValues(serverInfo.Name, serverInfo.Proto.String(), serverInfo.Addr, rCode)
return "2"
}
// setBoolGauge sets gauge to the numeric value corresponding to the val.

View File

@ -0,0 +1,51 @@
package prometheus
import "sync"
// initSyncMap is a wrapper around [*sync.Map] that initializes the data if it's
// not present in an atomic way.
//
// TODO(a.garipov): Move to golibs and use more.
type initSyncMap[K, V any] struct {
inner *sync.Map
new func(k K) (v V)
}
// newInitSyncMap returns a new properly initialized *initSyncMap that uses
// newFunc to return a value for the given key.
func newInitSyncMap[K, V any](newFunc func(k K) (v V)) (m *initSyncMap[K, V]) {
return &initSyncMap[K, V]{
inner: &sync.Map{},
new: newFunc,
}
}
// get returns a value for the given key. If a value isn't available, it waits
// until it is.
func (m *initSyncMap[K, V]) get(key K) (v V) {
// Step 1. The fast track: check if there is already a value present.
loadVal, inited := m.inner.Load(key)
if inited {
return loadVal.(func() (v V))()
}
// Step 2. Allocate a done channel and create a function that waits for one
// single initialization. Use the one returned from LoadOrStore regardless
// of whether it's this one.
var cached V
done := make(chan struct{}, 1)
done <- struct{}{}
loadVal, _ = m.inner.LoadOrStore(key, func() (loaded V) {
_, ok := <-done
if ok {
// The only real receive. Initialize the cached value and close the
// channel so that other goroutines receive the same value.
cached = m.new(key)
close(done)
}
return cached
})
return loadVal.(func() (v V))()
}

View File

@ -0,0 +1,41 @@
package prometheus
import (
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestInitSyncMap(t *testing.T) {
numCalls := atomic.Uint32{}
m := newInitSyncMap[int, int](func(k int) (v int) {
numCalls.Add(1)
return k + 1
})
const (
n = 1_000
key = 1
want = key + 1
)
results := make(chan int, n)
for i := 0; i < n; i++ {
go func() {
results <- m.get(key)
}()
}
for i := 0; i < n; i++ {
got, _ := testutil.RequireReceive(t, results, 1*time.Second)
assert.Equal(t, want, got)
}
assert.Equal(t, uint32(1), numCalls.Load())
}

View File

@ -1,8 +1,10 @@
package prometheus_test
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/golibs/testutil"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
@ -12,6 +14,22 @@ func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testReqDomain is the common request domain for tests.
const testReqDomain = "request-domain.example"
// testServerInfo is the common server information structure for tests.
var testServerInfo = dnsserver.ServerInfo{
Name: "test_server",
Addr: "127.0.0.1:80",
Proto: dnsserver.ProtoDNS,
}
// testUDPAddr is the common UDP address for tests.
var testUDPAddr = &net.UDPAddr{
IP: net.IP{1, 2, 3, 4},
Port: 53,
}
// requireMetrics accepts a list of metrics names and checks that
// they exist in the prom registry.
func requireMetrics(t testing.TB, args ...string) {
@ -34,6 +52,5 @@ func requireMetrics(t testing.TB, args ...string) {
delete(metricsToCheck, m.GetName())
}
require.Len(t, metricsToCheck, 0,
"Some metrics weren't reported: %v", metricsToCheck)
require.Len(t, metricsToCheck, 0, "Some metrics weren't reported: %v", metricsToCheck)
}

View File

@ -10,33 +10,47 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)
// RateLimitMetricsListener implements the ratelimit.MetricsListener interface
// RateLimitMetricsListener implements the [ratelimit.MetricsListener] interface
// and increments prom counters.
type RateLimitMetricsListener struct{}
type RateLimitMetricsListener struct {
dropCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
allowlistedCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
}
// NewRateLimitMetricsListener returns a new properly initialized
// *RateLimitMetricsListener.
func NewRateLimitMetricsListener() (l *RateLimitMetricsListener) {
return &RateLimitMetricsListener{
dropCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
return k.withLabelValues(droppedTotal)
}),
allowlistedCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
return k.withLabelValues(allowlistedTotal)
}),
}
}
// type check
var _ ratelimit.MetricsListener = (*RateLimitMetricsListener)(nil)
// OnRateLimited implements the ratelimit.MetricsListener interface for
// *RateLimitMetricsListener.
func (r *RateLimitMetricsListener) OnRateLimited(
func (l *RateLimitMetricsListener) OnRateLimited(
ctx context.Context,
req *dns.Msg,
rw dnsserver.ResponseWriter,
) {
s := dnsserver.MustServerInfoFromContext(ctx)
counterWithRequestLabels(s, req, rw, droppedTotal).Inc()
l.dropCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
}
// OnAllowlisted implements the ratelimit.MetricsListener interface for
// *RateLimitMetricsListener.
func (r *RateLimitMetricsListener) OnAllowlisted(
func (l *RateLimitMetricsListener) OnAllowlisted(
ctx context.Context,
req *dns.Msg,
rw dnsserver.ResponseWriter,
) {
s := dnsserver.MustServerInfoFromContext(ctx)
counterWithRequestLabels(s, req, rw, allowlistedTotal).Inc()
l.allowlistedCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
}
// This block contains prometheus metrics declarations for ratelimit.Middleware

View File

@ -2,7 +2,6 @@ package prometheus_test
import (
"context"
"net"
"net/netip"
"testing"
"time"
@ -33,7 +32,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
})
rlMw, err := ratelimit.NewMiddleware(rl, nil)
require.NoError(t, err)
rlMw.Metrics = &prometheus.RateLimitMetricsListener{}
rlMw.Metrics = prometheus.NewRateLimitMetricsListener()
handlerWithMiddleware := dnsserver.WithMiddlewares(
dnsservertest.DefaultHandler(),
@ -42,17 +41,14 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
// Pass 10 requests through the middleware.
for i := 0; i < 10; i++ {
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
addr := &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 53}
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
ctx := dnsserver.ContextWithServerInfo(context.Background(), dnsserver.ServerInfo{
Name: "test",
Addr: "127.0.0.1",
Proto: dnsserver.ProtoDNS,
})
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
nrw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
err = handlerWithMiddleware.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
if i < rps {
@ -65,3 +61,35 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
// Now make sure that prometheus metrics were incremented properly.
requireMetrics(t, "dns_ratelimit_dropped_total")
}
func BenchmarkRateLimitMetricsListener(b *testing.B) {
l := prometheus.NewRateLimitMetricsListener()
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
rw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
b.Run("OnAllowlisted", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnAllowlisted(ctx, req, rw)
}
})
b.Run("OnRateLimited", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnRateLimited(ctx, req, rw)
}
})
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkRateLimitMetricsListener/OnAllowlisted-16 6025423 209.5 ns/op 0 B/op 0 allocs/op
// BenchmarkRateLimitMetricsListener/OnRateLimited-16 5798031 209.4 ns/op 0 B/op 0 allocs/op
}

View File

@ -12,7 +12,80 @@ import (
// ServerMetricsListener implements the [dnsserver.MetricsListener] interface
// and increments prom counters.
type ServerMetricsListener struct{}
type ServerMetricsListener struct {
reqTotalCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
respRCodeCounters *initSyncMap[srvInfoRCode, prometheus.Counter]
invalidMsgCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
errorCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
panicCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
reqDurationHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
reqSizeHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
respSizeHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
}
// srvInfoRCode is a struct containing the server information along with a
// response code.
type srvInfoRCode struct {
rCode string
dnsserver.ServerInfo
}
// withLabelValues returns a counter with the server info and rcode data in the
// correct order.
func (i srvInfoRCode) withLabelValues(
vec *prometheus.CounterVec,
) (c prometheus.Counter) {
// The labels must be in the following order:
// 1. server name;
// 2. server protocol;
// 3. server addr;
// 4. response code;
return vec.WithLabelValues(
i.Name,
i.Proto.String(),
i.Addr,
i.rCode,
)
}
// NewServerMetricsListener returns a new properly initialized
// *ServerMetricsListener.
func NewServerMetricsListener() (l *ServerMetricsListener) {
return &ServerMetricsListener{
reqTotalCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
return k.withLabelValues(requestTotal)
}),
respRCodeCounters: newInitSyncMap(func(k srvInfoRCode) (c prometheus.Counter) {
return k.withLabelValues(responseRCode)
}),
invalidMsgCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
// TODO(a.garipov): Here and below, remove explicit type
// declarations in Go 1.21.
return withSrvInfoLabelValues[prometheus.Counter](invalidMsgTotal, k)
}),
errorCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
return withSrvInfoLabelValues[prometheus.Counter](errorTotal, k)
}),
panicCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
return withSrvInfoLabelValues[prometheus.Counter](panicTotal, k)
}),
reqDurationHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
return withSrvInfoLabelValues[prometheus.Observer](requestDuration, k)
}),
reqSizeHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
return withSrvInfoLabelValues[prometheus.Observer](requestSize, k)
}),
respSizeHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
return withSrvInfoLabelValues[prometheus.Observer](responseSize, k)
}),
}
}
// type check
var _ dnsserver.MetricsListener = (*ServerMetricsListener)(nil)
@ -21,51 +94,57 @@ var _ dnsserver.MetricsListener = (*ServerMetricsListener)(nil)
// [*ServerMetricsListener].
func (l *ServerMetricsListener) OnRequest(
ctx context.Context,
req, resp *dns.Msg,
req *dns.Msg,
resp *dns.Msg,
rw dnsserver.ResponseWriter,
) {
serverInfo := dnsserver.MustServerInfoFromContext(ctx)
startTime := dnsserver.MustStartTimeFromContext(ctx)
// Increment total requests count metrics.
counterWithRequestLabels(serverInfo, req, rw, requestTotal).Inc()
l.reqTotalCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
// Increment request duration histogram.
elapsed := time.Since(startTime).Seconds()
histogramWithServerLabels(serverInfo, requestDuration).Observe(elapsed)
l.reqDurationHistograms.get(serverInfo).Observe(elapsed)
// Increment request size.
ri := dnsserver.MustRequestInfoFromContext(ctx)
histogramWithServerLabels(serverInfo, requestSize).Observe(float64(ri.RequestSize))
l.reqSizeHistograms.get(serverInfo).Observe(float64(ri.RequestSize))
// If resp is not nil, increment response-related metrics.
if resp != nil {
histogramWithServerLabels(serverInfo, responseSize).Observe(float64(ri.ResponseSize))
rCode := rCodeToString(resp.Rcode)
counterWithServerLabelsPlusRCode(serverInfo, rCode, responseRCode).Inc()
l.respSizeHistograms.get(serverInfo).Observe(float64(ri.ResponseSize))
l.respRCodeCounters.get(srvInfoRCode{
ServerInfo: serverInfo,
rCode: rCodeToString(resp.Rcode),
}).Inc()
} else {
// If resp is nil, increment responseRCode with a special "rcode"
// label value ("DROPPED").
counterWithServerLabelsPlusRCode(serverInfo, "DROPPED", responseRCode).Inc()
// If resp is nil, increment responseRCode with a special "rcode" label
// value ("DROPPED").
l.respRCodeCounters.get(srvInfoRCode{
ServerInfo: serverInfo,
rCode: "DROPPED",
}).Inc()
}
}
// OnInvalidMsg implements the [dnsserver.MetricsListener] interface for
// [*ServerMetricsListener].
func (l *ServerMetricsListener) OnInvalidMsg(ctx context.Context) {
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), invalidMsgTotal).Inc()
l.invalidMsgCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
}
// OnError implements the [dnsserver.MetricsListener] interface for
// [*ServerMetricsListener].
func (l *ServerMetricsListener) OnError(ctx context.Context, _ error) {
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), errorTotal).Inc()
l.errorCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
}
// OnPanic implements the [dnsserver.MetricsListener] interface for
// [*ServerMetricsListener].
func (l *ServerMetricsListener) OnPanic(ctx context.Context, _ any) {
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), panicTotal).Inc()
l.panicCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
}
// OnQUICAddressValidation implements the [dnsserver.MetricsListener] interface

View File

@ -3,10 +3,11 @@ package prometheus_test
import (
"context"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
prom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
Name: "test",
Addr: "127.0.0.1:0",
Handler: dnsservertest.DefaultHandler(),
Metrics: &prom.ServerMetricsListener{},
Metrics: prometheus.NewServerMetricsListener(),
},
}
srv := dnsserver.NewServerDNS(conf)
@ -39,7 +40,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
})
// Create a test message.
req := dnsservertest.CreateMessage("example.org", dns.TypeA)
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
c := &dns.Client{Net: "tcp"}
@ -48,8 +49,8 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
// Pass 10 requests to make the test less flaky.
for i := 0; i < 10; i++ {
res, _, eerr := c.Exchange(req, addr)
require.NoError(t, eerr)
res, _, exchErr := c.Exchange(req, addr)
require.NoError(t, exchErr)
require.NotNil(t, res)
require.Equal(t, dns.RcodeSuccess, res.Rcode)
}
@ -64,3 +65,61 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
"dns_server_response_rcode_total",
)
}
func BenchmarkServerMetricsListener(b *testing.B) {
l := prometheus.NewServerMetricsListener()
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
resp := (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
ctx = dnsserver.ContextWithRequestInfo(ctx, dnsserver.RequestInfo{
RequestSize: req.Len(),
ResponseSize: resp.Len(),
})
rw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
b.Run("OnRequest", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnRequest(ctx, req, resp, rw)
}
})
b.Run("OnInvalidMsg", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnInvalidMsg(ctx)
}
})
b.Run("OnError", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnError(ctx, nil)
}
})
b.Run("OnPanic", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.OnPanic(ctx, nil)
}
})
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkServerMetricsListener/OnRequest-16 1550391 716.7 ns/op 0 B/op 0 allocs/op
// BenchmarkServerMetricsListener/OnInvalidMsg-16 13041940 91.75 ns/op 0 B/op 0 allocs/op
// BenchmarkServerMetricsListener/OnError-16 12297494 97.04 ns/op 0 B/op 0 allocs/op
// BenchmarkServerMetricsListener/OnPanic-16 14029394 89.19 ns/op 0 B/op 0 allocs/op
}

View File

@ -19,8 +19,8 @@ type ConfigBase struct {
// Name is used for logging, and it may be used for perf counters reporting.
Name string
// Addr is the address the server listens to. See go doc net.Dial for
// the documentation on the address format.
// Addr is the address the server listens to. See [net.Dial] for the
// documentation on the address format.
Addr string
// Network is the network this server listens to. If empty, the server will

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -47,8 +48,8 @@ func TestService_writeDebugResponse(t *testing.T) {
blockRule = "||example.com^"
)
clientIPStr := testClientIP.String()
serverIPStr := testServerAddr.String()
clientIPStr := dnssvctest.ClientIP.String()
serverIPStr := dnssvctest.ServerAddr.String()
testCases := []struct {
name string
ri *agd.RequestInfo
@ -129,26 +130,26 @@ func TestService_writeDebugResponse(t *testing.T) {
}),
}, {
name: "device",
ri: &agd.RequestInfo{Device: &agd.Device{ID: testDeviceID}},
ri: &agd.RequestInfo{Device: &agd.Device{ID: dnssvctest.DeviceID}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", clientIPStr},
{"server-ip.adguard-dns.com.", serverIPStr},
{"device-id.adguard-dns.com.", testDeviceID},
{"device-id.adguard-dns.com.", dnssvctest.DeviceIDStr},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "profile",
ri: &agd.RequestInfo{
Profile: &agd.Profile{ID: testProfileID},
Profile: &agd.Profile{ID: dnssvctest.ProfileID},
},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", clientIPStr},
{"server-ip.adguard-dns.com.", serverIPStr},
{"profile-id.adguard-dns.com.", testProfileID},
{"profile-id.adguard-dns.com.", dnssvctest.ProfileIDStr},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
@ -182,7 +183,7 @@ func TestService_writeDebugResponse(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(testLocalAddr, testRAddr)
rw := dnsserver.NewNonWriterResponseWriter(dnssvctest.LocalAddr, dnssvctest.RemoteAddr)
ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)

View File

@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
@ -20,6 +21,8 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/accessmw"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
@ -48,6 +51,9 @@ type Config struct {
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// AccessManager is used to block requests.
AccessManager access.Interface
// SafeBrowsing is the safe browsing TXT hash matcher.
SafeBrowsing filter.HashMatcher
@ -397,7 +403,7 @@ func NewListener(
metricsListener := &errCollMetricsListener{
errColl: errColl,
baseListener: &prometheus.ServerMetricsListener{},
baseListener: prometheus.NewServerMetricsListener(),
}
confBase := dnsserver.ConfigBase{
@ -473,61 +479,47 @@ func newServers(
rlProtos := []agd.Protocol{agd.ProtoDNS}
var rlm *ratelimit.Middleware
srvName := s.Name
rlm, err = ratelimit.NewMiddleware(c.RateLimit, rlProtos)
if err != nil {
return nil, fmt.Errorf("ratelimit: %w", err)
return nil, fmt.Errorf("server %q: ratelimit: %w", srvName, err)
}
rlm.Metrics = &prometheus.RateLimitMetricsListener{}
rlm.Metrics = prometheus.NewRateLimitMetricsListener()
imw := &initMw{
messages: c.Messages,
fltGrp: fg,
srvGrp: srvGrp,
srv: s,
db: c.ProfileDB,
geoIP: c.GeoIP,
errColl: c.ErrColl,
}
amw := accessmw.New(&accessmw.Config{
AccessManager: c.AccessManager,
})
imw := initial.New(&initial.Config{
Messages: c.Messages,
FilteringGroup: fg,
ServerGroup: srvGrp,
Server: s,
ProfileDB: c.ProfileDB,
GeoIP: c.GeoIP,
ErrColl: c.ErrColl,
})
h := dnsserver.WithMiddlewares(
handler,
// Keep the rate limiting middleware as the outer one to make sure
// that the application logic isn't touched if the request is
// ratelimited.
// Keep the rate limiting and access middlewares as the outer ones
// to make sure that the application logic isn't touched if the
// request is ratelimited or blocked by access settings.
rlm,
amw,
imw,
)
listeners := make([]*listener, 0, len(s.BindData))
for _, bindData := range s.BindData {
addr := bindData.Address
if addr == "" {
addr = bindData.AddrPort.String()
}
name := listenerName(s.Name, addr, s.Protocol)
lc := bindData.ListenConfig
if lc == nil {
lc = newListenConfig(c.ControlConf, c.ConnLimiter, s.Protocol)
}
var l Listener
l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, lc)
var listeners []*listener
listeners, err = newListeners(c, s, h, newListener)
if err != nil {
return nil, fmt.Errorf("server %q: %w", s.Name, err)
}
listeners = append(listeners, &listener{
name: name,
Listener: l,
})
return nil, fmt.Errorf("server %q: %w", srvName, err)
}
servers[i] = &server{
name: s.Name,
name: srvName,
handler: h,
listeners: listeners,
}
@ -536,16 +528,59 @@ func newServers(
return servers, nil
}
// newServers creates a slice of listeners for a server.
func newListeners(
c *Config,
srv *agd.Server,
handler dnsserver.Handler,
newListener NewListenerFunc,
) (listeners []*listener, err error) {
listeners = make([]*listener, 0, len(srv.BindData))
for i, bindData := range srv.BindData {
addr := bindData.Address
if addr == "" {
addr = bindData.AddrPort.String()
}
proto := srv.Protocol
name := listenerName(srv.Name, addr, proto)
lc := newListenConfig(bindData.ListenConfig, c.ControlConf, c.ConnLimiter, proto)
var l Listener
l, err = newListener(srv, name, addr, handler, c.NonDNS, c.ErrColl, lc)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
listeners = append(listeners, &listener{
name: name,
Listener: l,
})
}
return listeners, nil
}
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
original netext.ListenConfig,
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if original != nil {
if l == nil {
return original
}
return connlimiter.NewListenConfig(original, l)
}
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {

View File

@ -1,30 +0,0 @@
package dnssvc
import (
"net"
"net/netip"
)
// Common addresses for tests.
var (
testClientIP = net.IP{1, 2, 3, 4}
testRAddr = &net.TCPAddr{
IP: testClientIP,
Port: 12345,
}
testClientAddrPort = testRAddr.AddrPort()
testClientAddr = testClientAddrPort.Addr()
testServerAddr = netip.MustParseAddr("5.6.7.8")
testLocalAddr = &net.TCPAddr{
IP: testServerAddr.AsSlice(),
Port: 54321,
}
)
// testDeviceID is the common device ID for tests
const testDeviceID = "dev1234"
// testProfileID is the common profile ID for tests
const testProfileID = "prof1234"

View File

@ -1,334 +0,0 @@
package dnssvc
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// The Initial Middleware
// initMw is the outermost middleware of the AdGuard DNS server. It filters out
// the Firefox canary domain logic, sets and resets the AD bit for further
// processing, as well as puts as much information as it can into the context.
//
// This middleware must be the most outer middleware apart from the ratelimit
// one.
//
// TODO(a.garipov): Add tests.
type initMw struct {
// messages is used to build the responses specific for the request's
// context.
messages *dnsmsg.Constructor
// fltGrp is the filtering group to which srv belongs.
fltGrp *agd.FilteringGroup
// srvGrp is the server group to which srv belongs.
srvGrp *agd.ServerGroup
// srv is the current server which serves the request.
srv *agd.Server
// db is the database of user profiles and devices.
db profiledb.Interface
// geoIP detects the location of the request source.
geoIP geoip.Interface
// errColl collects and reports the errors considered non-critical.
errColl agd.ErrorCollector
}
// type check
var _ dnsserver.Middleware = (*initMw)(nil)
// Wrap implements the [dnsserver.Middleware] interface for *initMw.
func (mw *initMw) Wrap(h dnsserver.Handler) (wrapped dnsserver.Handler) {
return &initMwHandler{
mw: mw,
next: h,
}
}
// newRequestInfo returns the new request information structure using the
// middleware's configuration and values from ctx.
func (mw *initMw) newRequestInfo(
ctx context.Context,
req *dns.Msg,
laddr net.Addr,
raddr net.Addr,
fqdn string,
qt dnsmsg.RRType,
cl dnsmsg.Class,
) (ri *agd.RequestInfo, err error) {
// Put the host, server, and client IP data into the request information
// immediately.
ri = &agd.RequestInfo{
FilteringGroup: mw.fltGrp,
Messages: mw.messages,
ServerGroup: mw.srvGrp.Name,
Server: mw.srv.Name,
Host: strings.TrimSuffix(fqdn, "."),
QType: qt,
QClass: cl,
}
ri.RemoteIP = netutil.NetAddrToAddrPort(raddr).Addr()
// As an optimization, put the request ID closer to the top of the context
// stack.
ri.ID, _ = agd.RequestIDFromContext(ctx)
// Add the GeoIP information, if any.
err = mw.addLocation(ctx, ri, req)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// Add the profile information, if any.
localIP := netutil.NetAddrToAddrPort(laddr).Addr()
err = mw.addProfile(ctx, ri, req, localIP)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return ri, nil
}
// addLocation adds GeoIP location information about the client's remote address
// as well as the EDNS Client Subnet information, if there is one, to ri. err
// is not nil only if req contains a malformed EDNS Client Subnet option.
func (mw *initMw) addLocation(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) {
ri.Location = mw.locationData(ctx, ri.RemoteIP, "client")
ecs, scope, err := dnsmsg.ECSFromMsg(req)
if err != nil {
return fmt.Errorf("adding ecs info: %w", err)
} else if ecs != (netip.Prefix{}) {
ri.ECS = &agd.ECS{
Location: mw.locationData(ctx, ecs.Addr(), "ecs"),
Subnet: ecs,
Scope: scope,
}
}
return nil
}
// locationData returns the GeoIP location information about the IP address.
// typ is the type of data being requested for error reporting and logging.
func (mw *initMw) locationData(ctx context.Context, ip netip.Addr, typ string) (l *agd.Location) {
l, err := mw.geoIP.Data("", ip)
if err != nil {
// Consider GeoIP errors non-critical. Report and go on.
agd.Collectf(ctx, mw.errColl, "init mw: getting geoip for %s ip: %w", typ, err)
}
if l == nil {
optlog.Debug2("init mw: no geoip for %s ip %s", typ, ip)
} else {
optlog.Debug4("init mw: found country/asn %q/%d for %s ip %s", l.Country, l.ASN, typ, ip)
}
return l
}
// addProfile adds profile and device information, if any, to the request
// information.
func (mw *initMw) addProfile(
ctx context.Context,
ri *agd.RequestInfo,
req *dns.Msg,
localIP netip.Addr,
) (err error) {
defer func() { err = errors.Annotate(err, "getting profile from req: %w") }()
var id agd.DeviceID
if p := mw.srv.Protocol; p.IsStdEncrypted() {
// Assume that mw.srvGrp.TLS is non-nil if p.IsStdEncrypted() is true.
wildcards := mw.srvGrp.TLS.DeviceIDWildcards
id, err = deviceIDFromContext(ctx, mw.srv.Protocol, wildcards)
} else if p == agd.ProtoDNS {
id, err = deviceIDFromEDNS(req)
} else {
// No DeviceID for DNSCrypt yet.
return nil
}
if err != nil {
return err
}
optlog.Debug3("init mw: got device id %q, raddr %s, and laddr %s", id, ri.RemoteIP, localIP)
prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id)
if err != nil {
if !errors.Is(err, profiledb.ErrDeviceNotFound) {
// Very unlikely, since those two error types are the only ones
// currently returned from the default profile DB.
return fmt.Errorf("unexpected profiledb error: %s", err)
}
optlog.Debug1("init mw: profile or device not found: %s", err)
} else if prof.Deleted {
optlog.Debug1("init mw: profile %s is deleted", prof.ID)
} else {
optlog.Debug3("init mw: found profile %s and device %s by %s", prof.ID, dev.ID, byWhat)
ri.Device, ri.Profile = dev, prof
ri.Messages = dnsmsg.NewConstructor(prof.BlockingMode.Mode, prof.FilteredResponseTTL)
}
return nil
}
// Constants for the parameter by which a device has been found.
const (
byDeviceID = "device id"
byDedicatedIP = "dedicated ip"
byLinkedIP = "linked ip"
)
// profile finds the profile by the client data.
func (mw *initMw) profile(
ctx context.Context,
localIP netip.Addr,
remoteIP netip.Addr,
id agd.DeviceID,
) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) {
if id != "" {
prof, dev, err = mw.db.ProfileByDeviceID(ctx, id)
if err != nil {
return nil, nil, "", err
}
return prof, dev, byDeviceID, nil
}
if !mw.srv.LinkedIPEnabled {
optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name)
return nil, nil, "", profiledb.ErrDeviceNotFound
} else if p := mw.srv.Protocol; p != agd.ProtoDNS {
optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p)
return nil, nil, "", profiledb.ErrDeviceNotFound
}
byWhat = byDedicatedIP
prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP)
if errors.Is(err, profiledb.ErrDeviceNotFound) {
byWhat = byLinkedIP
prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP)
}
if err != nil {
return nil, nil, "", err
}
return prof, dev, byWhat, nil
}
// initMwHandler implements the [dnsserver.Handler] interface and will be used
// as a [dnsserver.Handler] that the initMw middleware returns from the Wrap
// function call.
type initMwHandler struct {
mw *initMw
next dnsserver.Handler
}
// type check
var _ dnsserver.Handler = (*initMwHandler)(nil)
// ServeDNS implements the [dnsserver.Handler] interface for *initMwHandler.
func (mh *initMwHandler) ServeDNS(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
defer func() { err = errors.Annotate(err, "init mw: %w") }()
// Save the actual value of the request AD and DO bits and set the AD
// bit in the request to true, so that the upstream validates the data
// and caches the actual value of the response AD bit. Restore it
// later, depending on the request and response data.
reqAD := req.AuthenticatedData
reqDO := dnsmsg.IsDO(req)
req.AuthenticatedData = true
// Assume that module dnsserver has already validated that the request
// always has exactly one question for us.
q := req.Question[0]
fqdn := strings.ToLower(q.Name)
qt := q.Qtype
cl := q.Qclass
// Copy middleware to the local variable to make the code simpler.
mw := mh.mw
// Get the request's information, such as GeoIP data and user profiles.
ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), fqdn, qt, cl)
if err != nil {
var ecsErr dnsmsg.BadECSError
if errors.As(err, &ecsErr) {
// We've got a bad ECS option. Log and respond with a FORMERR
// immediately.
optlog.Debug1("init mw: %s", err)
err = rw.WriteMsg(ctx, req, mw.messages.NewMsgFORMERR(req))
err = errors.Annotate(err, "writing formerr resp: %w")
}
// Don't wrap the error, because this is the main flow, and there is
// already errors.Annotate here.
return err
}
if specHdlr, name := mw.reqInfoSpecialHandler(ri, cl); specHdlr != nil {
optlog.Debug1("init mw: got req-info special handler %s", name)
// Don't wrap the error, because it's informative enough as is, and
// because if handled is true, the main flow terminates here.
return specHdlr(ctx, rw, req, ri)
}
ctx = agd.ContextWithRequestInfo(ctx, ri)
// Record the response, restore the AD bit value in both the request and
// the response, and write the response.
nwrw := makeNonWriter(rw)
err = mh.next.ServeDNS(ctx, nwrw, req)
if err != nil {
// Don't wrap the error, because this is the main flow, and there is
// already errors.Annotate here.
return err
}
resp := nwrw.Msg()
// Following RFC 6840, set the AD bit in the response only when the
// response is authenticated, and the request contained either a set DO
// bit or a set AD bit.
//
// See https://datatracker.ietf.org/doc/html/rfc6840#section-5.8.
resp.AuthenticatedData = resp.AuthenticatedData && (reqAD || reqDO)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing resp: %w")
}

View File

@ -1,855 +0,0 @@
package dnssvc
import (
"context"
"crypto/tls"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
func TestInitMw_profile(t *testing.T) {
prof := &agd.Profile{
ID: testProfileID,
DeviceIDs: []agd.DeviceID{
testDeviceID,
},
}
dev := &agd.Device{
ID: testDeviceID,
LinkedIP: testClientAddr,
DedicatedIPs: []netip.Addr{
testServerAddr,
},
}
testCases := []struct {
wantDev *agd.Device
wantProf *agd.Profile
wantByWhat string
wantErrMsg string
name string
id agd.DeviceID
proto agd.Protocol
linkedIPEnabled bool
}{{
wantDev: nil,
wantProf: nil,
wantByWhat: "",
wantErrMsg: "device not found",
name: "no_device_id",
id: "",
proto: agd.ProtoDNS,
linkedIPEnabled: true,
}, {
wantDev: dev,
wantProf: prof,
wantByWhat: byDeviceID,
wantErrMsg: "",
name: "device_id",
id: testDeviceID,
proto: agd.ProtoDNS,
linkedIPEnabled: true,
}, {
wantDev: dev,
wantProf: prof,
wantByWhat: byLinkedIP,
wantErrMsg: "",
name: "linked_ip",
id: "",
proto: agd.ProtoDNS,
linkedIPEnabled: true,
}, {
wantDev: nil,
wantProf: nil,
wantByWhat: "",
wantErrMsg: "device not found",
name: "linked_ip_dot",
id: "",
proto: agd.ProtoDoT,
linkedIPEnabled: true,
}, {
wantDev: nil,
wantProf: nil,
wantByWhat: "",
wantErrMsg: "device not found",
name: "linked_ip_disabled",
id: "",
proto: agd.ProtoDoT,
linkedIPEnabled: false,
}, {
wantDev: dev,
wantProf: prof,
wantByWhat: byDedicatedIP,
wantErrMsg: "",
name: "dedicated_ip",
id: "",
proto: agd.ProtoDNS,
linkedIPEnabled: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
gotID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
assert.Equal(t, tc.id, gotID)
if tc.wantByWhat == byDeviceID {
return prof, dev, nil
}
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByDedicatedIP: func(
_ context.Context,
gotLocalIP netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
assert.Equal(t, testServerAddr, gotLocalIP)
if tc.wantByWhat == byDedicatedIP {
return prof, dev, nil
}
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: func(
_ context.Context,
gotRemoteIP netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
assert.Equal(t, testClientAddr, gotRemoteIP)
if tc.wantByWhat == byLinkedIP {
return prof, dev, nil
}
return nil, nil, profiledb.ErrDeviceNotFound
},
}
mw := &initMw{
srv: &agd.Server{
Protocol: tc.proto,
LinkedIPEnabled: tc.linkedIPEnabled,
},
db: db,
}
ctx := context.Background()
gotProf, gotDev, gotByWhat, err := mw.profile(
ctx,
testServerAddr,
testClientAddr,
tc.id,
)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantProf, gotProf)
assert.Equal(t, tc.wantDev, gotDev)
assert.Equal(t, tc.wantByWhat, gotByWhat)
})
}
}
func TestInitMw_ServeDNS_ddr(t *testing.T) {
const (
resolverName = "dns.example.com"
resolverFQDN = resolverName + "."
targetWithID = testDeviceID + ".d." + resolverName + "."
ddrFQDN = ddrDomain + "."
dohPath = "/dns-query"
)
testDevice := &agd.Device{ID: testDeviceID}
srvs := map[agd.ServerName]*agd.Server{
"dot": {
TLS: &tls.Config{},
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
}},
Protocol: agd.ProtoDoT,
},
"doh": {
TLS: &tls.Config{},
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("5.6.7.8:54321"),
}},
Protocol: agd.ProtoDoH,
},
"dns": {
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
}},
Protocol: agd.ProtoDNS,
LinkedIPEnabled: true,
},
"dns_nolink": {
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
}},
Protocol: agd.ProtoDNS,
},
}
srvGrp := &agd.ServerGroup{
TLS: &agd.TLS{
DeviceIDWildcards: []string{"*.d." + resolverName},
},
DDR: &agd.DDR{
DeviceTargets: stringutil.NewSet(),
PublicTargets: stringutil.NewSet(),
Enabled: true,
},
Name: agd.ServerGroupName("test_server_group"),
Servers: maps.Values(srvs),
}
srvGrp.DDR.DeviceTargets.Add("d." + resolverName)
srvGrp.DDR.PublicTargets.Add(resolverName)
var dev *agd.Device
mw := &initMw{
messages: agdtest.NewConstructor(),
fltGrp: &agd.FilteringGroup{},
srvGrp: srvGrp,
db: &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
p = &agd.Profile{}
return p, dev, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
p = &agd.Profile{}
return p, dev, nil
},
},
geoIP: &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (_ netip.Prefix, _ error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
},
errColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
}
pubSVCBTmpls := []*dns.SVCB{
mw.messages.NewDDRTemplate(agd.ProtoDoH, resolverName, dohPath, nil, nil, 443, 1),
mw.messages.NewDDRTemplate(agd.ProtoDoT, resolverName, "", nil, nil, 853, 1),
mw.messages.NewDDRTemplate(agd.ProtoDoQ, resolverName, "", nil, nil, 853, 1),
}
devSVCBTmpls := []*dns.SVCB{
mw.messages.NewDDRTemplate(agd.ProtoDoH, "d."+resolverName, dohPath, nil, nil, 443, 1),
mw.messages.NewDDRTemplate(agd.ProtoDoT, "d."+resolverName, "", nil, nil, 853, 1),
mw.messages.NewDDRTemplate(agd.ProtoDoQ, "d."+resolverName, "", nil, nil, 853, 1),
}
srvGrp.DDR.PublicRecordTemplates = pubSVCBTmpls
srvGrp.DDR.DeviceRecordTemplates = devSVCBTmpls
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
_ context.Context,
_ dnsserver.ResponseWriter,
_ *dns.Msg,
) (_ error) {
// Make sure we haven't reached the following middleware.
panic("not implemented")
})
testCases := []struct {
device *agd.Device
name string
srv *agd.Server
host string
wantTarget string
wantNum int
qtype uint16
}{{
device: testDevice,
name: "id",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "id_specific",
srv: srvs["dot"],
host: ddrLabel + "." + targetWithID,
wantTarget: targetWithID,
wantNum: len(devSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "no_id",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: resolverFQDN,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "linked_ip",
srv: srvs["dns"],
host: ddrFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "no_linked_ip",
srv: srvs["dns_nolink"],
host: ddrFQDN,
wantTarget: resolverFQDN,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "public_resolver_name",
srv: srvs["dot"],
host: ddrLabel + "." + resolverFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "arpa_not_ddr_svcb",
srv: srvs["dot"],
host: dns.Fqdn(ddrLabel + ".something.else." + resolverArpaDomain),
wantTarget: "",
wantNum: 0,
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "arpa_ddr_not_svcb",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: "",
wantNum: 0,
qtype: dns.TypeA,
}}
for _, tc := range testCases {
mw.srv = tc.srv
dev = tc.device
var tlsServerName string
switch mw.srv.Protocol {
case agd.ProtoDoT, agd.ProtoDoQ:
tlsServerName = resolverName
if dev != nil {
tlsServerName = string(dev.ID) + ".d." + tlsServerName
}
default:
// Go on.
}
h := mw.Wrap(handler)
req := &dns.Msg{
Question: []dns.Question{{
Name: tc.host,
Qtype: tc.qtype,
Qclass: dns.ClassINET,
}},
}
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
TLSServerName: tlsServerName,
})
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
err := h.ServeDNS(ctx, rw, req)
require.NoError(t, err)
resp := rw.Msg()
require.NotNil(t, resp)
if tc.wantNum == 0 {
assert.Empty(t, resp.Answer)
return
}
assert.Len(t, resp.Answer, tc.wantNum)
for _, rr := range resp.Answer {
svcb := testutil.RequireTypeAssert[*dns.SVCB](t, rr)
assert.Equal(t, tc.wantTarget, svcb.Target)
assert.Equal(t, tc.host, svcb.Hdr.Name)
}
})
}
}
func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
testCases := []struct {
name string
host string
qtype dnsmsg.RRType
fltGrpBlocked bool
hasProf bool
profBlocked bool
wantRCode dnsmsg.RCode
}{{
name: "private_relay_blocked_by_fltgrp",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeNameError,
}, {
name: "no_special_domain",
host: "www.example.com",
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "no_private_relay_qtype",
host: applePrivateRelayMaskHost,
qtype: dns.TypeTXT,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "private_relay_blocked_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeNameError,
}, {
name: "private_relay_allowed_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: true,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "private_relay_allowed_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "private_relay_blocked_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeNameError,
}, {
name: "firefox_canary_allowed_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_allowed_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_blocked_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeRefused,
}, {
name: "firefox_canary_blocked_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeRefused,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
if tc.wantRCode != dns.RcodeSuccess {
return errors.Error("unexpectedly reached handler")
}
resp := (&dns.Msg{}).SetReply(req)
return rw.WriteMsg(ctx, req, resp)
})
onProfileByLinkedIP := func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
if !tc.hasProf {
return nil, nil, profiledb.ErrDeviceNotFound
}
prof := &agd.Profile{
BlockPrivateRelay: tc.profBlocked,
BlockFirefoxCanary: tc.profBlocked,
}
return prof, &agd.Device{}, nil
}
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: onProfileByLinkedIP,
}
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
}
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) {
panic("not implemented")
},
}
mw := &initMw{
messages: agdtest.NewConstructor(),
fltGrp: &agd.FilteringGroup{
BlockPrivateRelay: tc.fltGrpBlocked,
BlockFirefoxCanary: tc.fltGrpBlocked,
},
srvGrp: &agd.ServerGroup{},
srv: &agd.Server{
Protocol: agd.ProtoDNS,
LinkedIPEnabled: true,
},
db: db,
geoIP: geoIP,
errColl: errColl,
}
h := mw.Wrap(handler)
ctx := context.Background()
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
req := &dns.Msg{
Question: []dns.Question{{
Name: dns.Fqdn(tc.host),
Qtype: tc.qtype,
Qclass: dns.ClassINET,
}},
}
err := h.ServeDNS(ctx, rw, req)
require.NoError(t, err)
resp := rw.Msg()
require.NotNil(t, resp)
assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
})
}
}
var errSink error
func BenchmarkInitMw_Wrap(b *testing.B) {
const devIDTarget = "dns.example.com"
srvGrp := &agd.ServerGroup{
TLS: &agd.TLS{
DeviceIDWildcards: []string{"*." + devIDTarget},
},
DDR: &agd.DDR{
DeviceTargets: stringutil.NewSet(),
PublicTargets: stringutil.NewSet(),
Enabled: true,
},
Name: agd.ServerGroupName("test_server_group"),
Servers: []*agd.Server{{
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
}, {
AddrPort: netip.MustParseAddrPort("4.3.2.1:12345"),
}},
Protocol: agd.ProtoDoT,
}},
}
messages := agdtest.NewConstructor()
ipv4Hints := []netip.Addr{srvGrp.Servers[0].BindData[0].AddrPort.Addr()}
ipv6Hints := []netip.Addr{netip.MustParseAddr("2001::1234")}
srvGrp.DDR.DeviceTargets.Add(devIDTarget)
srvGrp.DDR.DeviceRecordTemplates = []*dns.SVCB{
messages.NewDDRTemplate(agd.ProtoDoH, devIDTarget, "/dns", ipv4Hints, ipv6Hints, 443, 1),
messages.NewDDRTemplate(agd.ProtoDoT, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
messages.NewDDRTemplate(agd.ProtoDoQ, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
}
mw := &initMw{
messages: messages,
fltGrp: &agd.FilteringGroup{},
srvGrp: srvGrp,
srv: srvGrp.Servers[0],
geoIP: &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
},
errColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
}
prof := &agd.Profile{}
dev := &agd.Device{}
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
TLSServerName: testDeviceID + ".dns.example.com",
})
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.net",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
resp := new(dns.Msg).SetReply(req)
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
return rw.WriteMsg(ctx, req, resp)
})
handler = mw.Wrap(handler)
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
mw.db = &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
b.Run("success", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = handler.ServeDNS(ctx, rw, req)
}
assert.NoError(b, errSink)
})
mw.db = &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
b.Run("not_found", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = handler.ServeDNS(ctx, rw, req)
}
assert.NoError(b, errSink)
})
ffReq := &dns.Msg{
Question: []dns.Question{{
Name: "use-application-dns.net.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
mw.db = &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
b.Run("firefox_canary", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = handler.ServeDNS(ctx, rw, ffReq)
}
assert.NoError(b, errSink)
})
ddrReq := &dns.Msg{
Question: []dns.Question{{
// Check the worst case when wildcards are checked.
Name: "_dns." + testDeviceID + ".dns.example.com.",
Qtype: dns.TypeSVCB,
Qclass: dns.ClassINET,
}},
}
devWithID := &agd.Device{
ID: testDeviceID,
}
mw.db = &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, devWithID, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
b.Run("ddr", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = handler.ServeDNS(ctx, rw, ddrReq)
}
assert.NoError(b, errSink)
})
}

View File

@ -0,0 +1,80 @@
// Package accessmw contains the access middleware of the AdGuard DNS server.
// It filters out the domain scanners and other requests by specified AdBlock
// rules and IP subnets.
package accessmw
import (
"context"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// type check
var _ dnsserver.Middleware = (*Middleware)(nil)
// Middleware is the access middleware of the AdGuard DNS server.
type Middleware struct {
accessManager access.Interface
}
// Config is the configuration structure for the access middleware. All fields
// must be non-nil.
type Config struct {
AccessManager access.Interface
}
// New returns a new access middleware. c must not be nil.
func New(c *Config) (mw *Middleware) {
return &Middleware{
accessManager: c.AccessManager,
}
}
// type check
var _ dnsserver.Middleware = (*Middleware)(nil)
// Wrap implements the [dnsserver.Middleware] interface for *Middleware
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
defer func() { err = errors.Annotate(err, "access mw: %w") }()
rAddr := netutil.NetAddrToAddrPort(rw.RemoteAddr()).Addr()
if blocked, _ := mw.accessManager.IsBlockedIP(rAddr); blocked {
metrics.AccessBlockedForSubnetTotal.Inc()
return nil
}
// Assume that module dnsserver has already validated that the request
// always has exactly one question for us.
q := req.Question[0]
if blocked := mw.accessManager.IsBlockedHost(normalizeDomain(q.Name), q.Qtype); blocked {
metrics.AccessBlockedForHostTotal.Inc()
return nil
}
return next.ServeDNS(ctx, rw, req)
}
return dnsserver.HandlerFunc(f)
}
// normalizeDomain returns a lowercased version of the host without the final
// dot, unless the host is ".", in which case it returns the unchanged host.
// That is the special case to allow matching queries like:
//
// dig IN NS '.'
func normalizeDomain(host string) (norm string) {
if host == "." {
return host
}
return agdnet.NormalizeDomain(host)
}

View File

@ -0,0 +1,135 @@
package accessmw_test
import (
"context"
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/accessmw"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMiddleware_Wrap(t *testing.T) {
am, accessErr := access.New([]string{
"block.test",
"UPPERCASE.test",
"||block_aaaa.test^$dnstype=AAAA",
}, []string{
"1.1.1.1",
"2.2.2.0/8",
})
require.NoError(t, accessErr)
amw := accessmw.New(&accessmw.Config{
AccessManager: am,
})
testCases := []struct {
wantResp assert.BoolAssertionFunc
name string
host string
ip net.IP
qtype uint16
}{{
ip: net.IP{1, 1, 1, 0},
name: "pass_ip",
host: "pass.test",
qtype: dns.TypeA,
wantResp: assert.True,
}, {
name: "block_ip",
ip: net.IP{1, 1, 1, 1},
host: "pass.test",
qtype: dns.TypeA,
wantResp: assert.False,
}, {
name: "pass_subnet",
ip: net.IP{1, 2, 2, 2},
host: "pass.test",
qtype: dns.TypeA,
wantResp: assert.True,
}, {
name: "block_subnet",
ip: net.IP{2, 2, 2, 2},
host: "pass.test",
qtype: dns.TypeA,
wantResp: assert.False,
}, {
wantResp: assert.True,
name: "pass_domain",
host: "pass.test",
qtype: dns.TypeA,
}, {
wantResp: assert.False,
name: "blocked_domain_A",
host: "block.test",
qtype: dns.TypeA,
}, {
wantResp: assert.False,
name: "blocked_domain_HTTPS",
host: "block.test",
qtype: dns.TypeHTTPS,
}, {
wantResp: assert.False,
name: "uppercase_domain",
host: "uppercase.test",
qtype: dns.TypeHTTPS,
}, {
wantResp: assert.True,
name: "pass_qt",
host: "block_aaaa.test",
qtype: dns.TypeA,
}, {
wantResp: assert.False,
name: "block_qt",
host: "block_aaaa.test",
qtype: dns.TypeAAAA,
}}
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
q *dns.Msg,
) (_ error) {
resp := dnsservertest.NewResp(
dns.RcodeSuccess,
q,
dnsservertest.SectionAnswer{
dnsservertest.NewA("test.domain", 0, netip.MustParseAddr("5.5.5.5")),
},
)
err := rw.WriteMsg(ctx, q, resp)
if err != nil {
return err
}
return nil
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, &net.TCPAddr{IP: tc.ip, Port: 5357})
req := &dns.Msg{
Question: []dns.Question{{
Name: tc.host,
Qtype: tc.qtype,
Qclass: dns.ClassINET,
}},
}
h := amw.Wrap(handler)
err := h.ServeDNS(context.Background(), rw, req)
require.NoError(t, err)
resp := rw.Msg()
tc.wantResp(t, resp != nil)
})
}
}

View File

@ -0,0 +1,44 @@
// Package dnssvctest contains common constants and utilities for the internal
// DNS-service packages.
package dnssvctest
import (
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
)
// Timeout is the common timeout for tests.
const Timeout time.Duration = 1 * time.Second
// String representations of the common IDs for tests.
const (
DeviceIDStr = "dev1234"
ProfileIDStr = "prof1234"
)
// DeviceID is the common device ID for tests.
const DeviceID agd.DeviceID = DeviceIDStr
// ProfileID is the common profile ID for tests.
const ProfileID agd.ProfileID = ProfileIDStr
// Common addresses for tests.
var (
ClientIP = net.IP{1, 2, 3, 4}
RemoteAddr = &net.TCPAddr{
IP: ClientIP,
Port: 12345,
}
ClientAddrPort = RemoteAddr.AddrPort()
ClientAddr = ClientAddrPort.Addr()
ServerAddr = netip.MustParseAddr("5.6.7.8")
LocalAddr = &net.TCPAddr{
IP: ServerAddr.AsSlice(),
Port: 54321,
}
)

View File

@ -1,4 +1,4 @@
package dnssvc
package initial
import (
"context"
@ -16,8 +16,6 @@ import (
"github.com/miekg/dns"
)
// Device ID Extraction
// deviceIDFromClientServerName extracts and validates a device ID. cliSrvName
// is the server name as sent by the client. wildcards are the domain wildcards
// for device ID detection.

View File

@ -1,4 +1,4 @@
package dnssvc
package initial
import (
"context"
@ -9,13 +9,14 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_Wrap_deviceID(t *testing.T) {
func TestDeviceIDFromContext(t *testing.T) {
testCases := []struct {
name string
cliSrvName string
@ -46,8 +47,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_device_id",
cliSrvName: testDeviceID + ".dns.example.com",
wantDeviceID: testDeviceID,
cliSrvName: dnssvctest.DeviceIDStr + ".dns.example.com",
wantDeviceID: dnssvctest.DeviceID,
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
proto: agd.ProtoDoT,
@ -61,7 +62,7 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_deep_subdomain",
cliSrvName: "abc." + testDeviceID + ".dns.example.com",
cliSrvName: "abc." + dnssvctest.DeviceIDStr + ".dns.example.com",
wantDeviceID: "",
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
@ -79,8 +80,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "quic_device_id",
cliSrvName: testDeviceID + ".dns.example.com",
wantDeviceID: testDeviceID,
cliSrvName: dnssvctest.DeviceIDStr + ".dns.example.com",
wantDeviceID: dnssvctest.DeviceID,
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
proto: agd.ProtoDoQ,
@ -93,8 +94,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_device_id_subdomain_wildcard",
cliSrvName: testDeviceID + ".sub.dns.example.com",
wantDeviceID: testDeviceID,
cliSrvName: dnssvctest.DeviceIDStr + ".sub.dns.example.com",
wantDeviceID: dnssvctest.DeviceID,
wantErrMsg: "",
wildcards: []string{
"*.dns.example.com",
@ -117,7 +118,7 @@ func TestService_Wrap_deviceID(t *testing.T) {
}
}
func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
func TestDeviceIDFromContext_https(t *testing.T) {
testCases := []struct {
name string
path string
@ -135,13 +136,13 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
wantErrMsg: "",
}, {
name: "device_id",
path: "/dns-query/" + testDeviceID,
wantDeviceID: testDeviceID,
path: "/dns-query/" + dnssvctest.DeviceIDStr,
wantDeviceID: dnssvctest.DeviceID,
wantErrMsg: "",
}, {
name: "device_id_slash",
path: "/dns-query/" + testDeviceID + "/",
wantDeviceID: testDeviceID,
path: "/dns-query/" + dnssvctest.DeviceIDStr + "/",
wantDeviceID: dnssvctest.DeviceID,
wantErrMsg: "",
}, {
name: "bad_url",
@ -150,10 +151,10 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
wantErrMsg: `http url device id check: bad path "/foo"`,
}, {
name: "extra",
path: "/dns-query/" + testDeviceID + "/foo",
path: "/dns-query/" + dnssvctest.DeviceIDStr + "/foo",
wantDeviceID: "",
wantErrMsg: `http url device id check: bad path "/dns-query/` + testDeviceID + `/foo": ` +
`extra parts`,
wantErrMsg: `http url device id check: bad path "/dns-query/` + dnssvctest.DeviceIDStr +
`/foo": extra parts`,
}, {
name: "bad_device_id",
path: "/dns-query/!!!",
@ -186,7 +187,7 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
t.Run("domain_name", func(t *testing.T) {
u := &url.URL{
Scheme: "https",
Host: testDeviceID + ".dns.example.com",
Host: dnssvctest.DeviceIDStr + ".dns.example.com",
Path: "/dns-query",
}
@ -202,11 +203,11 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
deviceID, err := deviceIDFromContext(ctx, proto, []string{"*.dns.example.com"})
require.NoError(t, err)
assert.Equal(t, agd.DeviceID(testDeviceID), deviceID)
assert.Equal(t, agd.DeviceID(dnssvctest.DeviceID), deviceID)
})
}
func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
func TestDeviceIDFromEDNS(t *testing.T) {
testCases := []struct {
name string
opt dns.EDNS0

View File

@ -0,0 +1,284 @@
// Package initial contains the initial, outermost (except for ratelimit)
// middleware of the AdGuard DNS server. It filters out the Firefox canary
// domain logic, sets and resets the AD bit for further processing, as well as
// puts as much information as it can into the context and request info.
package initial
import (
"context"
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// Middleware is the initial middleware of the AdGuard DNS server. This
// middleware must be the most outer middleware apart from the ratelimit one.
type Middleware struct {
// messages is used to build the responses specific for the request's
// context.
messages *dnsmsg.Constructor
// fltGrp is the filtering group to which srv belongs.
fltGrp *agd.FilteringGroup
// srvGrp is the server group to which srv belongs.
srvGrp *agd.ServerGroup
// srv is the current server which serves the request.
srv *agd.Server
// pool is the pool of [agd.RequestInfo] values.
pool *agdsync.TypedPool[agd.RequestInfo]
// db is the database of user profiles and devices.
db profiledb.Interface
// geoIP detects the location of the request source.
geoIP geoip.Interface
// errColl collects and reports the errors considered non-critical.
errColl agd.ErrorCollector
}
// Config is the configuration structure for the initial middleware. All fields
// must be non-nil.
type Config struct {
// messages is used to build the responses specific for a request's context.
Messages *dnsmsg.Constructor
// FilteringGroup is the filtering group to which Server belongs.
FilteringGroup *agd.FilteringGroup
// ServerGroup is the server group to which Server belongs.
ServerGroup *agd.ServerGroup
// Server is the current server which serves the request.
Server *agd.Server
// DB is the database of user profiles and devices.
ProfileDB profiledb.Interface
// GeoIP detects the location of the request source.
GeoIP geoip.Interface
// ErrColl collects and reports the errors considered non-critical.
ErrColl agd.ErrorCollector
}
// New returns a new initial middleware. c must not be nil.
func New(c *Config) (mw *Middleware) {
return &Middleware{
messages: c.Messages,
fltGrp: c.FilteringGroup,
srvGrp: c.ServerGroup,
srv: c.Server,
pool: agdsync.NewTypedPool(func() (v *agd.RequestInfo) {
return &agd.RequestInfo{}
}),
db: c.ProfileDB,
geoIP: c.GeoIP,
errColl: c.ErrColl,
}
}
// type check
var _ dnsserver.Middleware = (*Middleware)(nil)
// Wrap implements the [dnsserver.Middleware] interface for *Middleware
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
defer func() { err = errors.Annotate(err, "init mw: %w") }()
// Save the actual value of the request AD and DO bits and set the AD
// bit in the request to true, so that the upstream validates the data
// and caches the actual value of the response AD bit. Restore it
// later, depending on the request and response data.
reqAD := req.AuthenticatedData
reqDO := dnsmsg.IsDO(req)
req.AuthenticatedData = true
// Assume that module dnsserver has already validated that the request
// always has exactly one question for us.
q := req.Question[0]
qt := q.Qtype
cl := q.Qclass
// Get the request's information, such as GeoIP data and user profiles.
ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), q.Name, qt, cl)
if err != nil {
// Don't wrap the error, because this is the main flow, and there is
// already [errors.Annotate] here.
return mw.processReqInfoErr(ctx, rw, req, err)
}
defer mw.pool.Put(ri)
if specHdlr, name := mw.reqInfoSpecialHandler(ri, cl); specHdlr != nil {
optlog.Debug1("init mw: got req-info special handler %s", name)
// Don't wrap the error, because it's informative enough as is, and
// because if handled is true, the main flow terminates here.
return specHdlr(ctx, rw, req, ri)
}
ctx = agd.ContextWithRequestInfo(ctx, ri)
// Record the response, restore the AD bit value in both the request and
// the response, and write the response.
nwrw := internal.MakeNonWriter(rw)
err = next.ServeDNS(ctx, nwrw, req)
if err != nil {
// Don't wrap the error, because this is the main flow, and there is
// already errors.Annotate here.
return err
}
resp := nwrw.Msg()
// Following RFC 6840, set the AD bit in the response only when the
// response is authenticated, and the request contained either a set DO
// bit or a set AD bit.
//
// See https://datatracker.ietf.org/doc/html/rfc6840#section-5.8.
resp.AuthenticatedData = resp.AuthenticatedData && (reqAD || reqDO)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing resp: %w")
}
return dnsserver.HandlerFunc(f)
}
// newRequestInfo returns the new request information structure using the
// middleware's configuration and values from ctx.
func (mw *Middleware) newRequestInfo(
ctx context.Context,
req *dns.Msg,
laddr net.Addr,
raddr net.Addr,
fqdn string,
qt dnsmsg.RRType,
cl dnsmsg.Class,
) (ri *agd.RequestInfo, err error) {
ri = mw.pool.Get()
// Use ri as an argument here to evaluate and save the non-nil value of ri
// and prevent returns with an error from overwriting ri with nil.
defer func(fromPool *agd.RequestInfo) {
if err != nil {
mw.pool.Put(fromPool)
}
}(ri)
// Clear all fields that must be set later.
ri.Device = nil
ri.Profile = nil
ri.ECS = nil
ri.Location = nil
// Put the host, server, and client IP data into the request information
// immediately.
ri.FilteringGroup = mw.fltGrp
ri.Messages = mw.messages
ri.RemoteIP = netutil.NetAddrToAddrPort(raddr).Addr()
ri.ServerGroup = mw.srvGrp.Name
ri.Server = mw.srv.Name
ri.Host = agdnet.NormalizeDomain(fqdn)
ri.QType = qt
ri.QClass = cl
// As an optimization, put the request ID closer to the top of the context
// stack.
ri.ID, _ = agd.RequestIDFromContext(ctx)
// Add the GeoIP information, if any.
err = mw.addLocation(ctx, ri, req)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// Add the profile information, if any.
localIP := netutil.NetAddrToAddrPort(laddr).Addr()
err = mw.addProfile(ctx, ri, req, localIP)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return ri, nil
}
// addLocation adds GeoIP location information about the client's remote address
// as well as the EDNS Client Subnet information, if there is one, to ri. err
// is not nil only if req contains a malformed EDNS Client Subnet option.
func (mw *Middleware) addLocation(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) {
ri.Location = mw.locationData(ctx, ri.RemoteIP, "client")
ecs, scope, err := dnsmsg.ECSFromMsg(req)
if err != nil {
return fmt.Errorf("adding ecs info: %w", err)
} else if ecs != (netip.Prefix{}) {
ri.ECS = &agd.ECS{
Location: mw.locationData(ctx, ecs.Addr(), "ecs"),
Subnet: ecs,
Scope: scope,
}
}
return nil
}
// locationData returns the GeoIP location information about the IP address.
// typ is the type of data being requested for error reporting and logging.
func (mw *Middleware) locationData(ctx context.Context, ip netip.Addr, typ string) (l *agd.Location) {
l, err := mw.geoIP.Data("", ip)
if err != nil {
// Consider GeoIP errors non-critical. Report and go on.
agd.Collectf(ctx, mw.errColl, "init mw: getting geoip for %s ip: %w", typ, err)
}
if l == nil {
optlog.Debug2("init mw: no geoip for %s ip %s", typ, ip)
} else {
optlog.Debug4("init mw: found country/asn %q/%d for %s ip %s", l.Country, l.ASN, typ, ip)
}
return l
}
// processReqInfoErr processes the error returned by [Middleware.newRequestInfo]
// and returns the properly handled and/or wrapped error.
func (mw *Middleware) processReqInfoErr(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
origErr error,
) (err error) {
var ecsErr dnsmsg.BadECSError
if errors.As(origErr, &ecsErr) {
// We've got a bad ECS option. Log and respond with a FORMERR
// immediately.
optlog.Debug1("init mw: %s", origErr)
writeErr := rw.WriteMsg(ctx, req, mw.messages.NewMsgFORMERR(req))
writeErr = errors.Annotate(writeErr, "writing formerr resp: %w")
return errors.WithDeferred(origErr, writeErr)
}
return origErr
}

View File

@ -0,0 +1,660 @@
package initial_test
import (
"context"
"crypto/tls"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
func TestMiddleware_Wrap(t *testing.T) {
const (
resolverName = "dns.example.com"
resolverFQDN = resolverName + "."
targetWithID = dnssvctest.DeviceIDStr + ".d." + resolverName + "."
ddrFQDN = initial.DDRDomain + "."
dohPath = "/dns-query"
)
testDevice := &agd.Device{ID: dnssvctest.DeviceID}
srvs := map[agd.ServerName]*agd.Server{
"dot": {
TLS: &tls.Config{},
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
}},
Protocol: agd.ProtoDoT,
},
"doh": {
TLS: &tls.Config{},
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("5.6.7.8:54321"),
}},
Protocol: agd.ProtoDoH,
},
"dns": {
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
}},
Protocol: agd.ProtoDNS,
LinkedIPEnabled: true,
},
"dns_nolink": {
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
}},
Protocol: agd.ProtoDNS,
},
}
srvGrp := &agd.ServerGroup{
TLS: &agd.TLS{
DeviceIDWildcards: []string{"*.d." + resolverName},
},
DDR: &agd.DDR{
DeviceTargets: stringutil.NewSet(),
PublicTargets: stringutil.NewSet(),
Enabled: true,
},
Name: agd.ServerGroupName("test_server_group"),
Servers: maps.Values(srvs),
}
srvGrp.DDR.DeviceTargets.Add("d." + resolverName)
srvGrp.DDR.PublicTargets.Add(resolverName)
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (_ netip.Prefix, _ error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
}
messages := agdtest.NewConstructor()
pubSVCBTmpls := []*dns.SVCB{
messages.NewDDRTemplate(agd.ProtoDoH, resolverName, dohPath, nil, nil, 443, 1),
messages.NewDDRTemplate(agd.ProtoDoT, resolverName, "", nil, nil, 853, 1),
messages.NewDDRTemplate(agd.ProtoDoQ, resolverName, "", nil, nil, 853, 1),
}
devSVCBTmpls := []*dns.SVCB{
messages.NewDDRTemplate(agd.ProtoDoH, "d."+resolverName, dohPath, nil, nil, 443, 1),
messages.NewDDRTemplate(agd.ProtoDoT, "d."+resolverName, "", nil, nil, 853, 1),
messages.NewDDRTemplate(agd.ProtoDoQ, "d."+resolverName, "", nil, nil, 853, 1),
}
srvGrp.DDR.PublicRecordTemplates = pubSVCBTmpls
srvGrp.DDR.DeviceRecordTemplates = devSVCBTmpls
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
_ context.Context,
_ dnsserver.ResponseWriter,
_ *dns.Msg,
) (_ error) {
// Make sure we haven't reached the following middleware.
panic("not implemented")
})
testCases := []struct {
device *agd.Device
name string
srv *agd.Server
host string
wantTarget string
wantNum int
qtype uint16
}{{
device: testDevice,
name: "id",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "id_specific",
srv: srvs["dot"],
host: initial.DDRLabel + "." + targetWithID,
wantTarget: targetWithID,
wantNum: len(devSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "no_id",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: resolverFQDN,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "linked_ip",
srv: srvs["dns"],
host: ddrFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "no_linked_ip",
srv: srvs["dns_nolink"],
host: ddrFQDN,
wantTarget: resolverFQDN,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: testDevice,
name: "public_resolver_name",
srv: srvs["dot"],
host: initial.DDRLabel + "." + resolverFQDN,
wantTarget: targetWithID,
wantNum: len(pubSVCBTmpls),
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "arpa_not_ddr_svcb",
srv: srvs["dot"],
host: dns.Fqdn(
initial.DDRLabel + ".something.else." + initial.ResolverARPADomain,
),
wantTarget: "",
wantNum: 0,
qtype: dns.TypeSVCB,
}, {
device: nil,
name: "arpa_ddr_not_svcb",
srv: srvs["dot"],
host: ddrFQDN,
wantTarget: "",
wantNum: 0,
qtype: dns.TypeA,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return &agd.Profile{}, tc.device, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return &agd.Profile{}, tc.device, nil
},
}
mw := initial.New(&initial.Config{
Messages: agdtest.NewConstructor(),
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: tc.srv,
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
TLSServerName: srvNameForProto(tc.device, resolverName, tc.srv.Protocol),
})
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
req := &dns.Msg{
Question: []dns.Question{{
Name: tc.host,
Qtype: tc.qtype,
Qclass: dns.ClassINET,
}},
}
h := mw.Wrap(handler)
err := h.ServeDNS(ctx, rw, req)
require.NoError(t, err)
resp := rw.Msg()
require.NotNil(t, resp)
if tc.wantNum == 0 {
assert.Empty(t, resp.Answer)
return
}
assert.Len(t, resp.Answer, tc.wantNum)
for _, rr := range resp.Answer {
svcb := testutil.RequireTypeAssert[*dns.SVCB](t, rr)
assert.Equal(t, tc.wantTarget, svcb.Target)
assert.Equal(t, tc.host, svcb.Hdr.Name)
}
})
}
}
// srvNameForProto returns a client's TLS server name based on the protocol and
// other data.
func srvNameForProto(dev *agd.Device, resolverName string, proto agd.Protocol) (srvName string) {
switch proto {
case agd.ProtoDoT, agd.ProtoDoQ:
srvName = resolverName
if dev != nil {
srvName = string(dev.ID) + ".d." + srvName
}
default:
// Go on.
}
return srvName
}
func TestMiddleware_Wrap_error(t *testing.T) {
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
_ context.Context,
_ dnsserver.ResponseWriter,
_ *dns.Msg,
) (_ error) {
// Make sure we haven't reached the following middleware.
panic("not implemented")
})
srvGrp := &agd.ServerGroup{
Name: agd.ServerGroupName("test_server_group"),
}
srv := &agd.Server{
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("1.2.3.4:53"),
}},
Protocol: agd.ProtoDNS,
LinkedIPEnabled: true,
}
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (_ netip.Prefix, _ error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
}
const testError errors.Error = errors.Error("test error")
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, testError
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
mw := initial.New(&initial.Config{
Messages: agdtest.NewConstructor(),
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: srv,
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
ctx := context.Background()
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
req := &dns.Msg{
Question: []dns.Question{{
Name: "www.example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
h := mw.Wrap(handler)
err := h.ServeDNS(ctx, rw, req)
assert.ErrorIs(t, err, testError)
}
var errSink error
func BenchmarkMiddleware_Wrap(b *testing.B) {
const devIDTarget = "dns.example.com"
srvGrp := &agd.ServerGroup{
TLS: &agd.TLS{
DeviceIDWildcards: []string{"*." + devIDTarget},
},
DDR: &agd.DDR{
DeviceTargets: stringutil.NewSet(),
PublicTargets: stringutil.NewSet(),
Enabled: true,
},
Name: agd.ServerGroupName("test_server_group"),
Servers: []*agd.Server{{
BindData: []*agd.ServerBindData{{
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
}, {
AddrPort: netip.MustParseAddrPort("4.3.2.1:12345"),
}},
Protocol: agd.ProtoDoT,
}},
}
messages := agdtest.NewConstructor()
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
return nil, nil
},
}
ipv4Hints := []netip.Addr{srvGrp.Servers[0].BindData[0].AddrPort.Addr()}
ipv6Hints := []netip.Addr{netip.MustParseAddr("2001::1234")}
srvGrp.DDR.DeviceTargets.Add(devIDTarget)
srvGrp.DDR.DeviceRecordTemplates = []*dns.SVCB{
messages.NewDDRTemplate(agd.ProtoDoH, devIDTarget, "/dns", ipv4Hints, ipv6Hints, 443, 1),
messages.NewDDRTemplate(agd.ProtoDoT, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
messages.NewDDRTemplate(agd.ProtoDoQ, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
}
prof := &agd.Profile{}
dev := &agd.Device{}
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
TLSServerName: dnssvctest.DeviceIDStr + ".dns.example.com",
})
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.net",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
resp := new(dns.Msg).SetReply(req)
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
return rw.WriteMsg(ctx, req, resp)
})
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
b.Run("success", func(b *testing.B) {
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
mw := initial.New(&initial.Config{
Messages: messages,
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: srvGrp.Servers[0],
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
h := mw.Wrap(handler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = h.ServeDNS(ctx, rw, req)
}
assert.NoError(b, errSink)
})
b.Run("not_found", func(b *testing.B) {
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
mw := initial.New(&initial.Config{
Messages: messages,
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: srvGrp.Servers[0],
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
h := mw.Wrap(handler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = h.ServeDNS(ctx, rw, req)
}
assert.NoError(b, errSink)
})
b.Run("firefox_canary", func(b *testing.B) {
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
ffReq := &dns.Msg{
Question: []dns.Question{{
Name: "use-application-dns.net.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
mw := initial.New(&initial.Config{
Messages: messages,
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: srvGrp.Servers[0],
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
h := mw.Wrap(handler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = h.ServeDNS(ctx, rw, ffReq)
}
assert.NoError(b, errSink)
})
b.Run("ddr", func(b *testing.B) {
devWithID := &agd.Device{
ID: dnssvctest.DeviceID,
}
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return prof, devWithID, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
ddrReq := &dns.Msg{
Question: []dns.Question{{
// Check the worst case when wildcards are checked.
Name: "_dns." + dnssvctest.DeviceIDStr + ".dns.example.com.",
Qtype: dns.TypeSVCB,
Qclass: dns.ClassINET,
}},
}
mw := initial.New(&initial.Config{
Messages: messages,
FilteringGroup: &agd.FilteringGroup{},
ServerGroup: srvGrp,
Server: srvGrp.Servers[0],
ProfileDB: db,
GeoIP: geoIP,
ErrColl: &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
},
})
h := mw.Wrap(handler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
errSink = h.ServeDNS(ctx, rw, ddrReq)
}
assert.NoError(b, errSink)
})
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkMiddleware_Wrap/success-16 1970464 735.8 ns/op 72 B/op 2 allocs/op
// BenchmarkMiddleware_Wrap/not_found-16 1469100 715.9 ns/op 48 B/op 1 allocs/op
// BenchmarkMiddleware_Wrap/firefox_canary-16 1644410 861.9 ns/op 72 B/op 2 allocs/op
// BenchmarkMiddleware_Wrap/ddr-16 252656 4810 ns/op 1408 B/op 45 allocs/op
}

View File

@ -0,0 +1,110 @@
package initial
import (
"context"
"fmt"
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
// addProfile adds profile and device information, if any, to the request
// information.
func (mw *Middleware) addProfile(
ctx context.Context,
ri *agd.RequestInfo,
req *dns.Msg,
localIP netip.Addr,
) (err error) {
defer func() { err = errors.Annotate(err, "getting profile from req: %w") }()
var id agd.DeviceID
if p := mw.srv.Protocol; p.IsStdEncrypted() {
// Assume that mw.srvGrp.TLS is non-nil if p.IsStdEncrypted() is true.
wildcards := mw.srvGrp.TLS.DeviceIDWildcards
id, err = deviceIDFromContext(ctx, mw.srv.Protocol, wildcards)
} else if p == agd.ProtoDNS {
id, err = deviceIDFromEDNS(req)
} else {
// No DeviceID for DNSCrypt yet.
return nil
}
if err != nil {
return err
}
optlog.Debug3("init mw: got device id %q, raddr %s, and laddr %s", id, ri.RemoteIP, localIP)
prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id)
if err != nil {
if !errors.Is(err, profiledb.ErrDeviceNotFound) {
// Very unlikely, since there is only one error type currently
// returned from the default profile DB.
return fmt.Errorf("unexpected profiledb error: %w", err)
}
optlog.Debug1("init mw: profile or device not found: %s", err)
} else if prof.Deleted {
optlog.Debug1("init mw: profile %s is deleted", prof.ID)
} else {
optlog.Debug3("init mw: found profile %s and device %s by %s", prof.ID, dev.ID, byWhat)
ri.Device, ri.Profile = dev, prof
ri.Messages = dnsmsg.NewConstructor(prof.BlockingMode.Mode, prof.FilteredResponseTTL)
}
return nil
}
// Constants for the parameter by which a device has been found.
const (
byDeviceID = "device id"
byDedicatedIP = "dedicated ip"
byLinkedIP = "linked ip"
)
// profile finds the profile by the client data.
func (mw *Middleware) profile(
ctx context.Context,
localIP netip.Addr,
remoteIP netip.Addr,
id agd.DeviceID,
) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) {
if id != "" {
prof, dev, err = mw.db.ProfileByDeviceID(ctx, id)
if err != nil {
return nil, nil, "", err
}
return prof, dev, byDeviceID, nil
}
if !mw.srv.LinkedIPEnabled {
optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name)
return nil, nil, "", profiledb.ErrDeviceNotFound
} else if p := mw.srv.Protocol; p != agd.ProtoDNS {
optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p)
return nil, nil, "", profiledb.ErrDeviceNotFound
}
byWhat = byDedicatedIP
prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP)
if errors.Is(err, profiledb.ErrDeviceNotFound) {
byWhat = byLinkedIP
prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP)
}
if err != nil {
return nil, nil, "", err
}
return prof, dev, byWhat, nil
}

Some files were not shown because too many files have changed in this diff Show More