mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
Sync v2.3
This commit is contained in:
parent
1cc340ddb1
commit
cfb4caf935
44
CHANGELOG.md
44
CHANGELOG.md
@ -11,7 +11,49 @@ The format is **not** based on [Keep a Changelog][kec], since the project
|
||||
|
||||
|
||||
|
||||
## AGDNS-1537 / Build 580
|
||||
## AGDNS-1607 / Build 617
|
||||
|
||||
* New configuration `access` has been added, it has an a list of AdBlock rules
|
||||
to block requests, and a lists of client subnets to block access from.
|
||||
Example configuration:
|
||||
|
||||
```yaml
|
||||
access:
|
||||
blocked_question_domains:
|
||||
- 'test.org'
|
||||
- '||example.org^$dnstype=AAAA'
|
||||
blocked_client_subnets:
|
||||
- '1.1.1.1'
|
||||
- '2.2.2.0/8'
|
||||
```
|
||||
|
||||
|
||||
|
||||
## AGDNS-1619 / Build 611
|
||||
|
||||
* Added a new metric `bill_stat_upload_duration` that counts the duration of
|
||||
billing statistics upload.
|
||||
* The environment variable `BILLSTAT_URL`, which describes the endpoint for
|
||||
backend billing statistics uploader API, now supports GRPC endpoints.
|
||||
|
||||
|
||||
|
||||
## AGDNS-1600 / Build 582
|
||||
|
||||
* The environment variable `PROFILES_CACHE_PATH` no longer supports JSON
|
||||
files. Use protobuf with `.pb` extension instead. The default value has
|
||||
been changed to `./profilecache.pb`.
|
||||
|
||||
|
||||
|
||||
## AGDNS-1539 / Build 581
|
||||
|
||||
* The environment variable `PROFILES_URL`, which describes the endpoint for
|
||||
profiles sync API, now supports GRPC endpoints.
|
||||
|
||||
|
||||
|
||||
## AGDNS-1579 / Build 580
|
||||
|
||||
* The optional property `bind_interfaces` of `server_groups.*.servers`
|
||||
objects has been changed, property `subnet` is now an array and has been
|
||||
|
8
Makefile
8
Makefile
@ -51,17 +51,11 @@ test: go-test
|
||||
go-bench: ; $(ENV) "$(SHELL)" ./scripts/make/go-bench.sh
|
||||
go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh
|
||||
go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh
|
||||
go-gen: ; $(ENV) "$(SHELL)" ./scripts/make/go-gen.sh
|
||||
go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
|
||||
go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
|
||||
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
|
||||
|
||||
go-gen:
|
||||
cd ./internal/agd/ && "$(GO.MACRO)" run ./country_generate.go
|
||||
cd ./internal/geoip/ && "$(GO.MACRO)" run ./asntops_generate.go
|
||||
|
||||
cd ./internal/profiledb/internal/filecachepb/ &&\
|
||||
protoc --go_opt=paths=source_relative --go_out=. ./filecache.proto
|
||||
|
||||
go-check: go-tools go-lint go-test
|
||||
|
||||
# A quick check to make sure that all operating systems relevant to the
|
||||
|
@ -52,6 +52,15 @@ ratelimit:
|
||||
stop: 1000
|
||||
resume: 800
|
||||
|
||||
# Access settings.
|
||||
access:
|
||||
# Domains to block.
|
||||
blocked_question_domains:
|
||||
- 'test.org'
|
||||
# Client subnets to block.
|
||||
blocked_client_subnets:
|
||||
- '1.2.3.0/8'
|
||||
|
||||
# DNS cache configuration.
|
||||
cache:
|
||||
# The type of cache to use. Can be 'simple' (a simple LRU cache) or 'ecs'
|
||||
|
@ -32,6 +32,7 @@ configuration file with comments.
|
||||
* [Servers](#server_groups-*-servers-*)
|
||||
* [Connectivity check](#connectivity-check)
|
||||
* [Network settings](#network)
|
||||
* [Access settings](#access)
|
||||
* [Additional metrics information](#additional_metrics_info)
|
||||
|
||||
[dist]: ../config.dist.yml
|
||||
@ -1093,6 +1094,22 @@ The `network` object has the following properties:
|
||||
|
||||
|
||||
|
||||
## <a href="#access" id="access" name="access">Access settings</a>
|
||||
|
||||
The `access` object has the following properties:
|
||||
|
||||
* <a href="#access-blocked_question_domains" id="access-blocked_question_domains" name="access-blocked_question_domains">`blocked_question_domains`</a>:
|
||||
The list of domains or AdBlock rules to block requests.
|
||||
|
||||
**Examples:** `test.org`, `||example.org^$dnstype=AAAA`.
|
||||
|
||||
* <a href="#access-blocked_client_subnets" id="access-blocked_client_subnets" name="access-blocked_client_subnets">`blocked_client_subnets`</a>:
|
||||
The list of IP addresses or CIDR-es to block.
|
||||
|
||||
**Example:** `127.0.0.1`.
|
||||
|
||||
|
||||
|
||||
## <a href="#additional_metrics_info" id="additional_metrics_info" name="additional_metrics_info">Additional metrics information</a>
|
||||
|
||||
The `additional_metrics_info` object is a map of strings with extra information
|
||||
|
@ -241,27 +241,33 @@ Examples below are for the configuration with the following changes:
|
||||
|
||||
You may also need to remove `probe_ipv6` if your network does not support IPv6.
|
||||
|
||||
If you're using an OS different from Linux, you also need to make these changes:
|
||||
|
||||
* Remove the `interface_listeners` section.
|
||||
* Remove `bind_interfaces` from the `default_dns` server configuration and
|
||||
replace it with `bind_addresses`.
|
||||
|
||||
```sh
|
||||
env \
|
||||
ADULT_BLOCKING_URL='https://raw.githubusercontent.com/ameshkov/PersonalFilters/master/adult_test.txt' \
|
||||
BILLSTAT_URL='PUT BILLSTAT API BACKEND URL HERE' \
|
||||
BLOCKED_SERVICE_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/services.json'\
|
||||
CONSUL_ALLOWLIST_URL='PUT CONSUL ALLOWLIST URL HERE' \
|
||||
ADULT_BLOCKING_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/adult_blocking.txt' \
|
||||
BILLSTAT_URL='https://httpbin.agrd.workers.dev/post' \
|
||||
BLOCKED_SERVICE_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/services.json' \
|
||||
CONSUL_ALLOWLIST_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/consul_allowlist.json' \
|
||||
CONFIG_PATH='./config.yaml' \
|
||||
FILTER_INDEX_URL='https://adguardteam.github.io/HostlistsRegistry/assets/filters.json' \
|
||||
FILTER_CACHE_PATH='./test/cache' \
|
||||
NEW_REG_DOMAINS_URL='PUT NEWLY REGISTERED DOMAINS FILTER URL HERE' \
|
||||
PROFILES_CACHE_PATH='./test/profilecache.json' \
|
||||
PROFILES_URL='PUT PROFILES API BACKEND URL HERE' \
|
||||
SAFE_BROWSING_URL='https://raw.githubusercontent.com/ameshkov/PersonalFilters/master/safebrowsing_test.txt' \
|
||||
NEW_REG_DOMAINS_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/nrd.txt' \
|
||||
PROFILES_CACHE_PATH='./test/profilecache.pb' \
|
||||
PROFILES_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/profiles' \
|
||||
SAFE_BROWSING_URL='https://raw.githubusercontent.com/ameshkov/stuff/master/DNS/safe_browsing.txt' \
|
||||
GENERAL_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/engines_safe_search.txt' \
|
||||
GEOIP_ASN_PATH='./test/GeoLite2-ASN-Test.mmdb' \
|
||||
GEOIP_COUNTRY_PATH='./test/GeoIP2-City-Test.mmdb' \
|
||||
QUERYLOG_PATH='./test/cache/querylog.jsonl' \
|
||||
LINKED_IP_TARGET_URL='PUT LINKED IP TARGET URL HERE' \
|
||||
LINKED_IP_TARGET_URL='https://httpbin.agrd.workers.dev/anything' \
|
||||
LISTEN_ADDR='127.0.0.1' \
|
||||
LISTEN_PORT='8081' \
|
||||
RULESTAT_URL='https://testchrome.adtidy.org/api/1.0/rulestats.html' \
|
||||
RULESTAT_URL='https://httpbin.agrd.workers.dev/post' \
|
||||
SENTRY_DSN='https://1:1@localhost/1' \
|
||||
VERBOSE='1' \
|
||||
YOUTUBE_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/youtube_safe_search.txt' \
|
||||
|
@ -49,9 +49,10 @@ The URL of source list of rules for adult blocking filter.
|
||||
|
||||
## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a>
|
||||
|
||||
The base backend URL for backend billing statistics uploader API. The backend
|
||||
endpoints must reply with a 200 status code on success. See the [external HTTP
|
||||
API requirements section][ext-billstat]
|
||||
The base backend URL for backend billing statistics uploader API. Supports HTTP
|
||||
and GRPC protocols. In case of HTTP the backend endpoint must reply with a 200
|
||||
status code on success. See the [external HTTP API requirements
|
||||
section][ext-billstat].
|
||||
|
||||
**Default:** No default value, the variable is **required.**
|
||||
|
||||
@ -227,13 +228,10 @@ The path to the profile cache file:
|
||||
< /path/to/profilecache.pb
|
||||
```
|
||||
|
||||
* A file with the extension `.json` means that the profiles are cached in the
|
||||
JSON format. This format is **deprecated** and is not recommended.
|
||||
|
||||
The profile cache is read on start and is later updated on every
|
||||
[full refresh][conf-backend-full_refresh_interval].
|
||||
|
||||
**Default:** `./profilecache.json`.
|
||||
**Default:** `./profilecache.pb`.
|
||||
|
||||
[conf-backend-full_refresh_interval]: configuration.md#backend-full_refresh_interval
|
||||
|
||||
@ -241,9 +239,10 @@ The profile cache is read on start and is later updated on every
|
||||
|
||||
## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a>
|
||||
|
||||
The base backend URL for profiles API. The backend endpoints must reply with a
|
||||
200 status code on success. See the [external HTTP API requirements
|
||||
section][ext-profiles].
|
||||
The base backend URL for profiles API. Supports HTTP (`http://` and `https://`)
|
||||
and GRPC (`grpc://` and `grpcs://`) URLs. In case of HTTP the backend endpoint
|
||||
must reply with a 200 status code on success. See the [external API
|
||||
requirements section][ext-profiles].
|
||||
|
||||
**Default:** No default value, the variable is **required.**
|
||||
|
||||
|
@ -29,7 +29,9 @@ document should set the `Server` header in their replies.
|
||||
## <a href="#backend-billstat" id="backend-billstat" name="backend-billstat">Backend Billing Statistics</a>
|
||||
|
||||
This is the service to which the [`BILLSTAT_URL`][env-billstat_url] environment
|
||||
variable points. This service must provide one endpoint:
|
||||
variable points. Supports `http(s):` and `grpc(s)` URLs. In case of GRPC
|
||||
protocol, the service must correspond to `./internal/backendpb/backend.proto`.
|
||||
In case of HTTP protocol this service must provide one endpoint:
|
||||
`POST /dns_api/v1/devices_activity`, it must respond with a `200 OK` response
|
||||
code and accept a JSON document in the following format:
|
||||
|
||||
@ -55,7 +57,9 @@ code and accept a JSON document in the following format:
|
||||
## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend Profiles Service</a>
|
||||
|
||||
This is the service to which the [`PROFILES_URL`][env-profiles_url] environment
|
||||
variable points. This service must provide one endpoint:
|
||||
variable points. Supports `http(s):` and `grpc(s)` URLs. In case of GRPC
|
||||
protocol, the service must correspond to `./internal/backendpb/backend.proto`.
|
||||
In case of HTTP protocol this service must provide one endpoint:
|
||||
`GET /dns_api/v1/settings`, it must respond with a `200 OK` response code and
|
||||
accept a JSON document in the following format:
|
||||
|
||||
|
29
go.mod
29
go.mod
@ -4,8 +4,8 @@ go 1.20
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
|
||||
github.com/AdguardTeam/golibs v0.13.6
|
||||
github.com/AdguardTeam/urlfilter v0.16.2
|
||||
github.com/AdguardTeam/golibs v0.15.0
|
||||
github.com/AdguardTeam/urlfilter v0.17.0
|
||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||
github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc
|
||||
github.com/bluele/gcache v0.0.2
|
||||
@ -19,12 +19,13 @@ require (
|
||||
github.com/prometheus/client_golang v1.15.1
|
||||
github.com/prometheus/client_model v0.4.0
|
||||
github.com/prometheus/common v0.44.0
|
||||
github.com/quic-go/quic-go v0.35.1
|
||||
github.com/quic-go/quic-go v0.38.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
|
||||
golang.org/x/net v0.12.0
|
||||
golang.org/x/sys v0.10.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.14.0
|
||||
golang.org/x/sys v0.11.0
|
||||
golang.org/x/time v0.3.0
|
||||
google.golang.org/grpc v1.56.2
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
@ -40,19 +41,19 @@ require (
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
||||
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.10.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.7.5 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.10.1 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
golang.org/x/tools v0.10.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
|
104
go.sum
104
go.sum
@ -1,12 +1,7 @@
|
||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||
github.com/AdguardTeam/golibs v0.13.6 h1:z/0Q25pRLdaQxtoxvfSaooz5mdv8wj0R8KREj54q8yQ=
|
||||
github.com/AdguardTeam/golibs v0.13.6/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||
github.com/AdguardTeam/urlfilter v0.16.2 h1:k9m9dUYVJ3sTswYa2/ukVNjicfGcz0oqFDO13hPmfHE=
|
||||
github.com/AdguardTeam/urlfilter v0.16.2/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
||||
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
|
||||
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
|
||||
github.com/AdguardTeam/golibs v0.15.0 h1:yOv/fdVkJIOWKr0NlUXAE9RA0DK9GKiBbiGzq47vY7o=
|
||||
github.com/AdguardTeam/golibs v0.15.0/go.mod h1:66ZLs8P7nk/3IfKroQ1rqtieLk+5eXYXMBKXlVL7KeI=
|
||||
github.com/AdguardTeam/urlfilter v0.17.0 h1:tUzhtR9wMx704GIP3cibsDQJrixlMHfwoQbYJfPdFow=
|
||||
github.com/AdguardTeam/urlfilter v0.17.0/go.mod h1:bbuZjPUzm/Ip+nz5qPPbwIP+9rZyQbQad8Lt/0fCulU=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
|
||||
@ -27,7 +22,6 @@ github.com/caarlos0/env/v7 v7.1.0 h1:9lzTF5amyQeWHZzuZeKlCb5FWSUxpG1js43mhbY8ozg
|
||||
github.com/caarlos0/env/v7 v7.1.0/go.mod h1:LPPWniDUq4JaO6Q41vtlyikhMknqymCLBw0eX4dcH1E=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@ -38,8 +32,7 @@ github.com/getsentry/sentry-go v0.21.0 h1:c9l5F1nPF30JIppulk4veau90PK6Smu3abgVtV
|
||||
github.com/getsentry/sentry-go v0.21.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-ole/go-ole v1.2.5 h1:t4MGB5xEDZvXI+0rMjjsfBsD7yAgp/s9ZDkL1JndXwY=
|
||||
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
@ -50,25 +43,20 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ=
|
||||
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
|
||||
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
|
||||
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
||||
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
|
||||
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
|
||||
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM=
|
||||
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
|
||||
github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg=
|
||||
github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0=
|
||||
github.com/panjf2000/ants/v2 v2.7.5 h1:/vhh0Hza9G1vP1PdCj9hl6MUzCRbmtcTJL0OsnmytuU=
|
||||
@ -77,9 +65,9 @@ github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI=
|
||||
github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk=
|
||||
github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
|
||||
@ -90,47 +78,40 @@ github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+Pymzi
|
||||
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
|
||||
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
|
||||
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/shirou/gopsutil/v3 v3.21.8 h1:nKct+uP0TV8DjjNiHanKf8SAuub+GNsbrOtM9Nl9biA=
|
||||
github.com/shirou/gopsutil/v3 v3.21.8/go.mod h1:YWp/H8Qs5fVmf17v7JNZzA0mPJ+mS2e9JdiUF9LlKzQ=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tklauser/go-sysconf v0.3.9 h1:JeUVdAOWhhxVcU6Eqr/ATFHgXk/mmiItdKeJPev3vTo=
|
||||
github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs=
|
||||
github.com/tklauser/numcpus v0.3.0 h1:ILuRUQBtssgnxw0XXIjKUC56fgnOrFoQQ/4+DeU2biQ=
|
||||
github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8=
|
||||
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
|
||||
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
|
||||
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210929193557-e81a3d93ecf6/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -138,44 +119,37 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210909193231-528a39cd75f3/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
||||
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E=
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
|
||||
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
|
||||
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
52
go.work.sum
52
go.work.sum
@ -119,7 +119,6 @@ github.com/go-kit/log v0.2.0 h1:7i2K3eKTos3Vc0enKCfnVcgHh2olr/MyfboYq7cAcFw=
|
||||
github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU=
|
||||
github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
|
||||
github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA=
|
||||
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab h1:xveKWz2iaueeTaUgdetzel+U7exyigDYBryyVfV/rZk=
|
||||
@ -150,6 +149,7 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
@ -200,14 +200,11 @@ github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
|
||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||
github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kataras/blocks v0.0.7 h1:cF3RDY/vxnSRezc7vLFlQFTYXG/yAr1o7WImJuZbzC4=
|
||||
github.com/kataras/blocks v0.0.7/go.mod h1:UJIU97CluDo0f+zEjbnbkeMRlvYORtmc1304EeyXf4I=
|
||||
github.com/kataras/golog v0.1.7 h1:0TY5tHn5L5DlRIikepcaRR/6oInIr9AiWsxzt0vvlBE=
|
||||
@ -234,6 +231,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJ
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
|
||||
github.com/kr/pty v1.1.3 h1:/Um6a/ZmD5tF7peoOJ5oN5KMQ0DrGVQSXLNwyckutPk=
|
||||
@ -283,12 +281,9 @@ github.com/microcosm-cc/bluemonday v1.0.23/go.mod h1:mN70sk7UkkF8TUr2IGBpNN0jAgS
|
||||
github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 h1:D6paGObi5Wud7xg83MaEFyjxQB1W5bz5d0IFppr+ymk=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab h1:eFXv9Nu1lGbrNbj619aWwZfVF5HBrm9Plte8aNptuTI=
|
||||
@ -296,13 +291,18 @@ github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
|
||||
github.com/onsi/ginkgo/v2 v2.8.1/go.mod h1:N1/NbDngAFcSLdyZ+/aYTYGSlq9qMCS/cNKGJjy+csc=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM=
|
||||
github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo=
|
||||
github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM=
|
||||
github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
|
||||
github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
|
||||
github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw=
|
||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
||||
github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg=
|
||||
github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0=
|
||||
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
|
||||
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
@ -313,6 +313,8 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:
|
||||
github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
|
||||
github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc=
|
||||
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
|
||||
github.com/rogpeppe/go-internal v1.3.0 h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
@ -324,6 +326,8 @@ github.com/schollz/closestmatch v2.1.0+incompatible h1:Uel2GXEpJqOWBrlyI+oY9LTiy
|
||||
github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g=
|
||||
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4 h1:Fth6mevc5rX7glNLpbAMJnqKlfIkcTjZCSHEeqvKbcI=
|
||||
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
|
||||
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48 h1:vabduItPAIz9px5iryD5peyx7O3Ya8TBThapgXim98o=
|
||||
@ -388,6 +392,7 @@ github.com/tdewolff/minify/v2 v2.12.4 h1:kejsHQMM17n6/gwdw53qsi6lg0TGddZADVyQOz1
|
||||
github.com/tdewolff/minify/v2 v2.12.4/go.mod h1:h+SRvSIX3kwgwTFOpSckvSxgax3uy8kZTSF1Ojrr3bk=
|
||||
github.com/tdewolff/parse/v2 v2.6.4 h1:KCkDvNUMof10e3QExio9OPZJT8SbdKojLBumw8YZycQ=
|
||||
github.com/tdewolff/parse/v2 v2.6.4/go.mod h1:woz0cgbLwFdtbjJu8PIKxhW05KplTFQkOdX78o+Jgrs=
|
||||
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
|
||||
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
|
||||
@ -429,10 +434,13 @@ golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp v0.0.0-20230306221820-f0f767cdffd6/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp v0.0.0-20230807204917-050eac23e9de/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
|
||||
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@ -441,6 +449,9 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
||||
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@ -456,6 +467,10 @@ golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su
|
||||
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
@ -481,9 +496,14 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
|
||||
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
|
||||
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
|
||||
@ -491,15 +511,27 @@ golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
|
||||
golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
|
||||
@ -519,11 +551,15 @@ google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoA
|
||||
google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
|
||||
google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
|
||||
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
|
||||
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI=
|
||||
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
|
||||
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
|
||||
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc=
|
||||
@ -536,6 +572,8 @@ gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
|
||||
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
|
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
108
internal/access/access.go
Normal file
108
internal/access/access.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Package access contains structures for access control management.
|
||||
package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// Interface is the access manager interface.
|
||||
type Interface interface {
|
||||
// IsBlockedHost returns true if host should be blocked.
|
||||
IsBlockedHost(host string, qt uint16) (blocked bool)
|
||||
|
||||
// IsBlockedIP returns the status of the IP address blocking as well as the
|
||||
// rule that blocked it.
|
||||
IsBlockedIP(ip netip.Addr) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Interface = (*Manager)(nil)
|
||||
|
||||
// Manager controls IP and client blocking that takes place before all
|
||||
// other processing. An Manager is safe for concurrent use.
|
||||
type Manager struct {
|
||||
blockedIPs map[netip.Addr]unit
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
blockedNets []netip.Prefix
|
||||
}
|
||||
|
||||
// New create an Manager. The parameters assumed to be valid.
|
||||
func New(blockedDomains, blockedSubnets []string) (am *Manager, err error) {
|
||||
am = &Manager{
|
||||
blockedIPs: map[netip.Addr]unit{},
|
||||
}
|
||||
|
||||
processAccessList(blockedSubnets, am.blockedIPs, &am.blockedNets)
|
||||
|
||||
b := &strings.Builder{}
|
||||
for _, h := range blockedDomains {
|
||||
stringutil.WriteToBuilder(b, strings.ToLower(h), "\n")
|
||||
}
|
||||
|
||||
lists := []filterlist.RuleList{
|
||||
&filterlist.StringRuleList{
|
||||
ID: 0,
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
},
|
||||
}
|
||||
|
||||
rulesStrg, err := filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked hosts: %w", err)
|
||||
}
|
||||
|
||||
am.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
|
||||
|
||||
return am, nil
|
||||
}
|
||||
|
||||
// processAccessList is a helper for processing a list of strings, each of them
|
||||
// assumed be a valid IP address or a valid CIDR.
|
||||
func processAccessList(strs []string, ips map[netip.Addr]unit, nets *[]netip.Prefix) {
|
||||
for _, s := range strs {
|
||||
var err error
|
||||
var ip netip.Addr
|
||||
var ipnet netip.Prefix
|
||||
if ip, err = netip.ParseAddr(s); err == nil {
|
||||
ips[ip] = unit{}
|
||||
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
||||
*nets = append(*nets, ipnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedHost returns true if host should be blocked.
|
||||
func (am *Manager) IsBlockedHost(host string, qt uint16) (blocked bool) {
|
||||
_, blocked = am.blockedHostsEng.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: qt,
|
||||
})
|
||||
|
||||
return blocked
|
||||
}
|
||||
|
||||
// IsBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (am *Manager) IsBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
if _, ok := am.blockedIPs[ip]; ok {
|
||||
return true, ip.String()
|
||||
}
|
||||
|
||||
for _, ipnet := range am.blockedNets {
|
||||
if ipnet.Contains(ip) {
|
||||
return true, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
107
internal/access/access_test.go
Normal file
107
internal/access/access_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package access_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccessManager_IsBlockedHost(t *testing.T) {
|
||||
am, err := access.New([]string{
|
||||
"block.test",
|
||||
"UPPERCASE.test",
|
||||
"||block_aaaa.test^$dnstype=AAAA",
|
||||
}, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
host string
|
||||
qt uint16
|
||||
}{{
|
||||
want: assert.False,
|
||||
name: "pass",
|
||||
host: "pass.test",
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
want: assert.True,
|
||||
name: "blocked_domain_A",
|
||||
host: "block.test",
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
want: assert.True,
|
||||
name: "blocked_domain_HTTPS",
|
||||
host: "block.test",
|
||||
qt: dns.TypeHTTPS,
|
||||
}, {
|
||||
want: assert.True,
|
||||
name: "uppercase_domain",
|
||||
host: "uppercase.test",
|
||||
qt: dns.TypeHTTPS,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "pass_qt",
|
||||
host: "block_aaaa.test",
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
want: assert.True,
|
||||
name: "block_qt",
|
||||
host: "block_aaaa.test",
|
||||
qt: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
blocked := am.IsBlockedHost(tc.host, tc.qt)
|
||||
tc.want(t, blocked)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessManager_IsBlockedIP(t *testing.T) {
|
||||
am, err := access.New([]string{}, []string{
|
||||
"1.1.1.1",
|
||||
"2.2.2.0/8",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
ip netip.Addr
|
||||
wantRule string
|
||||
name string
|
||||
}{{
|
||||
want: assert.False,
|
||||
wantRule: "",
|
||||
name: "pass",
|
||||
ip: netip.MustParseAddr("1.1.1.0"),
|
||||
}, {
|
||||
want: assert.True,
|
||||
wantRule: "1.1.1.1",
|
||||
name: "block_ip",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
}, {
|
||||
want: assert.False,
|
||||
wantRule: "",
|
||||
name: "pass_subnet",
|
||||
ip: netip.MustParseAddr("1.2.2.2"),
|
||||
}, {
|
||||
want: assert.True,
|
||||
wantRule: "2.2.2.0/8",
|
||||
name: "block_subnet",
|
||||
ip: netip.MustParseAddr("2.2.2.2"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
blocked, rule := am.IsBlockedIP(tc.ip)
|
||||
tc.want(t, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
})
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"encoding/csv"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
@ -40,10 +41,10 @@ func main() {
|
||||
// Skip the first row, as it is a header.
|
||||
rows = rows[1:]
|
||||
|
||||
// Sort by the code to make the output more predictable and easier to look
|
||||
// through.
|
||||
slices.SortFunc(rows, func(a, b []string) (less bool) {
|
||||
return a[1] < b[1]
|
||||
slices.SortFunc(rows, func(a, b []string) (res int) {
|
||||
// Sort by the code to make the output more predictable and easier to
|
||||
// look through.
|
||||
return strings.Compare(a[1], b[1])
|
||||
})
|
||||
|
||||
tmpl, err := template.New("main").Parse(tmplStr)
|
||||
|
@ -74,3 +74,10 @@ func ParseSubnets(strs ...string) (subnets []netip.Prefix, err error) {
|
||||
|
||||
return subnets, nil
|
||||
}
|
||||
|
||||
// NormalizeDomain returns lowercased version of the host without the final dot.
|
||||
//
|
||||
// TODO(a.garipov): Move to golibs.
|
||||
func NormalizeDomain(fqdn string) (host string) {
|
||||
return strings.ToLower(strings.TrimSuffix(fqdn, "."))
|
||||
}
|
||||
|
19
internal/agdnet/prefixaddr.go
Normal file
19
internal/agdnet/prefixaddr.go
Normal file
@ -0,0 +1,19 @@
|
||||
package agdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// FormatPrefixAddr returns either a simple IP:port address or one with the
|
||||
// prefix length appended after a slash, depending on whether or not subnet is a
|
||||
// single-address subnet. This is done to make using the IP:port part easier to
|
||||
// split off using functions like [strings.Cut].
|
||||
func FormatPrefixAddr(subnet netip.Prefix, port uint16) (s string) {
|
||||
addrPort := netip.AddrPortFrom(subnet.Addr(), port)
|
||||
if subnet.IsSingleIP() {
|
||||
return addrPort.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%d", addrPort, subnet.Bits())
|
||||
}
|
17
internal/agdnet/prefixaddr_example_test.go
Normal file
17
internal/agdnet/prefixaddr_example_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package agdnet_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
)
|
||||
|
||||
func ExampeFormatPrefixAddr() {
|
||||
fmt.Println(agdnet.FormatPrefixAddr(netip.MustParsePrefix("1.2.3.4/32"), 5678))
|
||||
fmt.Println(agdnet.FormatPrefixAddr(netip.MustParsePrefix("1.2.3.0/24"), 5678))
|
||||
|
||||
// Output:
|
||||
// 1.2.3.4:5678
|
||||
// 1.2.3.0:5678/24
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -19,7 +20,14 @@ import (
|
||||
//
|
||||
// See go doc net.Resolver.
|
||||
type Resolver interface {
|
||||
LookupIP(ctx context.Context, fam netutil.AddrFamily, host string) (ips []net.IP, err error)
|
||||
// LookupNetIP returns a slice of host's IP addresses of family specified by
|
||||
// fam, which must be either [netutil.AddrFamilyIPv4] or
|
||||
// [netutil.AddrFamilyIPv6].
|
||||
LookupNetIP(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (ips []netip.Addr, err error)
|
||||
}
|
||||
|
||||
// DefaultResolver uses [net.DefaultResolver] to resolve addresses.
|
||||
@ -28,17 +36,17 @@ type DefaultResolver struct{}
|
||||
// type check
|
||||
var _ Resolver = DefaultResolver{}
|
||||
|
||||
// LookupIP implements the [Resolver] interface for DefaultResolver.
|
||||
func (DefaultResolver) LookupIP(
|
||||
// LookupNetIP implements the [Resolver] interface for DefaultResolver.
|
||||
func (DefaultResolver) LookupNetIP(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
) (ips []netip.Addr, err error) {
|
||||
switch fam {
|
||||
case netutil.AddrFamilyIPv4:
|
||||
return net.DefaultResolver.LookupIP(ctx, "ip4", host)
|
||||
return net.DefaultResolver.LookupNetIP(ctx, "ip4", host)
|
||||
case netutil.AddrFamilyIPv6:
|
||||
return net.DefaultResolver.LookupIP(ctx, "ip6", host)
|
||||
return net.DefaultResolver.LookupNetIP(ctx, "ip6", host)
|
||||
default:
|
||||
return nil, net.UnknownNetworkError(fam.String())
|
||||
}
|
||||
@ -50,7 +58,7 @@ type resolveCache map[string]*resolveCacheItem
|
||||
// resolveCacheItem is an item of [resolveCache].
|
||||
type resolveCacheItem struct {
|
||||
refrTime time.Time
|
||||
ips []net.IP
|
||||
ips []netip.Addr
|
||||
}
|
||||
|
||||
// CachingResolver caches resolved results for hosts for a certain time,
|
||||
@ -83,13 +91,13 @@ func NewCachingResolver(resolver Resolver, ttl time.Duration) (c *CachingResolve
|
||||
// type check
|
||||
var _ Resolver = (*CachingResolver)(nil)
|
||||
|
||||
// LookupIP implements the [Resolver] interface for *CachingResolver. host
|
||||
// LookupNetIP implements the [Resolver] interface for *CachingResolver. host
|
||||
// should be normalized. Slice ips and its elements must not be mutated.
|
||||
func (c *CachingResolver) LookupIP(
|
||||
func (c *CachingResolver) LookupNetIP(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
) (ips []netip.Addr, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@ -119,20 +127,20 @@ func (c *CachingResolver) resolve(
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (item *resolveCacheItem, err error) {
|
||||
var ips []net.IP
|
||||
var ips []netip.Addr
|
||||
|
||||
refrTime := time.Now()
|
||||
|
||||
// Don't resolve IP addresses.
|
||||
ip := ipFromHost(host, fam)
|
||||
if ip != nil {
|
||||
ips = []net.IP{ip}
|
||||
if ip != (netip.Addr{}) {
|
||||
ips = []netip.Addr{ip}
|
||||
|
||||
// Set the refresh time to the maximum date that time.Duration allows to
|
||||
// prevent this item from refreshing.
|
||||
refrTime = time.Unix(0, math.MaxInt64)
|
||||
} else {
|
||||
ips, err = c.resolver.LookupIP(ctx, fam, host)
|
||||
ips, err = c.resolver.LookupNetIP(ctx, fam, host)
|
||||
if err != nil {
|
||||
if !isExpectedLookupError(fam, err) {
|
||||
return nil, fmt.Errorf("resolving %s addr for %q: %w", fam, host, err)
|
||||
@ -159,22 +167,25 @@ func (c *CachingResolver) resolve(
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// ipFromHost returns a normalized IP address if host contains an IP address of
|
||||
// the given address family.
|
||||
func ipFromHost(host string, fam netutil.AddrFamily) (ip net.IP) {
|
||||
ip = net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return nil
|
||||
// ipFromHost parses host as if it'd be an IP address of specified fam. It
|
||||
// returns an empty netip.
|
||||
func ipFromHost(host string, fam netutil.AddrFamily) (ip netip.Addr) {
|
||||
var famFunc func(netip.Addr) (ok bool)
|
||||
switch fam {
|
||||
case netutil.AddrFamilyIPv4:
|
||||
famFunc = netip.Addr.Is4
|
||||
case netutil.AddrFamilyIPv6:
|
||||
famFunc = netip.Addr.Is6
|
||||
default:
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
ip4 := ip.To4()
|
||||
if fam == netutil.AddrFamilyIPv4 && ip4 != nil {
|
||||
return ip4
|
||||
} else if fam == netutil.AddrFamilyIPv6 && ip4 == nil {
|
||||
return ip
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil || !famFunc(ip) {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
return nil
|
||||
return ip
|
||||
}
|
||||
|
||||
// isExpectedLookupError returns true if the error is an expected lookup error.
|
||||
|
@ -2,7 +2,7 @@ package agdnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
@ -17,14 +17,14 @@ func TestCachingResolver_Resolve(t *testing.T) {
|
||||
const testHost = "addr.example"
|
||||
|
||||
var numLookups uint64
|
||||
wantIPv4 := []net.IP{{1, 2, 3, 4}}
|
||||
wantIPv6 := []net.IP{net.ParseIP("1234::5678")}
|
||||
wantIPv4 := []netip.Addr{netip.MustParseAddr("1.2.3.4")}
|
||||
wantIPv6 := []netip.Addr{netip.MustParseAddr("1234::5678")}
|
||||
r := &agdtest.Resolver{
|
||||
OnLookupIP: func(
|
||||
_ context.Context,
|
||||
OnLookupNetIP: func(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
_ string,
|
||||
) (ips []net.IP, err error) {
|
||||
host string,
|
||||
) (ips []netip.Addr, err error) {
|
||||
numLookups++
|
||||
|
||||
if fam == netutil.AddrFamilyIPv4 {
|
||||
@ -40,7 +40,7 @@ func TestCachingResolver_Resolve(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantIPs []net.IP
|
||||
wantIPs []netip.Addr
|
||||
wantNum uint64
|
||||
fam netutil.AddrFamily
|
||||
}{{
|
||||
@ -78,7 +78,7 @@ func TestCachingResolver_Resolve(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := cached.LookupIP(ctx, tc.fam, tc.host)
|
||||
got, err := cached.LookupNetIP(ctx, tc.fam, tc.host)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantNum, numLookups)
|
||||
|
27
internal/agdprotobuf/pbutil.go
Normal file
27
internal/agdprotobuf/pbutil.go
Normal file
@ -0,0 +1,27 @@
|
||||
// Package agdprotobuf contains protobuf utils.
|
||||
package agdprotobuf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// ByteSlicesToIPs converts a slice of byte slices into a slice of netip.Addrs.
|
||||
func ByteSlicesToIPs(data [][]byte) (ips []netip.Addr, err error) {
|
||||
if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ips = make([]netip.Addr, 0, len(data))
|
||||
for i, ipData := range data {
|
||||
var ip netip.Addr
|
||||
err = ip.UnmarshalBinary(ipData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ip at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
37
internal/agdsync/agdsync.go
Normal file
37
internal/agdsync/agdsync.go
Normal file
@ -0,0 +1,37 @@
|
||||
// Package agdsync contains extensions and utilities for package sync from the
|
||||
// standard library.
|
||||
//
|
||||
// TODO(a.garipov): Move to module golibs.
|
||||
package agdsync
|
||||
|
||||
import "sync"
|
||||
|
||||
// TypedPool is the strongly typed version of [sync.Pool] that manages pointers
|
||||
// to T.
|
||||
type TypedPool[T any] struct {
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
// NewTypedPool returns a new strongly typed pool. newFunc must not be nil.
|
||||
func NewTypedPool[T any](newFunc func() (v *T)) (p *TypedPool[T]) {
|
||||
return &TypedPool[T]{
|
||||
pool: &sync.Pool{
|
||||
New: func() (v any) { return newFunc() },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get selects an arbitrary item from the pool, removes it from the pool, and
|
||||
// returns it to the caller.
|
||||
//
|
||||
// See [sync.Pool.Get].
|
||||
func (p *TypedPool[T]) Get() (v *T) {
|
||||
return p.pool.Get().(*T)
|
||||
}
|
||||
|
||||
// Put adds v to the pool.
|
||||
//
|
||||
// See [sync.Pool.Put].
|
||||
func (p *TypedPool[T]) Put(v *T) {
|
||||
p.pool.Put(v)
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
@ -56,6 +57,27 @@ func (r *Refresher) Refresh(ctx context.Context) (err error) {
|
||||
return r.OnRefresh(ctx)
|
||||
}
|
||||
|
||||
// Package access
|
||||
|
||||
// type check
|
||||
var _ access.Interface = (*AccessManager)(nil)
|
||||
|
||||
// AccessManager is a [access.Interface] for tests.
|
||||
type AccessManager struct {
|
||||
OnIsBlockedHost func(host string, qt uint16) (blocked bool)
|
||||
OnIsBlockedIP func(ip netip.Addr) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// IsBlockedHost implements the [access.Interface] interface for *AccessManager.
|
||||
func (a *AccessManager) IsBlockedHost(host string, qt uint16) (blocked bool) {
|
||||
return a.OnIsBlockedHost(host, qt)
|
||||
}
|
||||
|
||||
// IsBlockedIP implements the [access.Interface] interface for *AccessManager.
|
||||
func (a *AccessManager) IsBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
return a.OnIsBlockedIP(ip)
|
||||
}
|
||||
|
||||
// Package agdnet
|
||||
|
||||
// type check
|
||||
@ -63,20 +85,20 @@ var _ agdnet.Resolver = (*Resolver)(nil)
|
||||
|
||||
// Resolver is an agd.Resolver for tests.
|
||||
type Resolver struct {
|
||||
OnLookupIP func(
|
||||
OnLookupNetIP func(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (ips []net.IP, err error)
|
||||
) (ips []netip.Addr, err error)
|
||||
}
|
||||
|
||||
// LookupIP implements the agd.Resolver interface for *Resolver.
|
||||
func (r *Resolver) LookupIP(
|
||||
// LookupNetIP implements the [agd.Resolver] interface for *Resolver.
|
||||
func (r *Resolver) LookupNetIP(
|
||||
ctx context.Context,
|
||||
fam netutil.AddrFamily,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
return r.OnLookupIP(ctx, fam, host)
|
||||
) (ips []netip.Addr, err error) {
|
||||
return r.OnLookupNetIP(ctx, fam, host)
|
||||
}
|
||||
|
||||
// Package billstat
|
||||
|
1464
internal/backendpb/backend.pb.go
Normal file
1464
internal/backendpb/backend.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
124
internal/backendpb/backend.proto
Normal file
124
internal/backendpb/backend.proto
Normal file
@ -0,0 +1,124 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "./backendpb";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
option java_multiple_files = true;
|
||||
option java_package = "com.adguard.backend.dns.generated";
|
||||
option java_outer_classname = "DNSProfilesProto";
|
||||
option objc_class_prefix = "DNS";
|
||||
|
||||
service DNSService {
|
||||
|
||||
/*
|
||||
Gets DNS profiles.
|
||||
|
||||
Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
|
||||
|
||||
The trailers headers will include a "sync_time", given in milliseconds,
|
||||
that should be used for subsequent incremental DNS profile synchronization requests.
|
||||
*/
|
||||
rpc getDNSProfiles(DNSProfilesRequest) returns (stream DNSProfile);
|
||||
|
||||
/*
|
||||
Stores devices activity.
|
||||
*/
|
||||
rpc saveDevicesBillingStat(stream DeviceBillingStat) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
||||
message DNSProfilesRequest {
|
||||
google.protobuf.Timestamp sync_time = 1;
|
||||
}
|
||||
|
||||
message DNSProfile {
|
||||
string dns_id = 1;
|
||||
bool filtering_enabled = 2;
|
||||
bool query_log_enabled = 3;
|
||||
bool deleted = 4;
|
||||
SafeBrowsingSettings safe_browsing = 5;
|
||||
ParentalSettings parental = 6;
|
||||
RuleListsSettings rule_lists = 7;
|
||||
repeated DeviceSettings devices = 8;
|
||||
repeated string custom_rules = 9;
|
||||
google.protobuf.Duration filtered_response_ttl = 10;
|
||||
bool block_private_relay = 11;
|
||||
bool block_firefox_canary = 12;
|
||||
oneof blocking_mode {
|
||||
BlockingModeCustomIP blocking_mode_custom_ip = 13;
|
||||
BlockingModeNXDOMAIN blocking_mode_nxdomain = 14;
|
||||
BlockingModeNullIP blocking_mode_null_ip = 15;
|
||||
BlockingModeREFUSED blocking_mode_refused = 16;
|
||||
}
|
||||
}
|
||||
|
||||
message SafeBrowsingSettings {
|
||||
bool enabled = 1;
|
||||
bool block_dangerous_domains = 2;
|
||||
bool block_nrd = 3;
|
||||
}
|
||||
|
||||
message DeviceSettings {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
bool filtering_enabled = 3;
|
||||
bytes linked_ip = 4;
|
||||
repeated bytes dedicated_ips = 5;
|
||||
}
|
||||
|
||||
message ParentalSettings {
|
||||
bool enabled = 1;
|
||||
bool block_adult = 2;
|
||||
bool general_safe_search = 3;
|
||||
bool youtube_safe_search = 4;
|
||||
repeated string blocked_services = 5;
|
||||
ScheduleSettings schedule = 6;
|
||||
}
|
||||
|
||||
message ScheduleSettings {
|
||||
string tmz = 1;
|
||||
WeeklyRange weeklyRange = 2;
|
||||
}
|
||||
|
||||
message WeeklyRange {
|
||||
DayRange mon = 1;
|
||||
DayRange tue = 2;
|
||||
DayRange wed = 3;
|
||||
DayRange thu = 4;
|
||||
DayRange fri = 5;
|
||||
DayRange sat = 6;
|
||||
DayRange sun = 7;
|
||||
}
|
||||
|
||||
message DayRange {
|
||||
google.protobuf.Duration start = 1;
|
||||
google.protobuf.Duration end = 2;
|
||||
}
|
||||
|
||||
message RuleListsSettings {
|
||||
bool enabled = 1;
|
||||
repeated string ids = 2;
|
||||
}
|
||||
|
||||
message BlockingModeCustomIP {
|
||||
bytes ipv4 = 1;
|
||||
bytes ipv6 = 2;
|
||||
}
|
||||
|
||||
message BlockingModeNXDOMAIN {}
|
||||
|
||||
message BlockingModeNullIP {}
|
||||
|
||||
message BlockingModeREFUSED {}
|
||||
|
||||
message DeviceBillingStat {
|
||||
google.protobuf.Timestamp last_activity_time = 1;
|
||||
string device_id = 2;
|
||||
string client_country = 3;
|
||||
// Protocol type. Possible values see here: https://bit.adguard.com/projects/DNS/repos/dns-server/browse#ql-properties
|
||||
uint32 proto = 4;
|
||||
uint32 asn = 5;
|
||||
uint32 queries = 6;
|
||||
}
|
222
internal/backendpb/backend_grpc.pb.go
Normal file
222
internal/backendpb/backend_grpc.pb.go
Normal file
@ -0,0 +1,222 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.3.0
|
||||
// - protoc v4.23.4
|
||||
// source: backend.proto
|
||||
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
const (
|
||||
DNSService_GetDNSProfiles_FullMethodName = "/DNSService/getDNSProfiles"
|
||||
DNSService_SaveDevicesBillingStat_FullMethodName = "/DNSService/saveDevicesBillingStat"
|
||||
)
|
||||
|
||||
// DNSServiceClient is the client API for DNSService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type DNSServiceClient interface {
|
||||
// Gets DNS profiles.
|
||||
//
|
||||
// Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
|
||||
//
|
||||
// The trailers headers will include a "sync_time", given in milliseconds,
|
||||
// that should be used for subsequent incremental DNS profile synchronization requests.
|
||||
GetDNSProfiles(ctx context.Context, in *DNSProfilesRequest, opts ...grpc.CallOption) (DNSService_GetDNSProfilesClient, error)
|
||||
// Stores devices activity.
|
||||
SaveDevicesBillingStat(ctx context.Context, opts ...grpc.CallOption) (DNSService_SaveDevicesBillingStatClient, error)
|
||||
}
|
||||
|
||||
type dNSServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewDNSServiceClient(cc grpc.ClientConnInterface) DNSServiceClient {
|
||||
return &dNSServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *dNSServiceClient) GetDNSProfiles(ctx context.Context, in *DNSProfilesRequest, opts ...grpc.CallOption) (DNSService_GetDNSProfilesClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DNSService_ServiceDesc.Streams[0], DNSService_GetDNSProfiles_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &dNSServiceGetDNSProfilesClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DNSService_GetDNSProfilesClient interface {
|
||||
Recv() (*DNSProfile, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type dNSServiceGetDNSProfilesClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *dNSServiceGetDNSProfilesClient) Recv() (*DNSProfile, error) {
|
||||
m := new(DNSProfile)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *dNSServiceClient) SaveDevicesBillingStat(ctx context.Context, opts ...grpc.CallOption) (DNSService_SaveDevicesBillingStatClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DNSService_ServiceDesc.Streams[1], DNSService_SaveDevicesBillingStat_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &dNSServiceSaveDevicesBillingStatClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DNSService_SaveDevicesBillingStatClient interface {
|
||||
Send(*DeviceBillingStat) error
|
||||
CloseAndRecv() (*emptypb.Empty, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type dNSServiceSaveDevicesBillingStatClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *dNSServiceSaveDevicesBillingStatClient) Send(m *DeviceBillingStat) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *dNSServiceSaveDevicesBillingStatClient) CloseAndRecv() (*emptypb.Empty, error) {
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := new(emptypb.Empty)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// DNSServiceServer is the server API for DNSService service.
|
||||
// All implementations must embed UnimplementedDNSServiceServer
|
||||
// for forward compatibility
|
||||
type DNSServiceServer interface {
|
||||
// Gets DNS profiles.
|
||||
//
|
||||
// Field "sync_time" in DNSProfilesRequest - pass to return the latest updates after this time moment.
|
||||
//
|
||||
// The trailers headers will include a "sync_time", given in milliseconds,
|
||||
// that should be used for subsequent incremental DNS profile synchronization requests.
|
||||
GetDNSProfiles(*DNSProfilesRequest, DNSService_GetDNSProfilesServer) error
|
||||
// Stores devices activity.
|
||||
SaveDevicesBillingStat(DNSService_SaveDevicesBillingStatServer) error
|
||||
mustEmbedUnimplementedDNSServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedDNSServiceServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedDNSServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedDNSServiceServer) GetDNSProfiles(*DNSProfilesRequest, DNSService_GetDNSProfilesServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method GetDNSProfiles not implemented")
|
||||
}
|
||||
func (UnimplementedDNSServiceServer) SaveDevicesBillingStat(DNSService_SaveDevicesBillingStatServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method SaveDevicesBillingStat not implemented")
|
||||
}
|
||||
func (UnimplementedDNSServiceServer) mustEmbedUnimplementedDNSServiceServer() {}
|
||||
|
||||
// UnsafeDNSServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to DNSServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeDNSServiceServer interface {
|
||||
mustEmbedUnimplementedDNSServiceServer()
|
||||
}
|
||||
|
||||
func RegisterDNSServiceServer(s grpc.ServiceRegistrar, srv DNSServiceServer) {
|
||||
s.RegisterService(&DNSService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _DNSService_GetDNSProfiles_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(DNSProfilesRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(DNSServiceServer).GetDNSProfiles(m, &dNSServiceGetDNSProfilesServer{stream})
|
||||
}
|
||||
|
||||
type DNSService_GetDNSProfilesServer interface {
|
||||
Send(*DNSProfile) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type dNSServiceGetDNSProfilesServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *dNSServiceGetDNSProfilesServer) Send(m *DNSProfile) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _DNSService_SaveDevicesBillingStat_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(DNSServiceServer).SaveDevicesBillingStat(&dNSServiceSaveDevicesBillingStatServer{stream})
|
||||
}
|
||||
|
||||
type DNSService_SaveDevicesBillingStatServer interface {
|
||||
SendAndClose(*emptypb.Empty) error
|
||||
Recv() (*DeviceBillingStat, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type dNSServiceSaveDevicesBillingStatServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *dNSServiceSaveDevicesBillingStatServer) SendAndClose(m *emptypb.Empty) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *dNSServiceSaveDevicesBillingStatServer) Recv() (*DeviceBillingStat, error) {
|
||||
m := new(DeviceBillingStat)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// DNSService_ServiceDesc is the grpc.ServiceDesc for DNSService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var DNSService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "DNSService",
|
||||
HandlerType: (*DNSServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "getDNSProfiles",
|
||||
Handler: _DNSService_GetDNSProfiles_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "saveDevicesBillingStat",
|
||||
Handler: _DNSService_SaveDevicesBillingStat_Handler,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "backend.proto",
|
||||
}
|
43
internal/backendpb/backendpb.go
Normal file
43
internal/backendpb/backendpb.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Package backendpb contains the protobuf structures for the backend API.
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// newClient returns new properly initialized DNSServiceClient.
|
||||
func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
|
||||
var creds credentials.TransportCredentials
|
||||
switch s := apiURL.Scheme; s {
|
||||
case "grpc":
|
||||
creds = insecure.NewCredentials()
|
||||
case "grpcs":
|
||||
// Use a nil [tls.Config] to get the default TLS configuration.
|
||||
creds = credentials.NewTLS(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad grpc url scheme %q", s)
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(apiURL.Host, grpc.WithTransportCredentials(creds))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
|
||||
// Immediately make a connection attempt, since the constructor is often
|
||||
// called right before the initial refresh.
|
||||
conn.Connect()
|
||||
|
||||
return NewDNSServiceClient(conn), nil
|
||||
}
|
||||
|
||||
// reportf is a helper method for reporting non-critical errors.
|
||||
func reportf(ctx context.Context, errColl agd.ErrorCollector, format string, args ...any) {
|
||||
agd.Collectf(ctx, errColl, "backendpb: "+format, args...)
|
||||
}
|
38
internal/backendpb/backendpb_test.go
Normal file
38
internal/backendpb/backendpb_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package backendpb_test
|
||||
|
||||
import "github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
|
||||
// testDNSServiceServer is the [backendpb.DNSServiceServer] for tests.
|
||||
//
|
||||
// TODO(d.kolyshev): Use this to remove as much as possible from the internal
|
||||
// test.
|
||||
type testDNSServiceServer struct {
|
||||
backendpb.UnimplementedDNSServiceServer
|
||||
OnGetDNSProfiles func(
|
||||
req *backendpb.DNSProfilesRequest,
|
||||
srv backendpb.DNSService_GetDNSProfilesServer,
|
||||
) (err error)
|
||||
OnSaveDevicesBillingStat func(
|
||||
srv backendpb.DNSService_SaveDevicesBillingStatServer,
|
||||
) (err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
|
||||
|
||||
// GetDNSProfiles implements the [backendpb.DNSServiceServer] interface for
|
||||
// *testDNSServiceServer
|
||||
func (s *testDNSServiceServer) GetDNSProfiles(
|
||||
req *backendpb.DNSProfilesRequest,
|
||||
srv backendpb.DNSService_GetDNSProfilesServer,
|
||||
) (err error) {
|
||||
return s.OnGetDNSProfiles(req, srv)
|
||||
}
|
||||
|
||||
// SaveDevicesBillingStat implements the [backendpb.DNSServiceServer] interface
|
||||
// for *testDNSServiceServer
|
||||
func (s *testDNSServiceServer) SaveDevicesBillingStat(
|
||||
srv backendpb.DNSService_SaveDevicesBillingStatServer,
|
||||
) (err error) {
|
||||
return s.OnSaveDevicesBillingStat(srv)
|
||||
}
|
101
internal/backendpb/billstat.go
Normal file
101
internal/backendpb/billstat.go
Normal file
@ -0,0 +1,101 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// BillStatConfig is the configuration structure for the business logic backend
|
||||
// billing statistics uploader.
|
||||
type BillStatConfig struct {
|
||||
// ErrColl is the error collector that is used to collect critical and
|
||||
// non-critical errors.
|
||||
ErrColl agd.ErrorCollector
|
||||
|
||||
// Endpoint is the backend API URL. The scheme should be either "grpc" or
|
||||
// "grpcs".
|
||||
Endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewBillStat creates a new billing statistics uploader. c must not be nil.
|
||||
func NewBillStat(c *BillStatConfig) (b *BillStat, err error) {
|
||||
b = &BillStat{
|
||||
errColl: c.ErrColl,
|
||||
}
|
||||
|
||||
b.client, err = newClient(c.Endpoint)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// BillStat is the implementation of the [billstat.Uploader] interface that
|
||||
// uploads the billing statistics to the business logic backend. It is safe for
|
||||
// concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Consider uniting with [ProfileStorage] into a single
|
||||
// backendpb.Client.
|
||||
type BillStat struct {
|
||||
errColl agd.ErrorCollector
|
||||
|
||||
// client is the current GRPC client.
|
||||
client DNSServiceClient
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ billstat.Uploader = (*BillStat)(nil)
|
||||
|
||||
// Upload implements the [billstat.Uploader] interface for *BillStat.
|
||||
func (b *BillStat) Upload(ctx context.Context, records billstat.Records) (err error) {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stream, err := b.client.SaveDevicesBillingStat(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening stream: %w", err)
|
||||
}
|
||||
|
||||
for deviceID, record := range records {
|
||||
if record == nil {
|
||||
reportf(ctx, b.errColl, "device %q: null record", deviceID)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
sendErr := stream.Send(recordToProtobuf(record, deviceID))
|
||||
if sendErr != nil {
|
||||
return fmt.Errorf("uploading device %q record: %w", deviceID, sendErr)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stream.CloseAndRecv()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("finishing stream: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordToProtobuf converts a billstat record structure into the protobuf
|
||||
// structure.
|
||||
func recordToProtobuf(r *billstat.Record, devID agd.DeviceID) (s *DeviceBillingStat) {
|
||||
return &DeviceBillingStat{
|
||||
LastActivityTime: timestamppb.New(r.Time),
|
||||
DeviceId: string(devID),
|
||||
ClientCountry: string(r.Country),
|
||||
Proto: uint32(r.Proto),
|
||||
Asn: uint32(r.ASN),
|
||||
Queries: uint32(r.Queries),
|
||||
}
|
||||
}
|
104
internal/backendpb/billstat_test.go
Normal file
104
internal/backendpb/billstat_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package backendpb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestBillStat_Upload(t *testing.T) {
|
||||
const (
|
||||
wantDeviceID = "test"
|
||||
invalidDeviceID = "invalid"
|
||||
)
|
||||
|
||||
wantRecord := &billstat.Record{
|
||||
Time: time.Time{},
|
||||
Country: agd.CountryCY,
|
||||
ASN: 1221,
|
||||
Queries: 1122,
|
||||
Proto: agd.ProtoDNS,
|
||||
}
|
||||
|
||||
records := billstat.Records{
|
||||
wantDeviceID: wantRecord,
|
||||
invalidDeviceID: nil,
|
||||
}
|
||||
|
||||
srv := &testDNSServiceServer{
|
||||
OnSaveDevicesBillingStat: func(
|
||||
srv backendpb.DNSService_SaveDevicesBillingStatServer,
|
||||
) (err error) {
|
||||
pt := &testutil.PanicT{}
|
||||
|
||||
for {
|
||||
data, recvErr := srv.Recv()
|
||||
if recvErr != nil && errors.Is(recvErr, io.EOF) {
|
||||
return srv.SendAndClose(&emptypb.Empty{})
|
||||
}
|
||||
|
||||
require.NoError(t, recvErr)
|
||||
|
||||
assert.Equal(pt, wantDeviceID, data.DeviceId)
|
||||
assert.Equal(pt, uint32(wantRecord.ASN), data.Asn)
|
||||
assert.Equal(pt, string(wantRecord.Country), data.ClientCountry)
|
||||
assert.Equal(pt, timestamppb.New(wantRecord.Time), data.LastActivityTime)
|
||||
assert.Equal(pt, uint32(wantRecord.Proto), data.Proto)
|
||||
assert.Equal(pt, uint32(wantRecord.Queries), data.Queries)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcSrv := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(1*time.Second),
|
||||
grpc.Creds(insecure.NewCredentials()),
|
||||
)
|
||||
backendpb.RegisterDNSServiceServer(grpcSrv, srv)
|
||||
|
||||
go func() {
|
||||
pt := &testutil.PanicT{}
|
||||
|
||||
srvErr := grpcSrv.Serve(l)
|
||||
require.NoError(pt, srvErr)
|
||||
}()
|
||||
t.Cleanup(grpcSrv.GracefulStop)
|
||||
|
||||
errColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
testutil.AssertErrorMsg(t, `backendpb: device "invalid": null record`, err)
|
||||
},
|
||||
}
|
||||
|
||||
b, err := backendpb.NewBillStat(&backendpb.BillStatConfig{
|
||||
ErrColl: errColl,
|
||||
Endpoint: &url.URL{
|
||||
Scheme: "grpc",
|
||||
Host: l.Addr().String(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = b.Upload(ctx, records)
|
||||
require.NoError(t, err)
|
||||
}
|
457
internal/backendpb/profiledb.go
Normal file
457
internal/backendpb/profiledb.go
Normal file
@ -0,0 +1,457 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdprotobuf"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// ProfileStorageConfig is the configuration for the business logic backend
|
||||
// profile storage.
|
||||
type ProfileStorageConfig struct {
|
||||
// ErrColl is the error collector that is used to collect critical and
|
||||
// non-critical errors.
|
||||
ErrColl agd.ErrorCollector
|
||||
|
||||
// Endpoint is the backend API URL. The scheme should be either "grpc" or
|
||||
// "grpcs".
|
||||
Endpoint *url.URL
|
||||
}
|
||||
|
||||
// ProfileStorage is the implementation of the [profiledb.Storage] interface
|
||||
// that retrieves the profile and device information from the business logic
|
||||
// backend. It is safe for concurrent use.
|
||||
type ProfileStorage struct {
|
||||
errColl agd.ErrorCollector
|
||||
|
||||
// client is the current GRPC client.
|
||||
client DNSServiceClient
|
||||
}
|
||||
|
||||
// NewProfileStorage returns a new [ProfileStorage] that retrieves information
|
||||
// from the business logic backend.
|
||||
func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage, err error) {
|
||||
client, err := newClient(c.Endpoint)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ProfileStorage{
|
||||
client: client,
|
||||
errColl: c.ErrColl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ profiledb.Storage = (*ProfileStorage)(nil)
|
||||
|
||||
// Profiles implements the [profiledb.Storage] interface for *ProfileStorage.
|
||||
func (s *ProfileStorage) Profiles(
|
||||
ctx context.Context,
|
||||
req *profiledb.StorageRequest,
|
||||
) (resp *profiledb.StorageResponse, err error) {
|
||||
stream, err := s.client.GetDNSProfiles(ctx, toProtobuf(req))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading profiles: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, stream.CloseSend()) }()
|
||||
|
||||
resp = &profiledb.StorageResponse{
|
||||
Profiles: []*agd.Profile{},
|
||||
Devices: []*agd.Device{},
|
||||
}
|
||||
|
||||
stats := &profilesCallStats{
|
||||
isFullSync: req.SyncTime == time.Time{},
|
||||
}
|
||||
|
||||
for {
|
||||
stats.startRecv()
|
||||
profile, profErr := stream.Recv()
|
||||
if profErr != nil {
|
||||
if errors.Is(profErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("receiving profile: %w", profErr)
|
||||
}
|
||||
stats.endRecv()
|
||||
|
||||
stats.startDec()
|
||||
prof, devices, profErr := profile.toInternal(ctx, time.Now(), s.errColl)
|
||||
if profErr != nil {
|
||||
reportf(ctx, s.errColl, "loading profile: %w", profErr)
|
||||
|
||||
continue
|
||||
}
|
||||
stats.endDec()
|
||||
|
||||
resp.Profiles = append(resp.Profiles, prof)
|
||||
resp.Devices = append(resp.Devices, devices...)
|
||||
}
|
||||
|
||||
stats.report()
|
||||
|
||||
trailer := stream.Trailer()
|
||||
resp.SyncTime, err = syncTimeFromTrailer(trailer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieving sync_time: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// toInternal converts the protobuf-encoded data into a profile structure.
|
||||
func (x *DNSProfile) toInternal(
|
||||
ctx context.Context,
|
||||
updTime time.Time,
|
||||
errColl agd.ErrorCollector,
|
||||
) (profile *agd.Profile, devices []*agd.Device, err error) {
|
||||
if x == nil {
|
||||
return nil, nil, fmt.Errorf("profile is nil")
|
||||
}
|
||||
|
||||
parental, err := x.Parental.toInternal(ctx, errColl)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parental: %w", err)
|
||||
}
|
||||
|
||||
m, err := blockingModeToInternal(x.BlockingMode)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("blocking mode: %w", err)
|
||||
}
|
||||
|
||||
devices, deviceIds := devicesToInternal(ctx, x.Devices, errColl)
|
||||
listsEnabled, listIDs := x.RuleLists.toInternal(ctx, errColl)
|
||||
|
||||
profID, err := agd.NewProfileID(x.DnsId)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("id: %w", err)
|
||||
}
|
||||
|
||||
var fltRespTTL time.Duration
|
||||
if respTTL := x.FilteredResponseTtl; respTTL != nil {
|
||||
fltRespTTL = respTTL.AsDuration()
|
||||
}
|
||||
|
||||
return &agd.Profile{
|
||||
Parental: parental,
|
||||
BlockingMode: m,
|
||||
ID: profID,
|
||||
UpdateTime: updTime,
|
||||
DeviceIDs: deviceIds,
|
||||
RuleListIDs: listIDs,
|
||||
CustomRules: rulesToInternal(ctx, x.CustomRules, errColl),
|
||||
FilteredResponseTTL: fltRespTTL,
|
||||
FilteringEnabled: x.FilteringEnabled,
|
||||
SafeBrowsing: x.SafeBrowsing.toInternal(),
|
||||
RuleListsEnabled: listsEnabled,
|
||||
QueryLogEnabled: x.QueryLogEnabled,
|
||||
Deleted: x.Deleted,
|
||||
BlockPrivateRelay: x.BlockPrivateRelay,
|
||||
BlockFirefoxCanary: x.BlockFirefoxCanary,
|
||||
}, devices, nil
|
||||
}
|
||||
|
||||
// toInternal converts a protobuf parental-settings structure to an internal
|
||||
// one. If x is nil, toInternal returns nil.
|
||||
func (x *ParentalSettings) toInternal(
|
||||
ctx context.Context,
|
||||
errColl agd.ErrorCollector,
|
||||
) (s *agd.ParentalProtectionSettings, err error) {
|
||||
if x == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
schedule, err := x.Schedule.toInternal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("schedule: %w", err)
|
||||
}
|
||||
|
||||
return &agd.ParentalProtectionSettings{
|
||||
Schedule: schedule,
|
||||
BlockedServices: blockedSvcsToInternal(ctx, errColl, x.BlockedServices),
|
||||
Enabled: x.Enabled,
|
||||
BlockAdult: x.BlockAdult,
|
||||
GeneralSafeSearch: x.GeneralSafeSearch,
|
||||
YoutubeSafeSearch: x.YoutubeSafeSearch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toInternal converts protobuf safe-browsing settings to an internal structure.
|
||||
// If x is nil, toInternal returns nil.
|
||||
func (x *SafeBrowsingSettings) toInternal() (sb *agd.SafeBrowsingSettings) {
|
||||
if x == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &agd.SafeBrowsingSettings{
|
||||
Enabled: x.Enabled,
|
||||
BlockDangerousDomains: x.BlockDangerousDomains,
|
||||
BlockNewlyRegisteredDomains: x.BlockNrd,
|
||||
}
|
||||
}
|
||||
|
||||
// blockedSvcsToInternal is a helper that converts the blocked service IDs from
|
||||
// the backend response to AdGuard DNS blocked service IDs.
|
||||
func blockedSvcsToInternal(
|
||||
ctx context.Context,
|
||||
errColl agd.ErrorCollector,
|
||||
respSvcs []string,
|
||||
) (svcs []agd.BlockedServiceID) {
|
||||
l := len(respSvcs)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
svcs = make([]agd.BlockedServiceID, 0, l)
|
||||
for i, s := range respSvcs {
|
||||
id, err := agd.NewBlockedServiceID(s)
|
||||
if err != nil {
|
||||
reportf(ctx, errColl, "blocked service at index %d: %w", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
svcs = append(svcs, id)
|
||||
}
|
||||
|
||||
return svcs
|
||||
}
|
||||
|
||||
// toInternal converts a protobuf protection-schedule structure to an internal
|
||||
// one. If x is nil, toInternal returns nil.
|
||||
func (x *ScheduleSettings) toInternal() (sch *agd.ParentalProtectionSchedule, err error) {
|
||||
if x == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sch = &agd.ParentalProtectionSchedule{}
|
||||
|
||||
sch.TimeZone, err = agdtime.LoadLocation(x.Tmz)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading timezone: %w", err)
|
||||
}
|
||||
|
||||
sch.Week = &agd.WeeklySchedule{}
|
||||
|
||||
w := x.WeeklyRange
|
||||
days := []*DayRange{w.Sun, w.Mon, w.Tue, w.Wed, w.Thu, w.Fri, w.Sat}
|
||||
for i, d := range days {
|
||||
if d == nil {
|
||||
sch.Week[i] = agd.ZeroLengthDayRange()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
sch.Week[i] = agd.DayRange{
|
||||
Start: uint16(d.Start.AsDuration().Minutes()),
|
||||
End: uint16(d.End.AsDuration().Minutes()),
|
||||
}
|
||||
}
|
||||
|
||||
for i, r := range sch.Week {
|
||||
err = r.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weekday %s: %w", time.Weekday(i), err)
|
||||
}
|
||||
}
|
||||
|
||||
return sch, nil
|
||||
}
|
||||
|
||||
// blockingModeToInternal converts a protobuf blocking-mode sum-type to an
|
||||
// internal one. If pbm is nil, blockingModeToInternal returns a null-IP
|
||||
// blocking mode.
|
||||
func blockingModeToInternal(pbm isDNSProfile_BlockingMode) (m dnsmsg.BlockingModeCodec, err error) {
|
||||
switch pbm := pbm.(type) {
|
||||
case nil:
|
||||
m.Mode = &dnsmsg.BlockingModeNullIP{}
|
||||
case *DNSProfile_BlockingModeCustomIp:
|
||||
custom := &dnsmsg.BlockingModeCustomIP{}
|
||||
err = custom.IPv4.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv4)
|
||||
if err != nil {
|
||||
return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv4: %w", err)
|
||||
}
|
||||
|
||||
err = custom.IPv6.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv6)
|
||||
if err != nil {
|
||||
return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv6: %w", err)
|
||||
}
|
||||
|
||||
m.Mode = custom
|
||||
case *DNSProfile_BlockingModeNxdomain:
|
||||
m.Mode = &dnsmsg.BlockingModeNXDOMAIN{}
|
||||
case *DNSProfile_BlockingModeNullIp:
|
||||
m.Mode = &dnsmsg.BlockingModeNullIP{}
|
||||
case *DNSProfile_BlockingModeRefused:
|
||||
m.Mode = &dnsmsg.BlockingModeREFUSED{}
|
||||
default:
|
||||
// Consider unhandled type-switch cases programmer errors.
|
||||
panic(fmt.Errorf("bad pb blocking mode %T(%[1]v)", pbm))
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// devicesToInternal is a helper that converts the devices from protobuf to
|
||||
// AdGuard DNS devices.
|
||||
func devicesToInternal(
|
||||
ctx context.Context,
|
||||
ds []*DeviceSettings,
|
||||
errColl agd.ErrorCollector,
|
||||
) (out []*agd.Device, ids []agd.DeviceID) {
|
||||
l := len(ds)
|
||||
if l == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out = make([]*agd.Device, 0, l)
|
||||
for _, d := range ds {
|
||||
dev, err := d.toInternal()
|
||||
if err != nil {
|
||||
reportf(ctx, errColl, "invalid device settings: %w", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ids = append(ids, dev.ID)
|
||||
out = append(out, dev)
|
||||
}
|
||||
|
||||
return out, ids
|
||||
}
|
||||
|
||||
// toInternal is a helper that converts device settings from backend protobuf
|
||||
// response to AdGuard DNS device object.
|
||||
func (ds *DeviceSettings) toInternal() (dev *agd.Device, err error) {
|
||||
if ds == nil {
|
||||
return nil, fmt.Errorf("device is nil")
|
||||
}
|
||||
|
||||
var linkedIP netip.Addr
|
||||
err = linkedIP.UnmarshalBinary(ds.LinkedIp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("linked ip: %w", err)
|
||||
}
|
||||
|
||||
var dedicatedIPs []netip.Addr
|
||||
dedicatedIPs, err = agdprotobuf.ByteSlicesToIPs(ds.DedicatedIps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dedicated ips: %w", err)
|
||||
}
|
||||
|
||||
id, err := agd.NewDeviceID(ds.Id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device id: %s: %w", ds.Id, err)
|
||||
}
|
||||
|
||||
name, err := agd.NewDeviceName(ds.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device name: %s: %w", ds.Name, err)
|
||||
}
|
||||
|
||||
return &agd.Device{
|
||||
ID: id,
|
||||
Name: name,
|
||||
LinkedIP: linkedIP,
|
||||
DedicatedIPs: dedicatedIPs,
|
||||
FilteringEnabled: ds.FilteringEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// rulesToInternal is a helper that converts the filter rules from the backend
|
||||
// response to AdGuard DNS filtering rules.
|
||||
func rulesToInternal(
|
||||
ctx context.Context,
|
||||
respRules []string,
|
||||
errColl agd.ErrorCollector,
|
||||
) (rules []agd.FilterRuleText) {
|
||||
l := len(respRules)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules = make([]agd.FilterRuleText, 0, l)
|
||||
for i, r := range respRules {
|
||||
text, err := agd.NewFilterRuleText(r)
|
||||
if err != nil {
|
||||
reportf(ctx, errColl, "rule at index %d: %w", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, text)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// toInternal is a helper that converts the filter lists from the backend
|
||||
// response to AdGuard DNS filter list ids. If x is nil, toInternal returns
|
||||
// false and nil.
|
||||
func (x *RuleListsSettings) toInternal(
|
||||
ctx context.Context,
|
||||
errColl agd.ErrorCollector,
|
||||
) (enabled bool, filterLists []agd.FilterListID) {
|
||||
if x == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
l := len(x.Ids)
|
||||
if l == 0 {
|
||||
return x.Enabled, nil
|
||||
}
|
||||
|
||||
filterLists = make([]agd.FilterListID, 0, l)
|
||||
for _, f := range x.Ids {
|
||||
id, err := agd.NewFilterListID(f)
|
||||
if err != nil {
|
||||
reportf(ctx, errColl, "invalid filter id: %s: %w", f, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
filterLists = append(filterLists, id)
|
||||
}
|
||||
|
||||
return x.Enabled, filterLists
|
||||
}
|
||||
|
||||
// toProtobuf converts a storage request structure into the protobuf structure.
|
||||
func toProtobuf(r *profiledb.StorageRequest) (req *DNSProfilesRequest) {
|
||||
return &DNSProfilesRequest{
|
||||
SyncTime: timestamppb.New(r.SyncTime),
|
||||
}
|
||||
}
|
||||
|
||||
// syncTimeFromTrailer returns sync time from trailer metadata. Trailer
|
||||
// metadata must contain "sync_time" field with milliseconds presentation of
|
||||
// sync time.
|
||||
func syncTimeFromTrailer(trailer metadata.MD) (syncTime time.Time, err error) {
|
||||
st := trailer.Get("sync_time")
|
||||
if len(st) == 0 {
|
||||
return syncTime, fmt.Errorf("empty value")
|
||||
}
|
||||
|
||||
syncTimeMs, err := strconv.ParseInt(st[0], 10, 64)
|
||||
if err != nil {
|
||||
return syncTime, fmt.Errorf("invalid value: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(0, syncTimeMs*time.Millisecond.Nanoseconds()), nil
|
||||
}
|
459
internal/backendpb/profiledb_internal_test.go
Normal file
459
internal/backendpb/profiledb_internal_test.go
Normal file
@ -0,0 +1,459 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testProfileID is the common profile ID for tests.
|
||||
const testProfileID agd.ProfileID = "prof1234"
|
||||
|
||||
// TestUpdTime is the common update time for tests.
|
||||
var TestUpdTime = time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
func TestDNSProfile_ToInternal(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
errColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
panic(err)
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
got, gotDevices, err := NewTestDNSProfile(t).toInternal(ctx, TestUpdTime, errColl)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, newProfile(t), got)
|
||||
assert.Equal(t, newDevices(t), gotDevices)
|
||||
})
|
||||
|
||||
t.Run("success_bad_data", func(t *testing.T) {
|
||||
var errCollErr error
|
||||
savingErrColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
errCollErr = err
|
||||
},
|
||||
}
|
||||
got, gotDevices, err := newDNSProfileWithBadData(t).toInternal(
|
||||
ctx,
|
||||
TestUpdTime,
|
||||
savingErrColl,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Error(t, errCollErr)
|
||||
|
||||
// See the TODO in [blockingModeToInternal].
|
||||
wantProf := newProfile(t)
|
||||
wantProf.BlockingMode = dnsmsg.BlockingModeCodec{
|
||||
Mode: &dnsmsg.BlockingModeNullIP{},
|
||||
}
|
||||
|
||||
assert.Equal(t, wantProf, got)
|
||||
assert.Equal(t, newDevices(t), gotDevices)
|
||||
})
|
||||
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
var emptyDNSProfile *DNSProfile
|
||||
_, _, err := emptyDNSProfile.toInternal(ctx, TestUpdTime, errColl)
|
||||
testutil.AssertErrorMsg(t, "profile is nil", err)
|
||||
})
|
||||
|
||||
t.Run("deleted", func(t *testing.T) {
|
||||
dp := &DNSProfile{
|
||||
DnsId: string(testProfileID),
|
||||
Deleted: true,
|
||||
}
|
||||
|
||||
got, gotDevices, err := dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
assert.Equal(t, got.ID, testProfileID)
|
||||
assert.True(t, got.Deleted)
|
||||
assert.Empty(t, gotDevices)
|
||||
})
|
||||
|
||||
t.Run("inv_parental_sch_tmz", func(t *testing.T) {
|
||||
dp := NewTestDNSProfile(t)
|
||||
dp.Parental.Schedule.Tmz = "invalid"
|
||||
|
||||
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
testutil.AssertErrorMsg(t, "parental: schedule: loading timezone: unknown time zone invalid", err)
|
||||
})
|
||||
|
||||
t.Run("inv_parental_sch_day_range", func(t *testing.T) {
|
||||
dp := NewTestDNSProfile(t)
|
||||
dp.Parental.Schedule.WeeklyRange.Sun = &DayRange{
|
||||
Start: durationpb.New(1000000000000),
|
||||
End: nil,
|
||||
}
|
||||
|
||||
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
testutil.AssertErrorMsg(t, "parental: schedule: weekday Sunday: bad day range: end 0 less than start 16", err)
|
||||
})
|
||||
|
||||
t.Run("inv_blocking_mode_v4", func(t *testing.T) {
|
||||
dp := NewTestDNSProfile(t)
|
||||
bm := dp.BlockingMode.(*DNSProfile_BlockingModeCustomIp)
|
||||
bm.BlockingModeCustomIp.Ipv4 = []byte("1")
|
||||
|
||||
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
testutil.AssertErrorMsg(t, "blocking mode: bad custom ipv4: unexpected slice size", err)
|
||||
})
|
||||
|
||||
t.Run("inv_blocking_mode_v6", func(t *testing.T) {
|
||||
dp := NewTestDNSProfile(t)
|
||||
bm := dp.BlockingMode.(*DNSProfile_BlockingModeCustomIp)
|
||||
bm.BlockingModeCustomIp.Ipv6 = []byte("1")
|
||||
|
||||
_, _, err := dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
testutil.AssertErrorMsg(t, "blocking mode: bad custom ipv6: unexpected slice size", err)
|
||||
})
|
||||
}
|
||||
|
||||
// newDNSProfileWithBadData returns a new instance of *DNSProfile with bad data
|
||||
// for tests.
|
||||
func newDNSProfileWithBadData(tb testing.TB) (dp *DNSProfile) {
|
||||
tb.Helper()
|
||||
|
||||
dayRange := &DayRange{
|
||||
Start: durationpb.New(0),
|
||||
End: durationpb.New(59 * time.Minute),
|
||||
}
|
||||
|
||||
devices := []*DeviceSettings{{
|
||||
Id: "118ffe93",
|
||||
Name: "118ffe93-name",
|
||||
FilteringEnabled: false,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
|
||||
DedicatedIps: [][]byte{ipToBytes(tb, netip.MustParseAddr("1.1.1.2"))},
|
||||
}, {
|
||||
Id: "b9e1a762",
|
||||
Name: "b9e1a762-name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("2.2.2.2")),
|
||||
DedicatedIps: nil,
|
||||
}, {
|
||||
Id: "invalid-too-long-device-id",
|
||||
Name: "device_name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
|
||||
DedicatedIps: nil,
|
||||
}, {
|
||||
Id: "dev-name",
|
||||
Name: "invalid-too-long-device-name-invalid-too-long-device-name-" +
|
||||
"invalid-too-long-device-name-invalid-too-long-device-name-" +
|
||||
"invalid-too-long-device-name-invalid-too-long-device-name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
|
||||
DedicatedIps: nil,
|
||||
}, {
|
||||
Id: "inv-ip",
|
||||
Name: "test-name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: []byte("1"),
|
||||
DedicatedIps: nil,
|
||||
}, {
|
||||
Id: "inv-d-ip",
|
||||
Name: "test-name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
|
||||
DedicatedIps: [][]byte{[]byte("1")},
|
||||
}}
|
||||
|
||||
return &DNSProfile{
|
||||
DnsId: string(testProfileID),
|
||||
FilteringEnabled: true,
|
||||
QueryLogEnabled: true,
|
||||
Deleted: false,
|
||||
SafeBrowsing: &SafeBrowsingSettings{
|
||||
Enabled: true,
|
||||
BlockDangerousDomains: true,
|
||||
BlockNrd: false,
|
||||
},
|
||||
Parental: &ParentalSettings{
|
||||
Enabled: false,
|
||||
BlockAdult: false,
|
||||
GeneralSafeSearch: false,
|
||||
YoutubeSafeSearch: false,
|
||||
BlockedServices: []string{"youtube", "inv_blocked_svc\r"},
|
||||
Schedule: &ScheduleSettings{
|
||||
Tmz: "GMT",
|
||||
WeeklyRange: &WeeklyRange{
|
||||
Sun: nil,
|
||||
Mon: dayRange,
|
||||
Tue: dayRange,
|
||||
Wed: dayRange,
|
||||
Thu: dayRange,
|
||||
Fri: dayRange,
|
||||
Sat: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
RuleLists: &RuleListsSettings{
|
||||
Enabled: true,
|
||||
Ids: []string{"1", "inv_filter_id\r"},
|
||||
},
|
||||
Devices: devices,
|
||||
CustomRules: []string{"||example.org^"},
|
||||
FilteredResponseTtl: durationpb.New(10 * time.Second),
|
||||
BlockPrivateRelay: true,
|
||||
BlockFirefoxCanary: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestDNSProfile returns a new instance of *DNSProfile for tests.
|
||||
func NewTestDNSProfile(tb testing.TB) (dp *DNSProfile) {
|
||||
tb.Helper()
|
||||
|
||||
dayRange := &DayRange{
|
||||
Start: durationpb.New(0),
|
||||
End: durationpb.New(59 * time.Minute),
|
||||
}
|
||||
|
||||
devices := []*DeviceSettings{{
|
||||
Id: "118ffe93",
|
||||
Name: "118ffe93-name",
|
||||
FilteringEnabled: false,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("1.1.1.1")),
|
||||
DedicatedIps: [][]byte{ipToBytes(tb, netip.MustParseAddr("1.1.1.2"))},
|
||||
}, {
|
||||
Id: "b9e1a762",
|
||||
Name: "b9e1a762-name",
|
||||
FilteringEnabled: true,
|
||||
LinkedIp: ipToBytes(tb, netip.MustParseAddr("2.2.2.2")),
|
||||
DedicatedIps: nil,
|
||||
}}
|
||||
|
||||
return &DNSProfile{
|
||||
DnsId: string(testProfileID),
|
||||
FilteringEnabled: true,
|
||||
QueryLogEnabled: true,
|
||||
Deleted: false,
|
||||
SafeBrowsing: &SafeBrowsingSettings{
|
||||
Enabled: true,
|
||||
BlockDangerousDomains: true,
|
||||
BlockNrd: false,
|
||||
},
|
||||
Parental: &ParentalSettings{
|
||||
Enabled: false,
|
||||
BlockAdult: false,
|
||||
GeneralSafeSearch: false,
|
||||
YoutubeSafeSearch: false,
|
||||
BlockedServices: []string{"youtube"},
|
||||
Schedule: &ScheduleSettings{
|
||||
Tmz: "GMT",
|
||||
WeeklyRange: &WeeklyRange{
|
||||
Sun: nil,
|
||||
Mon: dayRange,
|
||||
Tue: dayRange,
|
||||
Wed: dayRange,
|
||||
Thu: dayRange,
|
||||
Fri: dayRange,
|
||||
Sat: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
RuleLists: &RuleListsSettings{
|
||||
Enabled: true,
|
||||
Ids: []string{"1"},
|
||||
},
|
||||
Devices: devices,
|
||||
CustomRules: []string{"||example.org^"},
|
||||
FilteredResponseTtl: durationpb.New(10 * time.Second),
|
||||
BlockPrivateRelay: true,
|
||||
BlockFirefoxCanary: true,
|
||||
BlockingMode: &DNSProfile_BlockingModeCustomIp{
|
||||
BlockingModeCustomIp: &BlockingModeCustomIP{
|
||||
Ipv4: ipToBytes(tb, netip.MustParseAddr("1.2.3.4")),
|
||||
Ipv6: ipToBytes(tb, netip.MustParseAddr("1234::cdef")),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newProfile returns a new profile for tests.
|
||||
func newProfile(tb testing.TB) (p *agd.Profile) {
|
||||
tb.Helper()
|
||||
|
||||
wantLoc, err := agdtime.LoadLocation("GMT")
|
||||
require.NoError(tb, err)
|
||||
|
||||
dayRange := agd.DayRange{
|
||||
Start: 0,
|
||||
End: 59,
|
||||
}
|
||||
|
||||
wantParental := &agd.ParentalProtectionSettings{
|
||||
Schedule: &agd.ParentalProtectionSchedule{
|
||||
Week: &agd.WeeklySchedule{
|
||||
agd.ZeroLengthDayRange(),
|
||||
dayRange,
|
||||
dayRange,
|
||||
dayRange,
|
||||
dayRange,
|
||||
dayRange,
|
||||
agd.ZeroLengthDayRange(),
|
||||
},
|
||||
TimeZone: wantLoc,
|
||||
},
|
||||
BlockedServices: []agd.BlockedServiceID{"youtube"},
|
||||
Enabled: false,
|
||||
BlockAdult: false,
|
||||
GeneralSafeSearch: false,
|
||||
YoutubeSafeSearch: false,
|
||||
}
|
||||
|
||||
wantSafeBrowsing := &agd.SafeBrowsingSettings{
|
||||
Enabled: true,
|
||||
BlockDangerousDomains: true,
|
||||
BlockNewlyRegisteredDomains: false,
|
||||
}
|
||||
|
||||
wantBlockingMode := dnsmsg.BlockingModeCodec{
|
||||
Mode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: netip.MustParseAddr("1.2.3.4"),
|
||||
IPv6: netip.MustParseAddr("1234::cdef"),
|
||||
},
|
||||
}
|
||||
|
||||
return &agd.Profile{
|
||||
Parental: wantParental,
|
||||
BlockingMode: wantBlockingMode,
|
||||
ID: testProfileID,
|
||||
UpdateTime: TestUpdTime,
|
||||
DeviceIDs: []agd.DeviceID{
|
||||
"118ffe93",
|
||||
"b9e1a762",
|
||||
},
|
||||
RuleListIDs: []agd.FilterListID{"1"},
|
||||
CustomRules: []agd.FilterRuleText{"||example.org^"},
|
||||
FilteredResponseTTL: 10 * time.Second,
|
||||
SafeBrowsing: wantSafeBrowsing,
|
||||
RuleListsEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
QueryLogEnabled: true,
|
||||
Deleted: false,
|
||||
BlockPrivateRelay: true,
|
||||
BlockFirefoxCanary: true,
|
||||
}
|
||||
}
|
||||
|
||||
// newDevices returns a slice of test devices.
|
||||
func newDevices(t *testing.T) (d []*agd.Device) {
|
||||
t.Helper()
|
||||
|
||||
return []*agd.Device{{
|
||||
ID: "118ffe93",
|
||||
LinkedIP: netip.MustParseAddr("1.1.1.1"),
|
||||
Name: "118ffe93-name",
|
||||
DedicatedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
FilteringEnabled: false,
|
||||
}, {
|
||||
ID: "b9e1a762",
|
||||
LinkedIP: netip.MustParseAddr("2.2.2.2"),
|
||||
Name: "b9e1a762-name",
|
||||
DedicatedIPs: nil,
|
||||
FilteringEnabled: true,
|
||||
}}
|
||||
}
|
||||
|
||||
// ipToBytes is a wrapper around netip.Addr.MarshalBinary.
|
||||
func ipToBytes(tb testing.TB, ip netip.Addr) (b []byte) {
|
||||
tb.Helper()
|
||||
|
||||
b, err := ip.MarshalBinary()
|
||||
require.NoError(tb, err)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSyncTimeFromTrailer(t *testing.T) {
|
||||
milliseconds := strconv.FormatInt(TestUpdTime.UnixMilli(), 10)
|
||||
|
||||
testCases := []struct {
|
||||
wantError string
|
||||
want time.Time
|
||||
name string
|
||||
in metadata.MD
|
||||
}{{
|
||||
wantError: "empty value",
|
||||
want: time.Time{},
|
||||
name: "no_key",
|
||||
in: metadata.MD{},
|
||||
}, {
|
||||
wantError: "empty value",
|
||||
want: time.Time{},
|
||||
name: "empty_key",
|
||||
in: metadata.MD{"sync_time": []string{}},
|
||||
}, {
|
||||
wantError: `invalid value: strconv.ParseInt: parsing "": invalid syntax`,
|
||||
want: time.Time{},
|
||||
name: "empty_value",
|
||||
in: metadata.MD{"sync_time": []string{""}},
|
||||
}, {
|
||||
wantError: "",
|
||||
want: TestUpdTime,
|
||||
name: "success",
|
||||
in: metadata.MD{"sync_time": []string{milliseconds}},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
syncTime, err := syncTimeFromTrailer(tc.in)
|
||||
testutil.AssertErrorMsg(t, tc.wantError, err)
|
||||
assert.True(t, tc.want.Equal(syncTime), "want %s; got %s", tc.want, syncTime)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errSink error
|
||||
profSink *agd.Profile
|
||||
)
|
||||
|
||||
func BenchmarkDNSProfile_ToInternal(b *testing.B) {
|
||||
dp := NewTestDNSProfile(b)
|
||||
ctx := context.Background()
|
||||
|
||||
errColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
panic(err)
|
||||
},
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
profSink, _, errSink = dp.toInternal(ctx, TestUpdTime, errColl)
|
||||
}
|
||||
|
||||
require.NotNil(b, profSink)
|
||||
require.NoError(b, errSink)
|
||||
|
||||
// Most recent result, on a ThinkPad X13:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/backendpb
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkDNSProfile_ToInternal
|
||||
// BenchmarkDNSProfile_ToInternal-16 157513 10340 ns/op 1148 B/op 27 allocs/op
|
||||
}
|
96
internal/backendpb/profiledb_test.go
Normal file
96
internal/backendpb/profiledb_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package backendpb_test
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var (
|
||||
errSink error
|
||||
respSink *profiledb.StorageResponse
|
||||
)
|
||||
|
||||
func BenchmarkProfileStorage_Profiles(b *testing.B) {
|
||||
syncTime := strconv.FormatInt(backendpb.TestUpdTime.UnixMilli(), 10)
|
||||
srvProf := backendpb.NewTestDNSProfile(b)
|
||||
trailerMD := metadata.MD{
|
||||
"sync_time": []string{syncTime},
|
||||
}
|
||||
|
||||
srv := &testDNSServiceServer{
|
||||
OnGetDNSProfiles: func(
|
||||
req *backendpb.DNSProfilesRequest,
|
||||
srv backendpb.DNSService_GetDNSProfilesServer,
|
||||
) (err error) {
|
||||
sendErr := srv.Send(srvProf)
|
||||
srv.SetTrailer(trailerMD)
|
||||
|
||||
return sendErr
|
||||
},
|
||||
}
|
||||
|
||||
errColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
panic(err)
|
||||
},
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(b, err)
|
||||
|
||||
s, err := backendpb.NewProfileStorage(&backendpb.ProfileStorageConfig{
|
||||
ErrColl: errColl,
|
||||
Endpoint: &url.URL{
|
||||
Scheme: "grpc",
|
||||
Host: l.Addr().String(),
|
||||
},
|
||||
})
|
||||
require.NoError(b, err)
|
||||
|
||||
grpcSrv := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(1*time.Second),
|
||||
grpc.Creds(insecure.NewCredentials()),
|
||||
)
|
||||
backendpb.RegisterDNSServiceServer(grpcSrv, srv)
|
||||
|
||||
go func() {
|
||||
pt := &testutil.PanicT{}
|
||||
|
||||
srvErr := grpcSrv.Serve(l)
|
||||
require.NoError(pt, srvErr)
|
||||
}()
|
||||
b.Cleanup(grpcSrv.GracefulStop)
|
||||
|
||||
ctx := context.Background()
|
||||
req := &profiledb.StorageRequest{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
respSink, errSink = s.Profiles(ctx, req)
|
||||
}
|
||||
|
||||
require.NoError(b, errSink)
|
||||
require.NotNil(b, respSink)
|
||||
|
||||
// Most recent result, on a ThinkPad X13:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/backendpb
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkProfileStorage_Profiles
|
||||
// BenchmarkProfileStorage_Profiles-16 5347 245341 ns/op 15129 B/op 265 allocs/op
|
||||
}
|
82
internal/backendpb/stats.go
Normal file
82
internal/backendpb/stats.go
Normal file
@ -0,0 +1,82 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// profilesCallStats is a stateful structure that collects and reports
|
||||
// statistics about a [ProfileStorage.Profiles] call.
|
||||
type profilesCallStats struct {
|
||||
recvStart time.Time
|
||||
decStart time.Time
|
||||
|
||||
initRecv time.Duration
|
||||
totalRecv time.Duration
|
||||
totalDec time.Duration
|
||||
|
||||
numRecv int
|
||||
|
||||
isFullSync bool
|
||||
}
|
||||
|
||||
// startRecv starts the receive timer.
|
||||
func (s *profilesCallStats) startRecv() {
|
||||
s.recvStart = time.Now()
|
||||
}
|
||||
|
||||
// endRecv ends the receive timer and records the results.
|
||||
func (s *profilesCallStats) endRecv() {
|
||||
d := time.Since(s.recvStart)
|
||||
if s.numRecv == 0 {
|
||||
// Count the initial receive separately, since it is often not
|
||||
// representative of an average receive, because this is when gRPC
|
||||
// actually performs the call.
|
||||
s.initRecv = d
|
||||
} else {
|
||||
s.totalRecv += d
|
||||
}
|
||||
|
||||
s.numRecv++
|
||||
}
|
||||
|
||||
// startDec starts the decoding timer.
|
||||
func (s *profilesCallStats) startDec() {
|
||||
s.decStart = time.Now()
|
||||
}
|
||||
|
||||
// endDec ends the decoding timer and records the results.
|
||||
func (s *profilesCallStats) endDec() {
|
||||
s.totalDec += time.Since(s.decStart)
|
||||
}
|
||||
|
||||
// report writes the statistics to the log and the metrics.
|
||||
func (s *profilesCallStats) report() {
|
||||
logFunc := log.Debug
|
||||
if s.isFullSync {
|
||||
logFunc = log.Info
|
||||
}
|
||||
|
||||
if s.numRecv == 0 {
|
||||
logFunc("backendpb: no recv")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
n := time.Duration(s.numRecv)
|
||||
avgRecv := s.totalRecv / n
|
||||
avgDec := s.totalDec / n
|
||||
|
||||
logFunc(
|
||||
"backendpb: total recv: %s; agv recv: %s; init recv: %s",
|
||||
s.totalRecv,
|
||||
avgRecv,
|
||||
s.initRecv,
|
||||
)
|
||||
logFunc("backendpb: total dec: %s; agv dec: %s", s.totalDec, avgDec)
|
||||
|
||||
metrics.GRPCAvgProfileRecvDuration.Observe(avgRecv.Seconds())
|
||||
metrics.GRPCAvgProfileDecDuration.Observe(avgDec.Seconds())
|
||||
}
|
@ -86,10 +86,17 @@ var _ agd.Refresher = (*RuntimeRecorder)(nil)
|
||||
// uploads the currently available data and resets it.
|
||||
func (r *RuntimeRecorder) Refresh(ctx context.Context) (err error) {
|
||||
records := r.resetRecords()
|
||||
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
dur := time.Since(startTime).Seconds()
|
||||
metrics.BillStatUploadDuration.Observe(dur)
|
||||
|
||||
if err != nil {
|
||||
r.remergeRecords(records)
|
||||
log.Info("billstat: refresh failed, records remerged")
|
||||
} else {
|
||||
metrics.BillStatUploadTimestamp.SetToCurrentTime()
|
||||
}
|
||||
|
||||
metrics.SetStatusGauge(metrics.BillStatUploadStatus, err)
|
||||
@ -100,8 +107,6 @@ func (r *RuntimeRecorder) Refresh(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("uploading billstat records: %w", err)
|
||||
}
|
||||
|
||||
metrics.BillStatUploadTimestamp.SetToCurrentTime()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -6,30 +6,35 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// chanListener is a [net.Listener] that returns data sent to it through a
|
||||
// channel.
|
||||
//
|
||||
// Listeners of this type are returned by [chanListenConfig.Listen] and are used
|
||||
// in module dnsserver to make the bind-to-device logic work in DNS-over-TCP.
|
||||
// Listeners of this type are returned by [ListenConfig.Listen] and are used in
|
||||
// module dnsserver to make the bind-to-device logic work in DNS-over-TCP.
|
||||
type chanListener struct {
|
||||
// mu protects conns (against closure) and isClosed.
|
||||
mu *sync.Mutex
|
||||
conns chan net.Conn
|
||||
laddr net.Addr
|
||||
subnet netip.Prefix
|
||||
isClosed bool
|
||||
mu *sync.Mutex
|
||||
conns chan net.Conn
|
||||
connsGauge prometheus.Gauge
|
||||
laddr net.Addr
|
||||
subnet netip.Prefix
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
// newChanListener returns a new properly initialized *chanListener.
|
||||
func newChanListener(conns chan net.Conn, subnet netip.Prefix, laddr net.Addr) (l *chanListener) {
|
||||
return &chanListener{
|
||||
mu: &sync.Mutex{},
|
||||
conns: conns,
|
||||
laddr: laddr,
|
||||
subnet: subnet,
|
||||
isClosed: false,
|
||||
mu: &sync.Mutex{},
|
||||
conns: conns,
|
||||
connsGauge: metrics.BindToDeviceTCPConnsChanSize.WithLabelValues(subnet.String()),
|
||||
laddr: laddr,
|
||||
subnet: subnet,
|
||||
isClosed: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,5 +82,7 @@ func (l *chanListener) send(conn net.Conn) (ok bool) {
|
||||
|
||||
l.conns <- conn
|
||||
|
||||
l.connsGauge.Set(float64(len(l.conns)))
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -11,13 +11,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// chanPacketConn is a [netext.SessionPacketConn] that returns data sent to it
|
||||
// through the channel.
|
||||
//
|
||||
// Connections of this type are returned by [chanListenConfig.ListenPacket] and
|
||||
// are used in module dnsserver to make the bind-to-device logic work in
|
||||
// Connections of this type are returned by [ListenConfig.ListenPacket] and are
|
||||
// used in module dnsserver to make the bind-to-device logic work in
|
||||
// DNS-over-UDP.
|
||||
type chanPacketConn struct {
|
||||
// mu protects sessions (against closure) and isClosed.
|
||||
@ -26,6 +28,9 @@ type chanPacketConn struct {
|
||||
|
||||
writeRequests chan *packetConnWriteReq
|
||||
|
||||
sessionsGauge prometheus.Gauge
|
||||
writeRequestsGauge prometheus.Gauge
|
||||
|
||||
// deadlineMu protects readDeadline and writeDeadline.
|
||||
deadlineMu *sync.RWMutex
|
||||
readDeadline time.Time
|
||||
@ -48,6 +53,9 @@ func newChanPacketConn(
|
||||
sessions: sessions,
|
||||
writeRequests: writeRequests,
|
||||
|
||||
sessionsGauge: metrics.BindToDeviceUDPSessionsChanSize.WithLabelValues(subnet.String()),
|
||||
writeRequestsGauge: metrics.BindToDeviceUDPWriteRequestsChanSize.WithLabelValues(subnet.String()),
|
||||
|
||||
deadlineMu: &sync.RWMutex{},
|
||||
|
||||
laddr: laddr,
|
||||
@ -257,6 +265,8 @@ func (c *chanPacketConn) writeToSession(
|
||||
return 0, wrapConnError(tnChanPConn, fnName, c.laddr, err)
|
||||
}
|
||||
|
||||
c.writeRequestsGauge.Set(float64(len(c.writeRequests)))
|
||||
|
||||
r, err := receiveWithTimer(resp, timerCh)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("receiving write response: %w", err)
|
||||
@ -309,5 +319,7 @@ func (c *chanPacketConn) send(sess *packetSession) (ok bool) {
|
||||
|
||||
c.sessions <- sess
|
||||
|
||||
c.sessionsGauge.Set(float64(len(c.sessions)))
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -20,27 +20,23 @@ type connIndex struct {
|
||||
listeners []*chanListener
|
||||
}
|
||||
|
||||
// subnetSortsBefore returns true if subnet x sorts before subnet y.
|
||||
func subnetSortsBefore(x, y netip.Prefix) (isBefore bool) {
|
||||
xAddr, xBits := x.Addr(), x.Bits()
|
||||
yAddr, yBits := y.Addr(), y.Bits()
|
||||
if xBits == yBits {
|
||||
return xAddr.Less(yAddr)
|
||||
}
|
||||
|
||||
return xBits > yBits
|
||||
}
|
||||
|
||||
// subnetCompare is a comparison function for the two subnets. It returns -1 if
|
||||
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
|
||||
// position is the same.
|
||||
func subnetCompare(x, y netip.Prefix) (cmp int) {
|
||||
switch {
|
||||
case x == y:
|
||||
if x == y {
|
||||
return 0
|
||||
case subnetSortsBefore(x, y):
|
||||
}
|
||||
|
||||
xAddr, xBits := x.Addr(), x.Bits()
|
||||
yAddr, yBits := y.Addr(), y.Bits()
|
||||
if xBits == yBits {
|
||||
return xAddr.Compare(yAddr)
|
||||
}
|
||||
|
||||
if xBits > yBits {
|
||||
return -1
|
||||
default:
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestSubnetSortsBefore(t *testing.T) {
|
||||
func TestSubnetCompare(t *testing.T) {
|
||||
want := []netip.Prefix{
|
||||
netip.MustParsePrefix("1.0.0.0/24"),
|
||||
netip.MustParsePrefix("1.2.3.0/24"),
|
||||
@ -24,6 +24,6 @@ func TestSubnetSortsBefore(t *testing.T) {
|
||||
netip.MustParsePrefix("1.2.3.0/24"),
|
||||
}
|
||||
|
||||
slices.SortFunc(got, subnetSortsBefore)
|
||||
slices.SortFunc(got, subnetCompare)
|
||||
assert.Equalf(t, want, got, "got (as strings): %q", got)
|
||||
}
|
||||
|
@ -9,18 +9,22 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// interfaceListener contains information about a single interface listener.
|
||||
type interfaceListener struct {
|
||||
conns *connIndex
|
||||
listenConf *net.ListenConfig
|
||||
bodyPool *agdsync.TypedPool[[]byte]
|
||||
oobPool *agdsync.TypedPool[[]byte]
|
||||
writeRequests chan *packetConnWriteReq
|
||||
done chan unit
|
||||
listenConf *net.ListenConfig
|
||||
errColl agd.ErrorCollector
|
||||
ifaceName string
|
||||
port uint16
|
||||
@ -63,17 +67,30 @@ func (l *interfaceListener) listenTCP(errCh chan<- error) {
|
||||
continue
|
||||
}
|
||||
|
||||
laddr := netutil.NetAddrToAddrPort(conn.LocalAddr())
|
||||
lsnr := l.conns.listener(laddr.Addr())
|
||||
if lsnr == nil {
|
||||
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
|
||||
|
||||
continue
|
||||
}
|
||||
l.processConn(conn, logPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// processConn processes a single connection. If the connection doesn't have a
|
||||
// connected channel-listener, it is closed.
|
||||
func (l *interfaceListener) processConn(conn net.Conn, logPrefix string) {
|
||||
laddr := netutil.NetAddrToAddrPort(conn.LocalAddr())
|
||||
raddr := conn.RemoteAddr()
|
||||
if lsnr := l.conns.listener(laddr.Addr()); lsnr != nil {
|
||||
if !lsnr.send(conn) {
|
||||
log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
|
||||
log.Info("%s: from raddr %s: channel for laddr %s is closed", logPrefix, raddr, laddr)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
metrics.BindToDeviceUnknownTCPRequestsTotal.Inc()
|
||||
|
||||
optlog.Debug3("%s: from raddr %s: no stream channel for laddr %s", logPrefix, raddr, laddr)
|
||||
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Debug("%s: from raddr %s: closing: %s", logPrefix, raddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,29 +129,62 @@ func (l *interfaceListener) listenUDP(errCh chan<- error) {
|
||||
// Go on.
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Consider customization of body sizes.
|
||||
var sess *packetSession
|
||||
sess, err = readPacketSession(udpConn, dns.DefaultMsgSize)
|
||||
err = l.readUDP(udpConn, logPrefix)
|
||||
if err != nil {
|
||||
agd.Collectf(ctx, l.errColl, "%s: reading session: %w", logPrefix, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
laddr := sess.laddr.AddrPort().Addr()
|
||||
chanPConn := l.conns.packetConn(laddr)
|
||||
if chanPConn == nil {
|
||||
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !chanPConn.send(sess) {
|
||||
log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readUDP reads a UDP session from c and sends it to the appropriate channel.
|
||||
func (l *interfaceListener) readUDP(c *net.UDPConn, logPrefix string) (err error) {
|
||||
bodyPtr := l.bodyPool.Get()
|
||||
body := *bodyPtr
|
||||
|
||||
// Extend body to the capacity in case it had already been used and sliced
|
||||
// by [readPacketSession].
|
||||
body = body[:cap(body)]
|
||||
|
||||
oobPtr := l.oobPool.Get()
|
||||
oob := *oobPtr
|
||||
|
||||
defer func() {
|
||||
l.oobPool.Put(oobPtr)
|
||||
|
||||
// Only return the body to the pool in case of error here. The actual
|
||||
// return is done in writeUDP.
|
||||
if err != nil {
|
||||
l.bodyPool.Put(bodyPtr)
|
||||
}
|
||||
}()
|
||||
|
||||
sess, err := readPacketSession(c, body, oob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading session: %w", err)
|
||||
}
|
||||
|
||||
laddr := sess.laddr.AddrPort().Addr()
|
||||
chanPacketConn := l.conns.packetConn(laddr)
|
||||
if chanPacketConn == nil {
|
||||
metrics.BindToDeviceUnknownUDPRequestsTotal.Inc()
|
||||
|
||||
optlog.Debug3(
|
||||
"%s: from raddr %s: no packet channel for laddr %s",
|
||||
logPrefix,
|
||||
sess.raddr,
|
||||
laddr,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if !chanPacketConn.send(sess) {
|
||||
log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeUDP runs the UDP write loop. It is intended to be used as a goroutine.
|
||||
func (l *interfaceListener) writeUDP(c *net.UDPConn) {
|
||||
defer log.OnPanic("interfaceListener.writeUDP")
|
||||
@ -167,6 +217,8 @@ func (l *interfaceListener) writeUDP(c *net.UDPConn) {
|
||||
s.respOOB,
|
||||
req.session.raddr,
|
||||
)
|
||||
|
||||
l.bodyPool.Put(&s.readBody)
|
||||
}
|
||||
|
||||
resetDeadlineErr := c.SetWriteDeadline(time.Time{})
|
||||
|
@ -18,7 +18,7 @@ type NetInterface interface {
|
||||
// type check
|
||||
var _ NetInterface = osInterface{}
|
||||
|
||||
// osInterface is a wapper around [*net.Interface] that implements the
|
||||
// osInterface is a wrapper around [*net.Interface] that implements the
|
||||
// [NetInterface] interface.
|
||||
type osInterface struct {
|
||||
iface *net.Interface
|
||||
|
@ -9,23 +9,24 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
)
|
||||
|
||||
// chanListenConfig is a [netext.ListenConfig] implementation that uses the
|
||||
// ListenConfig is a [netext.ListenConfig] implementation that uses the
|
||||
// provided channel-based packet connection and listener to implement the
|
||||
// methods of the interface.
|
||||
//
|
||||
// netext.ListenConfig instances of this type are the ones that are going to be
|
||||
// set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic
|
||||
// work.
|
||||
type chanListenConfig struct {
|
||||
type ListenConfig struct {
|
||||
packetConn *chanPacketConn
|
||||
listener *chanListener
|
||||
addr string
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ netext.ListenConfig = (*chanListenConfig)(nil)
|
||||
var _ netext.ListenConfig = (*ListenConfig)(nil)
|
||||
|
||||
// Listen implements the [netext.ListenConfig] interface for *chanListenConfig.
|
||||
func (lc *chanListenConfig) Listen(
|
||||
// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
|
||||
func (lc *ListenConfig) Listen(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
address string,
|
||||
@ -34,11 +35,17 @@ func (lc *chanListenConfig) Listen(
|
||||
}
|
||||
|
||||
// ListenPacket implements the [netext.ListenConfig] interface for
|
||||
// *chanListenConfig.
|
||||
func (lc *chanListenConfig) ListenPacket(
|
||||
// *ListenConfig.
|
||||
func (lc *ListenConfig) ListenPacket(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
address string,
|
||||
) (c net.PacketConn, err error) {
|
||||
return lc.packetConn, nil
|
||||
}
|
||||
|
||||
// Addr returns the address on which lc accepts connections. See
|
||||
// [agdnet.FormatPrefixAddr] for the format.
|
||||
func (lc *ListenConfig) Addr() (addr string) {
|
||||
return lc.addr
|
||||
}
|
@ -6,16 +6,19 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChanListenConfig(t *testing.T) {
|
||||
func TestListenConfig(t *testing.T) {
|
||||
pc := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr)
|
||||
lsnr := newChanListener(nil, testSubnetIPv4, testLAddr)
|
||||
c := chanListenConfig{
|
||||
addr := agdnet.FormatPrefixAddr(testSubnetIPv4, 1234)
|
||||
c := &ListenConfig{
|
||||
packetConn: pc,
|
||||
listener: lsnr,
|
||||
addr: addr,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@ -29,4 +32,7 @@ func TestChanListenConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, lsnr, gotLsnr)
|
||||
|
||||
gotAddr := c.Addr()
|
||||
assert.Equal(t, addr, gotAddr)
|
||||
}
|
55
internal/bindtodevice/listenconfig_others.go
Normal file
55
internal/bindtodevice/listenconfig_others.go
Normal file
@ -0,0 +1,55 @@
|
||||
//go:build !linux
|
||||
|
||||
package bindtodevice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
)
|
||||
|
||||
// ListenConfig is a [netext.ListenConfig] implementation that uses the
|
||||
// provided channel-based packet connection and listener to implement the
|
||||
// methods of the interface.
|
||||
//
|
||||
// netext.ListenConfig instances of this type are the ones that are going to be
|
||||
// set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic
|
||||
// work.
|
||||
//
|
||||
// It is only supported on Linux.
|
||||
type ListenConfig struct{}
|
||||
|
||||
// type check
|
||||
var _ netext.ListenConfig = (*ListenConfig)(nil)
|
||||
|
||||
// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
|
||||
//
|
||||
// It is only supported on Linux.
|
||||
func (lc *ListenConfig) Listen(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
address string,
|
||||
) (l net.Listener, err error) {
|
||||
return nil, errUnsupported
|
||||
}
|
||||
|
||||
// ListenPacket implements the [netext.ListenConfig] interface for
|
||||
// *ListenConfig.
|
||||
//
|
||||
// It is only supported on Linux.
|
||||
func (lc *ListenConfig) ListenPacket(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
address string,
|
||||
) (c net.PacketConn, err error) {
|
||||
return nil, errUnsupported
|
||||
}
|
||||
|
||||
// Addr returns the address on which lc accepts connections. See
|
||||
// [agdnet.FormatPrefixAddr] for the format.
|
||||
//
|
||||
// It is only supported on Linux.
|
||||
func (lc *ListenConfig) Addr() (addr string) {
|
||||
return ""
|
||||
}
|
@ -10,10 +10,13 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mapsutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Manager creates individual listeners and dispatches connections to them.
|
||||
@ -49,7 +52,7 @@ var defaultCtrlConf = &ControlConfig{
|
||||
// configuration is used.
|
||||
//
|
||||
// Add must not be called after Start is called.
|
||||
func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig) (err error) {
|
||||
func (m *Manager) Add(id ID, ifaceName string, port uint16, ctrlConf *ControlConfig) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "adding interface listener with id %q: %w", id) }()
|
||||
|
||||
_, err = m.interfaces.InterfaceByName(ifaceName)
|
||||
@ -86,29 +89,51 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
conf = defaultCtrlConf
|
||||
if ctrlConf == nil {
|
||||
ctrlConf = defaultCtrlConf
|
||||
}
|
||||
|
||||
m.ifaceListeners[id] = &interfaceListener{
|
||||
conns: &connIndex{},
|
||||
writeRequests: make(chan *packetConnWriteReq, m.chanBufSize),
|
||||
done: m.done,
|
||||
listenConf: newListenConfig(ifaceName, conf),
|
||||
errColl: m.errColl,
|
||||
ifaceName: ifaceName,
|
||||
port: port,
|
||||
}
|
||||
// TODO(a.garipov): Consider customization of body sizes.
|
||||
m.ifaceListeners[id] = m.newInterfaceListener(ctrlConf, ifaceName, dns.DefaultMsgSize, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenConfig returns a new netext.ListenConfig that receives connections from
|
||||
// the interface listener with the given id and the destination addresses of
|
||||
// which fall within subnet. subnet should be masked.
|
||||
// newInterfaceListener returns a new properly initialized *interfaceListener
|
||||
// for this manager.
|
||||
func (m *Manager) newInterfaceListener(
|
||||
ctrlConf *ControlConfig,
|
||||
ifaceName string,
|
||||
bodySize int,
|
||||
port uint16,
|
||||
) (l *interfaceListener) {
|
||||
return &interfaceListener{
|
||||
conns: &connIndex{},
|
||||
listenConf: newListenConfig(ifaceName, ctrlConf),
|
||||
bodyPool: agdsync.NewTypedPool(func() (v *[]byte) {
|
||||
b := make([]byte, bodySize)
|
||||
|
||||
return &b
|
||||
}),
|
||||
oobPool: agdsync.NewTypedPool(func() (v *[]byte) {
|
||||
b := make([]byte, netext.IPDstOOBSize)
|
||||
|
||||
return &b
|
||||
}),
|
||||
writeRequests: make(chan *packetConnWriteReq, m.chanBufSize),
|
||||
done: m.done,
|
||||
errColl: m.errColl,
|
||||
ifaceName: ifaceName,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// ListenConfig returns a new *ListenConfig that receives connections from the
|
||||
// interface listener with the given id and the destination addresses of which
|
||||
// fall within subnet. subnet should be masked.
|
||||
//
|
||||
// ListenConfig must not be called after Start is called.
|
||||
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
|
||||
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c *ListenConfig, err error) {
|
||||
defer func() {
|
||||
err = errors.Annotate(
|
||||
err,
|
||||
@ -154,9 +179,10 @@ func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfi
|
||||
return nil, fmt.Errorf("adding udp conn: %w", err)
|
||||
}
|
||||
|
||||
return &chanListenConfig{
|
||||
return &ListenConfig{
|
||||
packetConn: pConn,
|
||||
listener: lsnr,
|
||||
addr: agdnet.FormatPrefixAddr(subnet, l.port),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -221,6 +221,8 @@ func TestManager(t *testing.T) {
|
||||
err := m.Add(testID1, ifaceName, testPort1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(a.garipov): Add tests for addresses within ifaceNet but outside of a
|
||||
// narrower subnet.
|
||||
subnet, err := netutil.IPNetToPrefixNoMapped(&net.IPNet{
|
||||
IP: ifaceNet.IP.Mask(ifaceNet.Mask),
|
||||
Mask: ifaceNet.Mask,
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
@ -25,6 +24,8 @@ func NewManager(c *ManagerConfig) (m *Manager) {
|
||||
|
||||
// errUnsupported is returned from all [Manager] methods on OSs other than
|
||||
// Linux.
|
||||
//
|
||||
// TODO(a.garipov): Consider using [errors.ErrUnsupported] in Go 1.21.
|
||||
const errUnsupported errors.Error = "bindtodevice is only supported on linux"
|
||||
|
||||
// Add creates a new interface-listener record in m.
|
||||
@ -34,12 +35,12 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16, cc *ControlConfig) (
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// ListenConfig returns a new netext.ListenConfig that receives connections from
|
||||
// the interface listener with the given id and the destination addresses of
|
||||
// which fall within subnet. subnet should be masked.
|
||||
// ListenConfig returns a new *ListenConfig that receives connections from the
|
||||
// interface listener with the given id and the destination addresses of which
|
||||
// fall within subnet. subnet should be masked.
|
||||
//
|
||||
// It is only supported on Linux.
|
||||
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
|
||||
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c *ListenConfig, err error) {
|
||||
return nil, errUnsupported
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,10 @@
|
||||
package bindtodevice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
)
|
||||
|
||||
// prefixNetAddr is a wrapper around netip.Prefix that makes it a [net.Addr].
|
||||
@ -21,16 +22,11 @@ type prefixNetAddr struct {
|
||||
// type check
|
||||
var _ net.Addr = (*prefixNetAddr)(nil)
|
||||
|
||||
// String implements the [net.Addr] interface for *prefixNetAddr. It returns an
|
||||
// address of the form "1.2.3.0:56789/24". That is, IP:port with a subnet after
|
||||
// a slash. This is done to make using the IP:port part easier to split off
|
||||
// using something like [strings.Cut].
|
||||
// String implements the [net.Addr] interface for *prefixNetAddr.
|
||||
//
|
||||
// See [agdnet.FormatPrefixAddr] for the format.
|
||||
func (addr *prefixNetAddr) String() (n string) {
|
||||
return fmt.Sprintf(
|
||||
"%s/%d",
|
||||
netip.AddrPortFrom(addr.prefix.Addr(), addr.port),
|
||||
addr.prefix.Bits(),
|
||||
)
|
||||
return agdnet.FormatPrefixAddr(addr.prefix, addr.port)
|
||||
}
|
||||
|
||||
// Network implements the [net.Addr] interface for *prefixNetAddr.
|
||||
|
@ -16,6 +16,8 @@ func TestPrefixAddr(t *testing.T) {
|
||||
network = "tcp"
|
||||
)
|
||||
|
||||
fullPrefix := netip.MustParsePrefix("1.2.3.4/32")
|
||||
|
||||
testCases := []struct {
|
||||
in *prefixNetAddr
|
||||
want string
|
||||
@ -42,6 +44,14 @@ func TestPrefixAddr(t *testing.T) {
|
||||
netip.AddrPortFrom(testSubnetIPv6.Addr(), port), testSubnetIPv6.Bits(),
|
||||
),
|
||||
name: "ipv6",
|
||||
}, {
|
||||
in: &prefixNetAddr{
|
||||
prefix: fullPrefix,
|
||||
network: network,
|
||||
port: port,
|
||||
},
|
||||
want: netip.AddrPortFrom(fullPrefix.Addr(), port).String(),
|
||||
name: "ipv4_full",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@ -106,14 +105,16 @@ func listenControlWithSO(
|
||||
return errors.WithDeferred(opErr, err)
|
||||
}
|
||||
|
||||
// msgUDPReader is an interface for types of connections that can read UDP
|
||||
// messages. See [*net.UDPConn].
|
||||
type msgUDPReader interface {
|
||||
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
|
||||
}
|
||||
|
||||
// readPacketSession is a helper that reads a packet-session data from a UDP
|
||||
// connection.
|
||||
func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err error) {
|
||||
// TODO(a.garipov): Consider adding pooling.
|
||||
b := make([]byte, bodySize)
|
||||
oob := make([]byte, netext.IPDstOOBSize)
|
||||
|
||||
n, oobn, _, raddr, err := c.ReadMsgUDP(b, oob)
|
||||
func readPacketSession(c msgUDPReader, body, oob []byte) (sess *packetSession, err error) {
|
||||
n, oobn, _, raddr, err := c.ReadMsgUDP(body, oob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading: %w", err)
|
||||
}
|
||||
@ -142,7 +143,7 @@ func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err e
|
||||
sess = &packetSession{
|
||||
laddr: origDstAddr,
|
||||
raddr: raddr,
|
||||
readBody: b[:n],
|
||||
readBody: body[:n],
|
||||
respOOB: respOOB,
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,9 @@
|
||||
package bindtodevice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -16,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
@ -240,10 +243,13 @@ func testListenControlUDPQuery(t *testing.T, packetConn net.PacketConn, reqAddr
|
||||
err := packetConn.SetReadDeadline(time.Now().Add(testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := make([]byte, reqLen)
|
||||
oob := make([]byte, netext.IPDstOOBSize)
|
||||
|
||||
var sess *packetSession
|
||||
switch c := packetConn.(type) {
|
||||
case *net.UDPConn:
|
||||
sess, err = readPacketSession(c, reqLen)
|
||||
sess, err = readPacketSession(c, b, oob)
|
||||
require.NoError(t, err)
|
||||
case netext.SessionPacketConn:
|
||||
var s netext.PacketSession
|
||||
@ -460,3 +466,84 @@ func TestListenControlWithSO(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// testMsgUDPReader is a [msgUDPReader] for tests.
|
||||
type testMsgUDPReader struct {
|
||||
onReadMsgUDP func(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ msgUDPReader = (*testMsgUDPReader)(nil)
|
||||
|
||||
// ReadMsgUDP implements the [msgUDPReader] interface for *testMsgUDPReader.
|
||||
func (r *testMsgUDPReader) ReadMsgUDP(
|
||||
b []byte,
|
||||
oob []byte,
|
||||
) (n, oobn, flags int, addr *net.UDPAddr, err error) {
|
||||
return r.onReadMsgUDP(b, oob)
|
||||
}
|
||||
|
||||
// Sinks for benchmarks.
|
||||
var (
|
||||
sessSink *packetSession
|
||||
errSink error
|
||||
)
|
||||
|
||||
func BenchmarkReadPacketSession(b *testing.B) {
|
||||
bodyData := []byte("message body data")
|
||||
|
||||
// TODO(a.garipov): Find a better way to pack these control messages than
|
||||
// just [binary.Write].
|
||||
oobBuf := &bytes.Buffer{}
|
||||
ctrlMsgHdr := unix.Cmsghdr{
|
||||
Len: 24,
|
||||
Level: unix.SOL_IP,
|
||||
Type: unix.IP_ORIGDSTADDR,
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use binary.NativeEndian in Go 1.21 here and below.
|
||||
err := binary.Write(oobBuf, binary.LittleEndian, ctrlMsgHdr)
|
||||
require.NoError(b, err)
|
||||
|
||||
pktInfo := unix.Inet4Pktinfo{
|
||||
Spec_dst: *(*[4]byte)(testRAddr.IP),
|
||||
Addr: *(*[4]byte)(testRAddr.IP),
|
||||
}
|
||||
|
||||
err = binary.Write(oobBuf, binary.LittleEndian, pktInfo)
|
||||
require.NoError(b, err)
|
||||
|
||||
oobData := oobBuf.Bytes()
|
||||
|
||||
c := &testMsgUDPReader{
|
||||
onReadMsgUDP: func(body, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
|
||||
copy(body, bodyData)
|
||||
copy(oob, oobData)
|
||||
|
||||
return len(bodyData), len(oobData), 0, testRAddr, nil
|
||||
},
|
||||
}
|
||||
|
||||
body := make([]byte, dns.DefaultMsgSize)
|
||||
oob := make([]byte, netext.IPDstOOBSize)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sessSink, errSink = readPacketSession(c, body, oob)
|
||||
}
|
||||
|
||||
require.NoError(b, errSink)
|
||||
require.NotNil(b, sessSink)
|
||||
|
||||
assert.Equal(b, sessSink.raddr, testRAddr)
|
||||
assert.Equal(b, sessSink.readBody, bodyData)
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkReadPacketSession
|
||||
// BenchmarkReadPacketSession-16 3311841 458.1 ns/op 224 B/op 5 allocs/op
|
||||
}
|
||||
|
40
internal/cmd/access.go
Normal file
40
internal/cmd/access.go
Normal file
@ -0,0 +1,40 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// accessConfig is the configuration that controls IP and hosts blocking.
|
||||
type accessConfig struct {
|
||||
// BlockedQuestionDomains is a list of AdBlock rules used to block access.
|
||||
BlockedQuestionDomains []string `yaml:"blocked_question_domains"`
|
||||
|
||||
// BlockedClientSubnets is a list of IP addresses or subnets to block.
|
||||
BlockedClientSubnets []string `yaml:"blocked_client_subnets"`
|
||||
}
|
||||
|
||||
// validate returns an error if the access configuration is invalid.
|
||||
func (a *accessConfig) validate() (err error) {
|
||||
if a == nil {
|
||||
return errNilConfig
|
||||
}
|
||||
|
||||
for i, s := range a.BlockedClientSubnets {
|
||||
// TODO(a.garipov): Use [netutil.ParseSubnet] after refactoring it to
|
||||
// [netip.Addr].
|
||||
_, parseErr := netip.ParseAddr(s)
|
||||
if parseErr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, parseErr = netip.ParsePrefix(s)
|
||||
if parseErr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return fmt.Errorf("value %q at index %d: bad ip or cidr: %w", s, i, parseErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -3,10 +3,12 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backend"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@ -87,12 +89,14 @@ func setupBillStat(
|
||||
sigHdlr signalHandler,
|
||||
errColl agd.ErrorCollector,
|
||||
) (rec *billstat.RuntimeRecorder, err error) {
|
||||
billStatConf := &backend.BillStatConfig{
|
||||
BaseEndpoint: netutil.CloneURL(&envs.BillStatURL.URL),
|
||||
apiURL := netutil.CloneURL(&envs.BillStatURL.URL)
|
||||
billStatUploader, err := setupBillStatUploader(apiURL, errColl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating bill stat uploader: %w", err)
|
||||
}
|
||||
|
||||
rec = billstat.NewRuntimeRecorder(&billstat.RuntimeRecorderConfig{
|
||||
Uploader: backend.NewBillStat(billStatConf),
|
||||
Uploader: billStatUploader,
|
||||
})
|
||||
|
||||
refrIvl := conf.RefreshIvl.Duration
|
||||
@ -127,13 +131,12 @@ func setupProfDB(
|
||||
sigHdlr signalHandler,
|
||||
errColl agd.ErrorCollector,
|
||||
) (profDB *profiledb.Default, err error) {
|
||||
profStrgConf := &backend.ProfileStorageConfig{
|
||||
BaseEndpoint: netutil.CloneURL(&envs.ProfilesURL.URL),
|
||||
Now: time.Now,
|
||||
ErrColl: errColl,
|
||||
apiURL := netutil.CloneURL(&envs.ProfilesURL.URL)
|
||||
profStrg, err := setupProfStorage(apiURL, errColl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating profile storage: %w", err)
|
||||
}
|
||||
|
||||
profStrg := backend.NewProfileStorage(profStrgConf)
|
||||
profDB, err = profiledb.New(profStrg, conf.FullRefreshIvl.Duration, envs.ProfilesCachePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default profile database: %w", err)
|
||||
@ -162,3 +165,55 @@ func setupProfDB(
|
||||
|
||||
return profDB, nil
|
||||
}
|
||||
|
||||
// Backend API URL schemes.
|
||||
const (
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
schemeGRPC = "grpc"
|
||||
schemeGRPCS = "grpcs"
|
||||
)
|
||||
|
||||
// setupProfStorage creates and returns a profile storage depending on the
|
||||
// provided API URL.
|
||||
func setupProfStorage(
|
||||
apiURL *url.URL,
|
||||
errColl agd.ErrorCollector,
|
||||
) (s profiledb.Storage, err error) {
|
||||
switch apiURL.Scheme {
|
||||
case schemeGRPC, schemeGRPCS:
|
||||
return backendpb.NewProfileStorage(&backendpb.ProfileStorageConfig{
|
||||
Endpoint: apiURL,
|
||||
ErrColl: errColl,
|
||||
})
|
||||
case schemeHTTP, schemeHTTPS:
|
||||
return backend.NewProfileStorage(&backend.ProfileStorageConfig{
|
||||
BaseEndpoint: apiURL,
|
||||
Now: time.Now,
|
||||
ErrColl: errColl,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
|
||||
}
|
||||
}
|
||||
|
||||
// setupBillStatUploader creates and returns a billstat uploader depending on
|
||||
// the provided API URL.
|
||||
func setupBillStatUploader(
|
||||
apiURL *url.URL,
|
||||
errColl agd.ErrorCollector,
|
||||
) (s billstat.Uploader, err error) {
|
||||
switch apiURL.Scheme {
|
||||
case schemeGRPC, schemeGRPCS:
|
||||
return backendpb.NewBillStat(&backendpb.BillStatConfig{
|
||||
ErrColl: errColl,
|
||||
Endpoint: apiURL,
|
||||
})
|
||||
case schemeHTTP, schemeHTTPS:
|
||||
return backend.NewBillStat(&backend.BillStatConfig{
|
||||
BaseEndpoint: apiURL,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -53,18 +52,6 @@ func (c *checkConfig) toInternal(
|
||||
sessURL = netutil.CloneURL(&envs.ConsulDNSCheckSessionURL.URL)
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use netip.Addrs in dnscheck, which also means using it
|
||||
// in dnsmsg.Constructor.
|
||||
ipv4 := make([]net.IP, len(c.IPv4))
|
||||
for i, ip := range c.IPv4 {
|
||||
ipv4[i] = ip.AsSlice()
|
||||
}
|
||||
|
||||
ipv6 := make([]net.IP, len(c.IPv6))
|
||||
for i, ip := range c.IPv6 {
|
||||
ipv6[i] = ip.AsSlice()
|
||||
}
|
||||
|
||||
domains := make([]string, len(c.Domains))
|
||||
for i, d := range c.Domains {
|
||||
domains[i] = strings.ToLower(d)
|
||||
@ -78,8 +65,8 @@ func (c *checkConfig) toInternal(
|
||||
Domains: domains,
|
||||
NodeLocation: c.NodeLocation,
|
||||
NodeName: c.NodeName,
|
||||
IPv4: ipv4,
|
||||
IPv6: ipv6,
|
||||
IPv4: c.IPv4,
|
||||
IPv6: c.IPv6,
|
||||
TTL: c.TTL.Duration,
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
|
||||
@ -15,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
@ -160,6 +162,11 @@ func Main() {
|
||||
|
||||
sigHdlr.add(btdMgr)
|
||||
|
||||
// access
|
||||
|
||||
accessManager, err := access.New(c.Access.BlockedQuestionDomains, c.Access.BlockedClientSubnets)
|
||||
check(err)
|
||||
|
||||
// Server groups
|
||||
|
||||
messages := dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, c.Filters.ResponseTTL.Duration)
|
||||
@ -239,6 +246,7 @@ func Main() {
|
||||
}
|
||||
|
||||
dnsConf := &dnssvc.Config{
|
||||
AccessManager: accessManager,
|
||||
Messages: messages,
|
||||
SafeBrowsing: hashprefix.NewMatcher(hashStorages),
|
||||
BillStat: billStatRec,
|
||||
@ -327,5 +335,10 @@ func collectPanics(errColl agd.ErrorCollector) {
|
||||
|
||||
errColl.Collect(context.Background(), err)
|
||||
|
||||
errFlushColl, ok := errColl.(errcoll.ErrorFlushCollector)
|
||||
if ok {
|
||||
errFlushColl.Flush()
|
||||
}
|
||||
|
||||
panic(v)
|
||||
}
|
||||
|
@ -71,6 +71,9 @@ type configuration struct {
|
||||
// Network is the configuration for network listeners.
|
||||
Network *network `yaml:"network"`
|
||||
|
||||
// Access is the configuration of the service managing access control.
|
||||
Access *accessConfig `yaml:"access"`
|
||||
|
||||
// AdditionalMetricsInfo is extra information, which is exposed by metrics.
|
||||
AdditionalMetricsInfo additionalInfo `yaml:"additional_metrics_info"`
|
||||
|
||||
@ -159,6 +162,9 @@ func (c *configuration) validate() (err error) {
|
||||
}, {
|
||||
validate: c.Network.validate,
|
||||
name: "network",
|
||||
}, {
|
||||
validate: c.Access.validate,
|
||||
name: "access",
|
||||
}, {
|
||||
validate: c.AdditionalMetricsInfo.validate,
|
||||
name: "additional_metrics_info",
|
||||
|
@ -36,10 +36,10 @@ func (c *connCheckConfig) validate() (err error) {
|
||||
|
||||
// connectivityCheck performs connectivity checks for bind addresses with
|
||||
// provided dialer and probe addresses. For each server group it reviews each
|
||||
// server bind addresses looking up for an IPv6 addresses. If an IPv6 address
|
||||
// is found, then additionally to a general probe to IPv4 it will perform a
|
||||
// check to IPv6 probe address.
|
||||
func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) error {
|
||||
// server bind addresses looking up for IPv6 addresses. If an IPv6 address is
|
||||
// found, then additionally to a general probe to IPv4 it will perform a check
|
||||
// to IPv6 probe address.
|
||||
func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) (err error) {
|
||||
probeIPv4 := net.TCPAddrFromAddrPort(connCheck.ProbeIPv4)
|
||||
|
||||
// General check to IPv4 probe address.
|
||||
@ -49,48 +49,60 @@ func connectivityCheck(c *dnssvc.Config, connCheck *connCheckConfig) error {
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
log.Fatalf("connectivity check: ipv4: %v", err)
|
||||
closeErr := conn.Close()
|
||||
if closeErr != nil {
|
||||
log.Fatalf("connectivity check: ipv4: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if containsIPv6BindAddress(c.ServerGroups) {
|
||||
if (connCheck.ProbeIPv6 == netip.AddrPort{}) {
|
||||
log.Fatal("connectivity check: no ipv6 probe address in config")
|
||||
}
|
||||
|
||||
probeIPv6 := net.TCPAddrFromAddrPort(connCheck.ProbeIPv6)
|
||||
|
||||
// Check to IPv6 probe address.
|
||||
connV6, errV6 := net.DialTCP("tcp6", nil, probeIPv6)
|
||||
if errV6 != nil {
|
||||
return fmt.Errorf("connectivity check: ipv6: %w", errV6)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = connV6.Close()
|
||||
if err != nil {
|
||||
log.Fatalf("connectivity check: ipv6: %v", err)
|
||||
}
|
||||
}()
|
||||
if !requireIPv6ConnCheck(c.ServerGroups) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if (connCheck.ProbeIPv6 == netip.AddrPort{}) {
|
||||
log.Fatal("connectivity check: no ipv6 probe address in config")
|
||||
}
|
||||
|
||||
probeIPv6 := net.TCPAddrFromAddrPort(connCheck.ProbeIPv6)
|
||||
|
||||
// Check to IPv6 probe address.
|
||||
connV6, err := net.DialTCP("tcp6", nil, probeIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connectivity check: ipv6: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := connV6.Close()
|
||||
if closeErr != nil {
|
||||
log.Fatalf("connectivity check: ipv6: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// containsIPv6BindAddress returns true if provided serverGroups require
|
||||
// IPv6 connectivity check.
|
||||
func containsIPv6BindAddress(serverGroups []*agd.ServerGroup) (ok bool) {
|
||||
// requireIPv6ConnCheck returns true if provided serverGroups require IPv6
|
||||
// connectivity check.
|
||||
func requireIPv6ConnCheck(serverGroups []*agd.ServerGroup) (ok bool) {
|
||||
for _, srvGrp := range serverGroups {
|
||||
for _, s := range srvGrp.Servers {
|
||||
for _, bindData := range s.BindData {
|
||||
if addr := bindData.AddrPort; addr.IsValid() && addr.Addr().Is6() {
|
||||
return true
|
||||
}
|
||||
if containsIPv6BindAddress(s.BindData) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// containsIPv6BindAddress returns true if provided bindData contains valid IPv6
|
||||
// address.
|
||||
func containsIPv6BindAddress(bindData []*agd.ServerBindData) (ok bool) {
|
||||
for _, bData := range bindData {
|
||||
if addr := bData.AddrPort; addr.Addr().Is6() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -57,8 +57,16 @@ func ddrRecsToSVCBTmpls(
|
||||
tmpls = appendDDRSVCBTmpls(tmpls, msgs, r, target)
|
||||
}
|
||||
|
||||
slices.SortStableFunc(tmpls, func(a, b *dns.SVCB) (less bool) {
|
||||
return a.Priority < b.Priority
|
||||
// TODO(e.burkov): Use cmp.Compare when updated to go1.21.
|
||||
slices.SortStableFunc(tmpls, func(a, b *dns.SVCB) (res int) {
|
||||
switch x, y := a.Priority, b.Priority; {
|
||||
case x < y:
|
||||
return -1
|
||||
case x > y:
|
||||
return +1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
|
||||
return targets, tmpls
|
||||
@ -118,26 +126,16 @@ func (c *ddrConfig) validate() (err error) {
|
||||
}
|
||||
|
||||
domainSuf := wildcard[2:]
|
||||
err = netutil.ValidateHostname(domainSuf)
|
||||
err = errors.Join(netutil.ValidateHostname(domainSuf), r.validate())
|
||||
if err != nil {
|
||||
return fmt.Errorf("device_records: %w", err)
|
||||
}
|
||||
|
||||
err = r.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("device_records: record for wildcard %q: %w", wildcard, err)
|
||||
return fmt.Errorf("device_records: wildcard %q: %w", wildcard, err)
|
||||
}
|
||||
}
|
||||
|
||||
for domain, r := range c.PublicRecords {
|
||||
err = netutil.ValidateHostname(domain)
|
||||
err = errors.Join(netutil.ValidateHostname(domain), r.validate())
|
||||
if err != nil {
|
||||
return fmt.Errorf("public_records: %w", err)
|
||||
}
|
||||
|
||||
err = r.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("public_records: record for domain %q: %w", domain, err)
|
||||
return fmt.Errorf("public_records: domain %q: %w", domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ type environments struct {
|
||||
|
||||
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
|
||||
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
|
||||
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.json"`
|
||||
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"`
|
||||
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
|
||||
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
|
||||
QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"`
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@ -205,10 +204,8 @@ func (s *server) bindData(
|
||||
ifaces := s.BindInterfaces
|
||||
bindData = make([]*agd.ServerBindData, 0, len(ifaces))
|
||||
for i, iface := range ifaces {
|
||||
address := string(iface.ID)
|
||||
|
||||
for j, subnet := range iface.Subnets {
|
||||
var lc netext.ListenConfig
|
||||
var lc *bindtodevice.ListenConfig
|
||||
lc, err = btdMgr.ListenConfig(iface.ID, subnet)
|
||||
if err != nil {
|
||||
const errStr = "bind_interface at index %d: subnet at index %d: %w"
|
||||
@ -218,7 +215,7 @@ func (s *server) bindData(
|
||||
|
||||
bindData = append(bindData, &agd.ServerBindData{
|
||||
ListenConfig: lc,
|
||||
Address: address,
|
||||
Address: lc.Addr(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -136,14 +136,15 @@ func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
|
||||
return nil, fmt.Errorf("certificate at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
tlsCerts[i] = cert
|
||||
|
||||
var leaf *x509.Certificate
|
||||
leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid leaf, certificate at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
cert.Leaf = leaf
|
||||
tlsCerts[i] = cert
|
||||
|
||||
authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String()
|
||||
metrics.TLSCertificateInfo.With(prometheus.Labels{
|
||||
"auth_algo": authAlgo,
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@ -23,6 +22,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@ -46,8 +46,8 @@ type Consul struct {
|
||||
nodeLocation string
|
||||
nodeName string
|
||||
|
||||
ipv4 []net.IP
|
||||
ipv6 []net.IP
|
||||
ipv4 []netip.Addr
|
||||
ipv6 []netip.Addr
|
||||
}
|
||||
|
||||
// ConsulConfig is the configuration structure for Consul KV based DNS checker.
|
||||
@ -77,10 +77,10 @@ type ConsulConfig struct {
|
||||
NodeName string
|
||||
|
||||
// IPv4 are the IPv4 addresses to respond with to A requests.
|
||||
IPv4 []net.IP
|
||||
IPv4 []netip.Addr
|
||||
|
||||
// IPv6 are the IPv6 addresses to respond with to AAAA requests.
|
||||
IPv6 []net.IP
|
||||
IPv6 []netip.Addr
|
||||
|
||||
// TTL defines, for how long to keep the information about a single client.
|
||||
TTL time.Duration
|
||||
@ -108,8 +108,8 @@ func NewConsul(c *ConsulConfig) (cc *Consul, err error) {
|
||||
nodeLocation: c.NodeLocation,
|
||||
nodeName: c.NodeName,
|
||||
|
||||
ipv4: netutil.CloneIPs(c.IPv4),
|
||||
ipv6: netutil.CloneIPs(c.IPv6),
|
||||
ipv4: slices.Clone(c.IPv4),
|
||||
ipv6: slices.Clone(c.IPv6),
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Validate also c.ConsulSessionURL?
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -222,8 +223,8 @@ func TestConsul_Check(t *testing.T) {
|
||||
conf := &dnscheck.ConsulConfig{
|
||||
Messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second),
|
||||
Domains: []string{checkDomain},
|
||||
IPv4: []net.IP{{1, 2, 3, 4}},
|
||||
IPv6: []net.IP{net.ParseIP("1234::5678")},
|
||||
IPv4: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
IPv6: []netip.Addr{netip.MustParseAddr("1234::5678")},
|
||||
}
|
||||
|
||||
dnsCk, err := dnscheck.NewConsul(conf)
|
||||
|
@ -5,9 +5,9 @@ import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -25,7 +25,8 @@ import (
|
||||
|
||||
func TestDefault_ServeHTTP(t *testing.T) {
|
||||
const dname = "some-domain.name"
|
||||
testIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
testIP := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
successHdr := http.Header{
|
||||
httphdr.ContentType: []string{agdhttp.HdrValTextCSV},
|
||||
@ -100,7 +101,7 @@ func TestDefault_ServeHTTP(t *testing.T) {
|
||||
db.Record(ctx, m, &agd.RequestInfo{
|
||||
// Emulate the logic from init middleware.
|
||||
//
|
||||
// See [dnssvc.initMw.newRequestInfo].
|
||||
// See [initial.Middleware.newRequestInfo].
|
||||
Host: strings.TrimSuffix(m.Question[0].Name, "."),
|
||||
})
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ func answerString(rr dns.RR) (s string) {
|
||||
case *dns.AAAA:
|
||||
return v.AAAA.String()
|
||||
case *dns.CNAME:
|
||||
// TODO(a.garipov): Consider lowercasing target hostname.
|
||||
return strings.TrimSuffix(v.Target, ".")
|
||||
default:
|
||||
return ""
|
||||
|
@ -21,6 +21,8 @@ type BlockingMode interface {
|
||||
|
||||
// BlockingModeCodec is a wrapper around a BlockingMode that implements the
|
||||
// [json.Marshaler] and [json.Unmarshaler] interfaces.
|
||||
//
|
||||
// TODO(s.chzhen): Remove once it's not used anymore.
|
||||
type BlockingModeCodec struct {
|
||||
Mode BlockingMode
|
||||
}
|
||||
|
510
internal/dnsmsg/cloner.go
Normal file
510
internal/dnsmsg/cloner.go
Normal file
@ -0,0 +1,510 @@
|
||||
package dnsmsg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Cloner is a pool that can clone common parts of DNS messages with fewer
|
||||
// allocations.
|
||||
//
|
||||
// TODO(a.garipov): Add ECS/OPT.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
//
|
||||
// TODO(a.garipov): Consider merging into [Constructor].
|
||||
type Cloner struct {
|
||||
// Top-level structures.
|
||||
|
||||
msg *agdsync.TypedPool[dns.Msg]
|
||||
question *agdsync.TypedPool[[]dns.Question]
|
||||
|
||||
// Mostly-answer structures.
|
||||
|
||||
a *agdsync.TypedPool[dns.A]
|
||||
aaaa *agdsync.TypedPool[dns.AAAA]
|
||||
cname *agdsync.TypedPool[dns.CNAME]
|
||||
ptr *agdsync.TypedPool[dns.PTR]
|
||||
srv *agdsync.TypedPool[dns.SRV]
|
||||
txt *agdsync.TypedPool[dns.TXT]
|
||||
|
||||
// Mostly-answer custom cloners.
|
||||
|
||||
https *httpsCloner
|
||||
|
||||
// Mostly-NS structures.
|
||||
|
||||
soa *agdsync.TypedPool[dns.SOA]
|
||||
}
|
||||
|
||||
// NewCloner returns a new properly initialized *Cloner.
|
||||
func NewCloner() (c *Cloner) {
|
||||
return &Cloner{
|
||||
msg: agdsync.NewTypedPool(func() (v *dns.Msg) {
|
||||
return &dns.Msg{}
|
||||
}),
|
||||
question: agdsync.NewTypedPool(func() (v *[]dns.Question) {
|
||||
q := make([]dns.Question, 1)
|
||||
|
||||
return &q
|
||||
}),
|
||||
|
||||
a: agdsync.NewTypedPool(func() (v *dns.A) {
|
||||
return &dns.A{}
|
||||
}),
|
||||
aaaa: agdsync.NewTypedPool(func() (v *dns.AAAA) {
|
||||
return &dns.AAAA{}
|
||||
}),
|
||||
cname: agdsync.NewTypedPool(func() (v *dns.CNAME) {
|
||||
return &dns.CNAME{}
|
||||
}),
|
||||
ptr: agdsync.NewTypedPool(func() (v *dns.PTR) {
|
||||
return &dns.PTR{}
|
||||
}),
|
||||
srv: agdsync.NewTypedPool(func() (v *dns.SRV) {
|
||||
return &dns.SRV{}
|
||||
}),
|
||||
txt: agdsync.NewTypedPool(func() (v *dns.TXT) {
|
||||
return &dns.TXT{}
|
||||
}),
|
||||
|
||||
https: newHTTPSCloner(),
|
||||
|
||||
soa: agdsync.NewTypedPool(func() (v *dns.SOA) {
|
||||
return &dns.SOA{}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Clone returns a deep clone of msg. full is true if msg was cloned entirely
|
||||
// without the use of [dns.Copy].
|
||||
//
|
||||
// msg must have exactly one question.
|
||||
//
|
||||
// TODO(a.garipov): Don't require one question?
|
||||
func (c *Cloner) Clone(msg *dns.Msg) (clone *dns.Msg, full bool) {
|
||||
if msg == nil {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
clone = c.msg.Get()
|
||||
|
||||
clone.MsgHdr = msg.MsgHdr
|
||||
clone.Compress = msg.Compress
|
||||
|
||||
clone.Question = *c.question.Get()
|
||||
clone.Question[0] = msg.Question[0]
|
||||
|
||||
clone.Answer, full = c.appendAnswer(clone.Answer[:0], msg.Answer)
|
||||
|
||||
clone.Ns = clone.Ns[:0]
|
||||
for _, orig := range msg.Ns {
|
||||
var nsClone dns.RR
|
||||
switch orig := orig.(type) {
|
||||
case *dns.SOA:
|
||||
ns := c.soa.Get()
|
||||
*ns = *orig
|
||||
|
||||
nsClone = ns
|
||||
// TODO(a.garipov): Add more if necessary.
|
||||
default:
|
||||
nsClone = dns.Copy(orig)
|
||||
full = false
|
||||
}
|
||||
|
||||
clone.Ns = append(clone.Ns, nsClone)
|
||||
}
|
||||
|
||||
clone.Extra = clone.Extra[:0]
|
||||
for _, orig := range msg.Extra {
|
||||
var exClone dns.RR
|
||||
switch orig := orig.(type) {
|
||||
// TODO(a.garipov): Add more if necessary.
|
||||
default:
|
||||
exClone = dns.Copy(orig)
|
||||
full = false
|
||||
}
|
||||
|
||||
clone.Extra = append(clone.Extra, exClone)
|
||||
}
|
||||
|
||||
return clone, full
|
||||
}
|
||||
|
||||
// appendAnswer appends deep clones of all resource recornds from original to
|
||||
// clones and returns it.
|
||||
func (c *Cloner) appendAnswer(clones, original []dns.RR) (res []dns.RR, full bool) {
|
||||
full = true
|
||||
for _, orig := range original {
|
||||
var ansClone dns.RR
|
||||
switch orig := orig.(type) {
|
||||
case *dns.A:
|
||||
ans := c.a.Get()
|
||||
ans.Hdr = orig.Hdr
|
||||
|
||||
ans.A = append(ans.A[:0], orig.A...)
|
||||
|
||||
ansClone = ans
|
||||
case *dns.AAAA:
|
||||
ans := c.aaaa.Get()
|
||||
ans.Hdr = orig.Hdr
|
||||
|
||||
ans.AAAA = append(ans.AAAA[:0], orig.AAAA...)
|
||||
|
||||
ansClone = ans
|
||||
case *dns.CNAME:
|
||||
ans := c.cname.Get()
|
||||
*ans = *orig
|
||||
|
||||
ansClone = ans
|
||||
case *dns.HTTPS:
|
||||
var httpsFull bool
|
||||
ansClone, httpsFull = c.https.clone(orig)
|
||||
full = full && httpsFull
|
||||
case *dns.PTR:
|
||||
ans := c.ptr.Get()
|
||||
*ans = *orig
|
||||
|
||||
ansClone = ans
|
||||
case *dns.SRV:
|
||||
ans := c.srv.Get()
|
||||
*ans = *orig
|
||||
|
||||
ansClone = ans
|
||||
case *dns.TXT:
|
||||
ans := c.txt.Get()
|
||||
ans.Hdr = orig.Hdr
|
||||
|
||||
ans.Txt = append(ans.Txt[:0], orig.Txt...)
|
||||
|
||||
ansClone = ans
|
||||
default:
|
||||
ansClone = dns.Copy(orig)
|
||||
full = false
|
||||
}
|
||||
|
||||
clones = append(clones, ansClone)
|
||||
}
|
||||
|
||||
return clones, full
|
||||
}
|
||||
|
||||
// Put returns structures from msg into c's pools. Neither msg nor any of its
|
||||
// parts must not be used after this.
|
||||
//
|
||||
// msg must have exactly one question.
|
||||
//
|
||||
// TODO(a.garipov): Don't require one question?
|
||||
func (c *Cloner) Put(msg *dns.Msg) {
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.putAnswers(msg.Answer)
|
||||
|
||||
for _, ns := range msg.Ns {
|
||||
switch ns := ns.(type) {
|
||||
case *dns.SOA:
|
||||
c.soa.Put(ns)
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
}
|
||||
|
||||
for _, ex := range msg.Extra {
|
||||
// TODO(a.garipov): Add OPT.
|
||||
_ = ex
|
||||
}
|
||||
|
||||
c.question.Put(&msg.Question)
|
||||
|
||||
c.msg.Put(msg)
|
||||
}
|
||||
|
||||
// putAnswers returns answers into c's pools.
|
||||
func (c *Cloner) putAnswers(answers []dns.RR) {
|
||||
for _, ans := range answers {
|
||||
switch ans := ans.(type) {
|
||||
case *dns.A:
|
||||
c.a.Put(ans)
|
||||
case *dns.AAAA:
|
||||
c.aaaa.Put(ans)
|
||||
case *dns.CNAME:
|
||||
c.cname.Put(ans)
|
||||
case *dns.HTTPS:
|
||||
c.https.put(ans)
|
||||
case *dns.PTR:
|
||||
c.ptr.Put(ans)
|
||||
case *dns.SRV:
|
||||
c.srv.Put(ans)
|
||||
case *dns.TXT:
|
||||
c.txt.Put(ans)
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpsCloner is a pool that can clone common parts of DNS messages of type
|
||||
// HTTPS with fewer allocations.
|
||||
type httpsCloner struct {
|
||||
// Top-level structures.
|
||||
|
||||
rr *agdsync.TypedPool[dns.HTTPS]
|
||||
|
||||
// Values.
|
||||
|
||||
alpn *agdsync.TypedPool[dns.SVCBAlpn]
|
||||
dohpath *agdsync.TypedPool[dns.SVCBDoHPath]
|
||||
echconfig *agdsync.TypedPool[dns.SVCBECHConfig]
|
||||
ipv4hint *agdsync.TypedPool[dns.SVCBIPv4Hint]
|
||||
ipv6hint *agdsync.TypedPool[dns.SVCBIPv6Hint]
|
||||
local *agdsync.TypedPool[dns.SVCBLocal]
|
||||
mandatory *agdsync.TypedPool[dns.SVCBMandatory]
|
||||
noDefALPN *agdsync.TypedPool[dns.SVCBNoDefaultAlpn]
|
||||
port *agdsync.TypedPool[dns.SVCBPort]
|
||||
|
||||
// Miscellaneous.
|
||||
|
||||
ip *agdsync.TypedPool[net.IP]
|
||||
}
|
||||
|
||||
// newHTTPSCloner returns a new properly initialized *httpsCloner.
|
||||
func newHTTPSCloner() (c *httpsCloner) {
|
||||
return &httpsCloner{
|
||||
rr: agdsync.NewTypedPool(func() (v *dns.HTTPS) {
|
||||
return &dns.HTTPS{}
|
||||
}),
|
||||
|
||||
alpn: agdsync.NewTypedPool(func() (v *dns.SVCBAlpn) {
|
||||
return &dns.SVCBAlpn{}
|
||||
}),
|
||||
dohpath: agdsync.NewTypedPool(func() (v *dns.SVCBDoHPath) {
|
||||
return &dns.SVCBDoHPath{}
|
||||
}),
|
||||
echconfig: agdsync.NewTypedPool(func() (v *dns.SVCBECHConfig) {
|
||||
return &dns.SVCBECHConfig{}
|
||||
}),
|
||||
ipv4hint: agdsync.NewTypedPool(func() (v *dns.SVCBIPv4Hint) {
|
||||
return &dns.SVCBIPv4Hint{}
|
||||
}),
|
||||
ipv6hint: agdsync.NewTypedPool(func() (v *dns.SVCBIPv6Hint) {
|
||||
return &dns.SVCBIPv6Hint{}
|
||||
}),
|
||||
local: agdsync.NewTypedPool(func() (v *dns.SVCBLocal) {
|
||||
return &dns.SVCBLocal{}
|
||||
}),
|
||||
mandatory: agdsync.NewTypedPool(func() (v *dns.SVCBMandatory) {
|
||||
return &dns.SVCBMandatory{}
|
||||
}),
|
||||
noDefALPN: agdsync.NewTypedPool(func() (v *dns.SVCBNoDefaultAlpn) {
|
||||
return &dns.SVCBNoDefaultAlpn{}
|
||||
}),
|
||||
port: agdsync.NewTypedPool(func() (v *dns.SVCBPort) {
|
||||
return &dns.SVCBPort{}
|
||||
}),
|
||||
|
||||
ip: agdsync.NewTypedPool(func() (v *net.IP) {
|
||||
// Use the IPv6 length to increase the effectiveness of the pool.
|
||||
ip := make(net.IP, 16)
|
||||
|
||||
return &ip
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep clone of rr. full is true if rr was cloned entirely
|
||||
// without the use of [dns.Copy].
|
||||
func (c *httpsCloner) clone(rr *dns.HTTPS) (clone *dns.HTTPS, full bool) {
|
||||
if rr == nil {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
clone = c.rr.Get()
|
||||
|
||||
clone.Hdr = rr.Hdr
|
||||
clone.Priority = rr.Priority
|
||||
clone.Target = rr.Target
|
||||
|
||||
clone.Value = clone.Value[:0]
|
||||
for _, orig := range rr.Value {
|
||||
valClone, knownKV := c.cloneKV(orig)
|
||||
if !knownKV {
|
||||
// This branch is only reached if there is a new SVCB key-value type
|
||||
// in miekg/dns. Give up and just use their copy function.
|
||||
return dns.Copy(rr).(*dns.HTTPS), false
|
||||
}
|
||||
|
||||
clone.Value = append(clone.Value, valClone)
|
||||
}
|
||||
|
||||
return clone, true
|
||||
}
|
||||
|
||||
// cloneKV returns a deep clone of orig. full is true if orig was recognized.
|
||||
func (c *httpsCloner) cloneKV(orig dns.SVCBKeyValue) (clone dns.SVCBKeyValue, known bool) {
|
||||
switch orig := orig.(type) {
|
||||
case *dns.SVCBAlpn:
|
||||
v := c.alpn.Get()
|
||||
|
||||
v.Alpn = append(v.Alpn[:0], orig.Alpn...)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBDoHPath:
|
||||
v := c.dohpath.Get()
|
||||
*v = *orig
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBECHConfig:
|
||||
v := c.echconfig.Get()
|
||||
|
||||
v.ECH = append(v.ECH[:0], orig.ECH...)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBIPv4Hint:
|
||||
v := c.ipv4hint.Get()
|
||||
|
||||
v.Hint = c.appendIPs(v.Hint[:0], orig.Hint)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBIPv6Hint:
|
||||
v := c.ipv6hint.Get()
|
||||
|
||||
v.Hint = c.appendIPs(v.Hint[:0], orig.Hint)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBLocal:
|
||||
v := c.local.Get()
|
||||
v.KeyCode = orig.KeyCode
|
||||
|
||||
v.Data = append(v.Data[:0], orig.Data...)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBMandatory:
|
||||
v := c.mandatory.Get()
|
||||
|
||||
v.Code = append(v.Code[:0], orig.Code...)
|
||||
|
||||
clone = v
|
||||
case *dns.SVCBNoDefaultAlpn:
|
||||
clone = c.noDefALPN.Get()
|
||||
case *dns.SVCBPort:
|
||||
v := c.port.Get()
|
||||
*v = *orig
|
||||
|
||||
clone = v
|
||||
default:
|
||||
// This branch is only reached if there is a new SVCB key-value type
|
||||
// in miekg/dns.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return clone, true
|
||||
}
|
||||
|
||||
// appendIPs appends the clones of IP addresses from orig to hints and returns
|
||||
// the resulting slice. clone is allocated as a single continuous slice.
|
||||
func (c *httpsCloner) appendIPs(hints, orig []net.IP) (clone []net.IP) {
|
||||
if len(orig) == 0 {
|
||||
if orig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []net.IP{}
|
||||
}
|
||||
|
||||
// Use a single large slice and subslice it to make it easier to maintain a
|
||||
// pool of these.
|
||||
ips := *c.ip.Get()
|
||||
ips = ips[:0]
|
||||
|
||||
neededCap := 0
|
||||
for _, origIP := range orig {
|
||||
neededCap += len(origIP)
|
||||
}
|
||||
|
||||
ips = slices.Grow(ips, neededCap)
|
||||
|
||||
hints = hints[:0]
|
||||
for _, origIP := range orig {
|
||||
ips = append(ips, origIP...)
|
||||
origLen := len(origIP)
|
||||
lastIdx := len(ips)
|
||||
hints = append(hints, ips[lastIdx-origLen:lastIdx])
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// put returns structures from rr into c's pools.
|
||||
func (c *httpsCloner) put(rr *dns.HTTPS) {
|
||||
if rr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, kv := range rr.Value {
|
||||
c.putKV(kv)
|
||||
}
|
||||
|
||||
c.rr.Put(rr)
|
||||
}
|
||||
|
||||
// putKV returns structures from kv into c's pools.
|
||||
func (c *httpsCloner) putKV(kv dns.SVCBKeyValue) {
|
||||
switch kv := kv.(type) {
|
||||
case *dns.SVCBAlpn:
|
||||
c.alpn.Put(kv)
|
||||
case *dns.SVCBDoHPath:
|
||||
c.dohpath.Put(kv)
|
||||
case *dns.SVCBECHConfig:
|
||||
c.echconfig.Put(kv)
|
||||
case *dns.SVCBIPv4Hint:
|
||||
putIPHint(c, kv)
|
||||
case *dns.SVCBIPv6Hint:
|
||||
putIPHint(c, kv)
|
||||
case *dns.SVCBLocal:
|
||||
c.local.Put(kv)
|
||||
case *dns.SVCBMandatory:
|
||||
c.mandatory.Put(kv)
|
||||
case *dns.SVCBNoDefaultAlpn:
|
||||
c.noDefALPN.Put(kv)
|
||||
case *dns.SVCBPort:
|
||||
c.port.Put(kv)
|
||||
default:
|
||||
// This branch is only reached if there is a new SVCB key-value type
|
||||
// in miekg/dns. Noting to do.
|
||||
}
|
||||
}
|
||||
|
||||
// putIPHint is a generic helper that returns the structures of kv into c.
|
||||
func putIPHint[T *dns.SVCBIPv4Hint | *dns.SVCBIPv6Hint](c *httpsCloner, kv T) {
|
||||
switch kv := any(kv).(type) {
|
||||
case *dns.SVCBIPv4Hint:
|
||||
// TODO(a.garipov): Put the common code above the switch when Go learns
|
||||
// about common fields between types.
|
||||
if len(kv.Hint) > 0 {
|
||||
// Assume that the array underlying these slices is a single and
|
||||
// continuous one.
|
||||
c.ip.Put(&kv.Hint[0])
|
||||
}
|
||||
|
||||
c.ipv4hint.Put(kv)
|
||||
case *dns.SVCBIPv6Hint:
|
||||
// TODO(a.garipov): Put the common code above the switch when Go learns
|
||||
// about common fields between types.
|
||||
if len(kv.Hint) > 0 {
|
||||
// Assume that the array underlying these slices is a single and
|
||||
// continuous one.
|
||||
c.ip.Put(&kv.Hint[0])
|
||||
}
|
||||
|
||||
c.ipv6hint.Put(kv)
|
||||
default:
|
||||
// Must not happen, because there is a strict type parameter above.
|
||||
panic(fmt.Errorf("bad type %T", kv))
|
||||
}
|
||||
}
|
280
internal/dnsmsg/cloner_test.go
Normal file
280
internal/dnsmsg/cloner_test.go
Normal file
@ -0,0 +1,280 @@
|
||||
package dnsmsg_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// clonerTestCase is the type for the common test cases for the cloner tests and
|
||||
// benchmarks.
|
||||
type clonerTestCase struct {
|
||||
msg *dns.Msg
|
||||
wantFull assert.BoolAssertionFunc
|
||||
name string
|
||||
}
|
||||
|
||||
// clonerTestCases are the common test cases for the clone benchmarks.
|
||||
var clonerTestCases = []clonerTestCase{{
|
||||
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
name: "req_a",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(testFQDN, 10, testIPv4),
|
||||
},
|
||||
),
|
||||
name: "resp_a",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(testFQDN, 10, testIPv4),
|
||||
dnsservertest.NewA(testFQDN, 10, testIPv4.Next()),
|
||||
},
|
||||
),
|
||||
name: "resp_a_many",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(testFQDN, 10, testIPv4),
|
||||
},
|
||||
dnsservertest.SectionNs{
|
||||
dnsservertest.NewSOA(testFQDN, 10, "ns.example.", "mbox.example."),
|
||||
},
|
||||
),
|
||||
name: "resp_a_soa",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET),
|
||||
name: "req_aaaa",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewAAAA(testFQDN, 10, testIPv6),
|
||||
},
|
||||
),
|
||||
name: "resp_aaaa",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewCNAME(testFQDN, 10, "cname.example."),
|
||||
dnsservertest.NewA("cname.example.", 10, testIPv4),
|
||||
},
|
||||
),
|
||||
name: "resp_cname_a",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq("4.3.2.1.in-addr.arpa", dns.TypePTR, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewPTR("4.3.2.1.in-addr.arpa", 10, "ptr.example."),
|
||||
},
|
||||
),
|
||||
name: "resp_ptr",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewTXT(testFQDN, 10, "a", "b", "c"),
|
||||
},
|
||||
),
|
||||
name: "resp_txt",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeSRV, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewSRV(testFQDN, 10, "target.example.", 1, 1, 8080),
|
||||
},
|
||||
),
|
||||
name: "resp_srv",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeDNSKEY, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
&dns.DNSKEY{},
|
||||
},
|
||||
),
|
||||
name: "resp_not_full",
|
||||
wantFull: assert.False,
|
||||
}, {
|
||||
msg: newHTTPSResp([]dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"http/1.1", "h2", "h3"}},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query"},
|
||||
&dns.SVCBECHConfig{ECH: []byte{0, 1, 2, 3}},
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{
|
||||
testIPv4.AsSlice(),
|
||||
testIPv4.Next().AsSlice(),
|
||||
}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{
|
||||
testIPv6.AsSlice(),
|
||||
testIPv6.Next().AsSlice(),
|
||||
}},
|
||||
&dns.SVCBLocal{KeyCode: dns.SVCBKey(1234), Data: []byte{3, 2, 1, 0}},
|
||||
&dns.SVCBMandatory{Code: []dns.SVCBKey{dns.SVCB_ALPN}},
|
||||
&dns.SVCBNoDefaultAlpn{},
|
||||
&dns.SVCBPort{Port: 443},
|
||||
}),
|
||||
name: "resp_https",
|
||||
wantFull: assert.True,
|
||||
}, {
|
||||
msg: newHTTPSResp([]dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: []net.IP{}},
|
||||
&dns.SVCBIPv6Hint{Hint: []net.IP{}},
|
||||
}),
|
||||
name: "resp_https_empty_hint",
|
||||
wantFull: assert.True,
|
||||
}}
|
||||
|
||||
// newHTTPSResp is a hepler that returns a response of type HTTPS with the given
|
||||
// parameter values.
|
||||
func newHTTPSResp(kv []dns.SVCBKeyValue) (resp *dns.Msg) {
|
||||
ans := &dns.HTTPS{
|
||||
SVCB: dns.SVCB{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: testFQDN,
|
||||
Rrtype: dns.TypeHTTPS,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 10,
|
||||
},
|
||||
Priority: 10,
|
||||
Target: testFQDN,
|
||||
Value: kv,
|
||||
},
|
||||
}
|
||||
|
||||
return dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
dnsservertest.NewReq(testFQDN, dns.TypeHTTPS, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{ans},
|
||||
)
|
||||
}
|
||||
|
||||
func TestCloner_Clone(t *testing.T) {
|
||||
c := dnsmsg.NewCloner()
|
||||
|
||||
for _, tc := range clonerTestCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
clone, full := c.Clone(tc.msg)
|
||||
assert.NotSame(t, tc.msg, clone)
|
||||
assert.Equal(t, tc.msg, clone)
|
||||
tc.wantFull(t, full)
|
||||
|
||||
// Check again after putting it back.
|
||||
c.Put(clone)
|
||||
|
||||
clone, full = c.Clone(tc.msg)
|
||||
assert.NotSame(t, tc.msg, clone)
|
||||
assert.Equal(t, tc.msg, clone)
|
||||
tc.wantFull(t, full)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sinks for benchmarks
|
||||
var (
|
||||
msgSink *dns.Msg
|
||||
boolSink bool
|
||||
)
|
||||
|
||||
func BenchmarkClone(b *testing.B) {
|
||||
for _, tc := range clonerTestCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msgSink = dnsmsg.Clone(tc.msg)
|
||||
}
|
||||
|
||||
require.Equal(b, tc.msg, msgSink)
|
||||
})
|
||||
}
|
||||
|
||||
// Most recent results, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/querylog
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkClone/req_a-16 32849714 231.7 ns/op 168 B/op 2 allocs/op
|
||||
// BenchmarkClone/resp_a-16 12051967 509.1 ns/op 256 B/op 5 allocs/op
|
||||
// BenchmarkClone/resp_a_many-16 8579755 669.4 ns/op 344 B/op 7 allocs/op
|
||||
// BenchmarkClone/resp_a_soa-16 10393932 681.9 ns/op 368 B/op 6 allocs/op
|
||||
// BenchmarkClone/req_aaaa-16 25616247 232.1 ns/op 168 B/op 2 allocs/op
|
||||
// BenchmarkClone/resp_aaaa-16 14519920 493.4 ns/op 264 B/op 5 allocs/op
|
||||
// BenchmarkClone/resp_cname_a-16 8652282 662.2 ns/op 320 B/op 6 allocs/op
|
||||
// BenchmarkClone/resp_ptr-16 13558555 370.0 ns/op 232 B/op 4 allocs/op
|
||||
// BenchmarkClone/resp_txt-16 12322016 532.7 ns/op 296 B/op 5 allocs/op
|
||||
// BenchmarkClone/resp_srv-16 15878784 396.3 ns/op 248 B/op 4 allocs/op
|
||||
// BenchmarkClone/resp_not_full-16 15718658 384.6 ns/op 248 B/op 4 allocs/op
|
||||
// BenchmarkClone/resp_https-16 2621149 2020 ns/op 880 B/op 24 allocs/op
|
||||
// BenchmarkClone/resp_https_empty_hint-16 6829873 890.8 ns/op 424 B/op 8 allocs/op
|
||||
}
|
||||
|
||||
func BenchmarkCloner_Clone(b *testing.B) {
|
||||
c := dnsmsg.NewCloner()
|
||||
|
||||
for _, tc := range clonerTestCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msgSink, boolSink = c.Clone(tc.msg)
|
||||
if i < b.N-1 {
|
||||
// Don't put the last one to be sure that we can compare
|
||||
// that one.
|
||||
c.Put(msgSink)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(b, tc.msg, msgSink)
|
||||
tc.wantFull(b, boolSink)
|
||||
})
|
||||
}
|
||||
|
||||
// Most recent results, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/querylog
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkCloner_Clone/req_a-16 163590546 36.33 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_a-16 100000000 56.55 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_a_many-16 72498543 84.52 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_a_soa-16 81750753 73.07 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/req_aaaa-16 165287482 39.00 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_aaaa-16 99625165 59.56 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_cname_a-16 72154432 81.15 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_ptr-16 100418211 60.88 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_txt-16 80963180 73.66 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_srv-16 89021206 69.35 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_not_full-16 31277523 187.6 ns/op 64 B/op 1 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_https-16 14601229 396.3 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkCloner_Clone/resp_https_empty_hint-16 45725181 127.4 ns/op 0 B/op 0 allocs/op
|
||||
}
|
@ -2,10 +2,9 @@ package dnsmsg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -38,7 +37,7 @@ func (c *Constructor) NewBlockedRespMsg(req *dns.Msg) (msg *dns.Msg, err error)
|
||||
case *BlockingModeNullIP:
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
return c.NewIPRespMsg(req, nil)
|
||||
return c.NewIPRespMsg(req, netip.Addr{})
|
||||
default:
|
||||
return c.NewMsgNODATA(req), nil
|
||||
}
|
||||
@ -62,11 +61,11 @@ func (c *Constructor) newBlockedCustomIPRespMsg(
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
if m.IPv4.IsValid() {
|
||||
return c.NewIPRespMsg(req, m.IPv4.AsSlice())
|
||||
return c.NewIPRespMsg(req, m.IPv4)
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if m.IPv6.IsValid() {
|
||||
return c.NewIPRespMsg(req, m.IPv6.AsSlice())
|
||||
return c.NewIPRespMsg(req, m.IPv6)
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
@ -78,7 +77,7 @@ func (c *Constructor) newBlockedCustomIPRespMsg(
|
||||
// NewIPRespMsg returns a DNS A or AAAA response message with the given IP
|
||||
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
|
||||
// null) IP. The TTL is also set to c.FilteredResponseTTL.
|
||||
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
|
||||
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
return c.newMsgA(req, ips...)
|
||||
@ -205,40 +204,38 @@ func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, cl dns.Class)
|
||||
}
|
||||
|
||||
// NewAnsA returns a new resource record with an IPv4 address. ip must be an
|
||||
// IPv4 address. If ip is nil, it is replaced by an unspecified (aka null) IP,
|
||||
// 0.0.0.0.
|
||||
func (c *Constructor) NewAnsA(req *dns.Msg, ip net.IP) (ans *dns.A, err error) {
|
||||
var ip4 net.IP
|
||||
if ip == nil {
|
||||
ip4 = net.IP{0, 0, 0, 0}
|
||||
} else if err = netutil.ValidateIP(ip); err != nil {
|
||||
return nil, err
|
||||
} else if ip4 = ip.To4(); ip4 == nil {
|
||||
// IPv4 address. If ip is a zero netip.Addr, it is replaced by an unspecified
|
||||
// (aka null) IP, 0.0.0.0.
|
||||
func (c *Constructor) NewAnsA(req *dns.Msg, ip netip.Addr) (ans *dns.A, err error) {
|
||||
if ip == (netip.Addr{}) {
|
||||
ip = netip.IPv4Unspecified()
|
||||
} else if !ip.Is4() {
|
||||
return nil, fmt.Errorf("bad ipv4: %s", ip)
|
||||
}
|
||||
|
||||
data := ip.As4()
|
||||
|
||||
return &dns.A{
|
||||
Hdr: c.newHdr(req, dns.TypeA),
|
||||
A: ip4,
|
||||
A: data[:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewAnsAAAA returns a new resource record with an IPv6 address. ip must be an
|
||||
// IPv6 address. If ip is nil, it is replaced by an unspecified (aka null) IP,
|
||||
// [::].
|
||||
func (c *Constructor) NewAnsAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA, err error) {
|
||||
var ip6 net.IP
|
||||
if ip == nil {
|
||||
ip6 = net.IPv6unspecified
|
||||
} else if err = netutil.ValidateIP(ip); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
ip6 = ip.To16()
|
||||
// IPv6 address. If ip is a zero netip.Addr, it is replaced by an unspecified
|
||||
// (aka null) IP, [::].
|
||||
func (c *Constructor) NewAnsAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA, err error) {
|
||||
if ip == (netip.Addr{}) {
|
||||
ip = netip.IPv6Unspecified()
|
||||
} else if !ip.Is6() {
|
||||
return nil, fmt.Errorf("bad ipv6: %s", ip)
|
||||
}
|
||||
|
||||
data := ip.As16()
|
||||
|
||||
return &dns.AAAA{
|
||||
Hdr: c.newHdr(req, dns.TypeAAAA),
|
||||
AAAA: ip6,
|
||||
AAAA: data[:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -358,7 +355,7 @@ func (c *Constructor) NewRespMsg(req *dns.Msg) (resp *dns.Msg) {
|
||||
|
||||
// newMsgA returns a new DNS response with the given IPv4 addresses. If any IP
|
||||
// address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0.
|
||||
func (c *Constructor) newMsgA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
|
||||
func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
msg = c.NewRespMsg(req)
|
||||
for i, ip := range ips {
|
||||
var ans dns.RR
|
||||
@ -375,7 +372,7 @@ func (c *Constructor) newMsgA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err er
|
||||
|
||||
// newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any
|
||||
// IP address is nil, it is replaced by an unspecified (aka null) IP, [::].
|
||||
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...net.IP) (msg *dns.Msg, err error) {
|
||||
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
msg = c.NewRespMsg(req)
|
||||
for i, ip := range ips {
|
||||
var ans dns.RR
|
||||
|
@ -2,7 +2,6 @@ package dnsmsg_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -75,9 +74,6 @@ func TestConstructor_NewBlockedRespMsg_nullIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
|
||||
wantIPv4 := netip.MustParseAddr("1.2.3.4")
|
||||
wantIPv6 := netip.MustParseAddr("1234::cdef")
|
||||
|
||||
testCases := []struct {
|
||||
messages *dnsmsg.Constructor
|
||||
name string
|
||||
@ -85,22 +81,22 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
|
||||
wantAAAA bool
|
||||
}{{
|
||||
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: wantIPv4,
|
||||
IPv6: wantIPv6,
|
||||
IPv4: testIPv4,
|
||||
IPv6: testIPv6,
|
||||
}, testFltRespTTL),
|
||||
name: "both",
|
||||
wantA: true,
|
||||
wantAAAA: true,
|
||||
}, {
|
||||
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: wantIPv4,
|
||||
IPv4: testIPv4,
|
||||
}, testFltRespTTL),
|
||||
name: "ipv4_only",
|
||||
wantA: true,
|
||||
wantAAAA: false,
|
||||
}, {
|
||||
messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeCustomIP{
|
||||
IPv6: wantIPv6,
|
||||
IPv6: testIPv6,
|
||||
}, testFltRespTTL),
|
||||
name: "ipv6_only",
|
||||
wantA: false,
|
||||
@ -120,7 +116,7 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
|
||||
require.Len(t, respA.Answer, 1)
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, respA.Answer[0])
|
||||
assert.Equal(t, net.IP(wantIPv4.AsSlice()), a.A)
|
||||
assert.Equal(t, net.IP(testIPv4.AsSlice()), a.A)
|
||||
} else {
|
||||
assert.Empty(t, respA.Answer)
|
||||
}
|
||||
@ -136,7 +132,7 @@ func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
|
||||
require.Len(t, respAAAA.Answer, 1)
|
||||
|
||||
aaaa := testutil.RequireTypeAssert[*dns.AAAA](t, respAAAA.Answer[0])
|
||||
assert.Equal(t, net.IP(wantIPv6.AsSlice()), aaaa.AAAA)
|
||||
assert.Equal(t, net.IP(testIPv6.AsSlice()), aaaa.AAAA)
|
||||
} else {
|
||||
assert.Empty(t, respAAAA.Answer)
|
||||
}
|
||||
|
@ -31,6 +31,12 @@ const (
|
||||
testFQDN = testDomain + "."
|
||||
)
|
||||
|
||||
// Common IP addresses for tests.
|
||||
var (
|
||||
testIPv4 = netip.MustParseAddr("1.2.3.4")
|
||||
testIPv6 = netip.MustParseAddr("1234::cdef")
|
||||
)
|
||||
|
||||
// Common test constants.
|
||||
const (
|
||||
ipv4MaskBits = 24
|
||||
|
@ -14,14 +14,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
// testIPv4 is an IPv4 for tests.
|
||||
testIPv4 = netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// testIPv6 is an IPv6 for tests.
|
||||
testIPv6 = netip.MustParseAddr("::1")
|
||||
)
|
||||
|
||||
func TestConstructor_NewAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
// Preconditions.
|
||||
|
||||
|
14
internal/dnsserver/cache/cache_test.go
vendored
14
internal/dnsserver/cache/cache_test.go
vendored
@ -3,6 +3,7 @@ package cache_test
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -26,6 +27,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
defaultTTL uint32 = 3600
|
||||
)
|
||||
|
||||
reqAddr := netip.MustParseAddr("1.2.3.4")
|
||||
testTTL := 60 * time.Second
|
||||
|
||||
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
|
||||
@ -44,7 +46,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}{{
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, defaultTTL, reqAddr),
|
||||
}),
|
||||
name: "simple_a",
|
||||
wantNumReq: 1,
|
||||
@ -114,7 +116,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}, {
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, 0, reqAddr),
|
||||
}),
|
||||
name: "expired_one",
|
||||
wantNumReq: N,
|
||||
@ -123,7 +125,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}, {
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, 10, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, 10, reqAddr),
|
||||
}),
|
||||
name: "override_ttl_ok",
|
||||
wantNumReq: 1,
|
||||
@ -132,7 +134,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}, {
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, 1000, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, 1000, reqAddr),
|
||||
}),
|
||||
name: "override_ttl_max",
|
||||
wantNumReq: 1,
|
||||
@ -141,7 +143,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}, {
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, 0, reqAddr),
|
||||
}),
|
||||
name: "override_ttl_zero",
|
||||
wantNumReq: N,
|
||||
@ -150,7 +152,7 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}, {
|
||||
req: aReq,
|
||||
resp: dnsservertest.NewResp(dns.RcodeServerFailure, aReq, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(reqHostname, servFailMaxCacheTTL, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(reqHostname, servFailMaxCacheTTL, reqAddr),
|
||||
}),
|
||||
name: "override_ttl_servfail",
|
||||
wantNumReq: 1,
|
||||
|
@ -2,6 +2,7 @@ package dnsservertest
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -127,6 +128,38 @@ func NewResp(rcode int, req *dns.Msg, rrs ...RRSection) (resp *dns.Msg) {
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewA constructs the new resource record of type A. a must be a valid 4-byte
|
||||
// IPv4-address.
|
||||
func NewA(name string, ttl uint32, a netip.Addr) (rr dns.RR) {
|
||||
data := a.As4()
|
||||
|
||||
return &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
A: data[:],
|
||||
}
|
||||
}
|
||||
|
||||
// NewAAAA constructs the new resource record of type AAAA. aaaa must be a
|
||||
// valid 16-byte IPv6-address.
|
||||
func NewAAAA(name string, ttl uint32, aaaa netip.Addr) (rr dns.RR) {
|
||||
data := aaaa.As16()
|
||||
|
||||
return &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
AAAA: data[:],
|
||||
}
|
||||
}
|
||||
|
||||
// NewCNAME constructs the new resource record of type CNAME.
|
||||
func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) {
|
||||
return &dns.CNAME{
|
||||
@ -140,39 +173,20 @@ func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) {
|
||||
}
|
||||
}
|
||||
|
||||
// NewA constructs the new resource record of type A. a must be a valid 4-byte
|
||||
// IPv4-address.
|
||||
func NewA(name string, ttl uint32, a net.IP) (rr dns.RR) {
|
||||
return &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
A: a,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAAAA constructs the new resource record of type AAAA. aaaa must be a
|
||||
// valid 16-byte IPv6-address.
|
||||
func NewAAAA(name string, ttl uint32, aaaa net.IP) (rr dns.RR) {
|
||||
return &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
AAAA: aaaa,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPS constructs the new resource record of type HTTPS with IPv4 and IPv6
|
||||
// hint records from provided v4Hint and v6Hint parameters.
|
||||
//
|
||||
// TODO(d.kolyshev): Add "alpn" and other SVCB key-value pairs.
|
||||
func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
|
||||
func NewHTTPS(name string, ttl uint32, v4Hints, v6Hints []netip.Addr) (rr dns.RR) {
|
||||
v4Hint := &dns.SVCBIPv4Hint{}
|
||||
for _, ip := range v4Hints {
|
||||
v4Hint.Hint = append(v4Hint.Hint, ip.AsSlice())
|
||||
}
|
||||
v6Hint := &dns.SVCBIPv6Hint{}
|
||||
for _, ip := range v6Hints {
|
||||
v6Hint.Hint = append(v6Hint.Hint, ip.AsSlice())
|
||||
}
|
||||
|
||||
svcb := dns.SVCB{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
@ -181,10 +195,7 @@ func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
|
||||
Ttl: ttl,
|
||||
},
|
||||
Target: dns.Fqdn(name),
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBIPv4Hint{Hint: v4Hint},
|
||||
&dns.SVCBIPv6Hint{Hint: v6Hint},
|
||||
},
|
||||
Value: []dns.SVCBKeyValue{v4Hint, v6Hint},
|
||||
}
|
||||
|
||||
return &dns.HTTPS{
|
||||
@ -192,6 +203,49 @@ func NewHTTPS(name string, ttl uint32, v4Hint, v6Hint []net.IP) (rr dns.RR) {
|
||||
}
|
||||
}
|
||||
|
||||
// NewPTR constructs the new resource record of type PTR.
|
||||
func NewPTR(name string, ttl uint32, target string) (rr dns.RR) {
|
||||
return &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
Ptr: dns.Fqdn(target),
|
||||
}
|
||||
}
|
||||
|
||||
// NewSRV constructs the new resource record of type SRV.
|
||||
func NewSRV(name string, ttl uint32, target string, prio, weight, port uint16) (rr dns.RR) {
|
||||
return &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
Priority: prio,
|
||||
Weight: weight,
|
||||
Port: port,
|
||||
Target: target,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTXT constructs the new resource record of type TXT. txts are put into the
|
||||
// TXT record as is.
|
||||
func NewTXT(name string, ttl uint32, txts ...string) (rr dns.RR) {
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(name),
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
Txt: txts,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSOA constructs the new resource record of type SOA.
|
||||
func NewSOA(name string, ttl uint32, ns, mbox string) (rr dns.RR) {
|
||||
return &dns.SOA{
|
||||
|
@ -2,7 +2,7 @@ package dnsservertest_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@ -45,7 +45,7 @@ func ExampleNewResp() {
|
||||
|
||||
m = dnsservertest.NewResp(dns.RcodeSuccess, m, dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewCNAME(testFQDN, 3600, realTestFQDN),
|
||||
dnsservertest.NewA(realTestFQDN, 3600, net.IP{1, 2, 3, 4}),
|
||||
dnsservertest.NewA(realTestFQDN, 3600, netip.MustParseAddr("1.2.3.4")),
|
||||
}, dnsservertest.SectionNs{
|
||||
dnsservertest.NewSOA(realTestFQDN, 1000, "ns."+realTestFQDN, "mbox."+realTestFQDN),
|
||||
dnsservertest.NewNS(testFQDN, 1000, "ns."+testFQDN),
|
||||
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/golibs v0.13.6
|
||||
github.com/AdguardTeam/golibs v0.14.0
|
||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||
github.com/ameshkov/dnsstamps v1.0.3
|
||||
github.com/bluele/gcache v0.0.2
|
||||
@ -11,11 +11,11 @@ require (
|
||||
github.com/panjf2000/ants/v2 v2.7.5
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
|
||||
github.com/prometheus/client_golang v1.15.1
|
||||
github.com/quic-go/quic-go v0.35.1
|
||||
github.com/quic-go/quic-go v0.38.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
|
||||
golang.org/x/net v0.12.0
|
||||
golang.org/x/sys v0.10.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.14.0
|
||||
golang.org/x/sys v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -27,21 +27,20 @@ require (
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
||||
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.10.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.4.0 // indirect
|
||||
github.com/prometheus/common v0.44.0 // indirect
|
||||
github.com/prometheus/procfs v0.10.1 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
golang.org/x/tools v0.10.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
github.com/AdguardTeam/golibs v0.13.6 h1:z/0Q25pRLdaQxtoxvfSaooz5mdv8wj0R8KREj54q8yQ=
|
||||
github.com/AdguardTeam/golibs v0.13.6/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
|
||||
github.com/AdguardTeam/golibs v0.14.0 h1:/vfJshXBVaevMuBgzAIr+F64XdNqZL+j9F33GXJmgeQ=
|
||||
github.com/AdguardTeam/golibs v0.14.0/go.mod h1:hOtcb8dPfKcFjWTPA904hTA4dl1aWvzeebdJpE72IPk=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
|
||||
@ -29,8 +29,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||
github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
@ -38,9 +37,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zk
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
|
||||
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
|
||||
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
|
||||
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
|
||||
github.com/panjf2000/ants/v2 v2.7.5 h1:/vhh0Hza9G1vP1PdCj9hl6MUzCRbmtcTJL0OsnmytuU=
|
||||
github.com/panjf2000/ants/v2 v2.7.5/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8=
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible h1:IWzUvJ72xMjmrjR9q3H1PF+jwdN0uNQiR2t1BLNalyo=
|
||||
@ -57,12 +55,8 @@ github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+Pymzi
|
||||
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
|
||||
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM=
|
||||
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
@ -76,18 +70,14 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
|
||||
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -98,18 +88,15 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
||||
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
|
||||
golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
// responsePaddingMaxSize is used to calculate the EDNS padding length. We use
|
||||
// the Random-Length Padding strategy from RFC 8467 as we find it more
|
||||
// efficient, it requires less extra traffic while provides comparable entropy.
|
||||
const responsePaddingMaxSize = 128
|
||||
const responsePaddingMaxSize = 32
|
||||
|
||||
// respPadBuf is a fixed buffer to draw on for padding.
|
||||
var respPadBuf [responsePaddingMaxSize]byte
|
||||
@ -64,7 +64,7 @@ func normalize(network Network, proto Protocol, req, resp *dns.Msg) {
|
||||
|
||||
// In the case of encrypted protocols we should pad responses.
|
||||
if proto.HasPaddingSupport() {
|
||||
padAnswer(resp, reqOpt, respOpt)
|
||||
padAnswer(reqOpt, respOpt)
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,7 +123,7 @@ func filterUnsupportedOptions(o []dns.EDNS0) (supported []dns.EDNS0) {
|
||||
// padAnswer adds padding to a DNS response before it's sent back over an
|
||||
// encrypted DNS protocol according to RFC 8467. Unencrypted responses should
|
||||
// not be padded. Inspired by github.com/folbricht/routedns padding.
|
||||
func padAnswer(resp *dns.Msg, reqOpt, respOpt *dns.OPT) {
|
||||
func padAnswer(reqOpt, respOpt *dns.OPT) {
|
||||
if findOption[*dns.EDNS0_PADDING](reqOpt) == nil {
|
||||
// According to the RFC, responders MAY (or may not) pad responses when
|
||||
// the padding option is not included in the request. In our case, we
|
||||
@ -146,22 +146,14 @@ func padAnswer(resp *dns.Msg, reqOpt, respOpt *dns.OPT) {
|
||||
// TODO(ameshkov): Consider changing to crypto/rand, need to hold a vote.
|
||||
// #nosec G404 -- We don't need a real random for a simple padding
|
||||
// randomization, pseudo-random is enough.
|
||||
padLen := rand.Intn(responsePaddingMaxSize-1) + 1
|
||||
|
||||
// If padding would make the packet larger than the request EDNS0 allows,
|
||||
// we need to truncate it.
|
||||
//
|
||||
// TODO(ameshkov): Consider removing this check and not calling resp.Len().
|
||||
// resp.Len() is a rather heavy function which we'd better avoid calling.
|
||||
// However, we risk having a message larger than 64kB in this case.
|
||||
answerLen := resp.Len()
|
||||
if packetSize := int(reqOpt.UDPSize()); answerLen+padLen > packetSize {
|
||||
padLen = packetSize - answerLen
|
||||
if padLen < 0 {
|
||||
// Still doesn't fit? Give up on padding.
|
||||
padLen = 0
|
||||
}
|
||||
}
|
||||
// Note, that we don't check for whether reqOpt.UDPSize() here is smaller
|
||||
// than resp.Len() + padLen so in theory the padded response may be larger
|
||||
// than 64kB. This is an acceptable risk considering the savings on
|
||||
// avoiding calling resp.Len().
|
||||
//
|
||||
// TODO(ameshkov): Return this check if we optimize resp.Len().
|
||||
padLen := rand.Intn(responsePaddingMaxSize-1) + 1
|
||||
|
||||
paddingOpt.Padding = respPadBuf[:padLen:padLen]
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package prometheus_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -28,20 +27,17 @@ func TestCacheMetricsListener_integration_cache(t *testing.T) {
|
||||
cacheMiddleware,
|
||||
)
|
||||
|
||||
// Pass 10 requests through the middleware
|
||||
// This way we'll increment and set both hits and misses.
|
||||
// Pass 10 requests through the middleware. This way we'll increment and
|
||||
// set both hits and misses.
|
||||
for i := 0; i < 10; i++ {
|
||||
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
|
||||
addr := &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 53}
|
||||
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), dnsserver.ServerInfo{
|
||||
Name: "test_server",
|
||||
Addr: "127.0.0.1:0",
|
||||
Proto: dnsserver.ProtoDNS,
|
||||
})
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
|
||||
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
|
||||
|
||||
nrw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
|
||||
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
|
||||
err := handlerWithMiddleware.ServeDNS(ctx, nrw, req)
|
||||
require.NoError(t, err)
|
||||
dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false)
|
||||
|
@ -29,7 +29,7 @@ func TestForwardMetricsListener_integration_request(t *testing.T) {
|
||||
// Prepare a test DNS message and call the handler's ServeDNS function.
|
||||
// It will then call the metrics listener and prom metrics should be
|
||||
// incremented.
|
||||
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr())
|
||||
|
||||
err := handler.ServeDNS(context.Background(), rw, req)
|
||||
|
@ -1,95 +1,95 @@
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// counterWithRequestLabels is a helper method that gets or creates a
|
||||
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
|
||||
// this method is to avoid allocating [prometheus.Labels] and instead use the
|
||||
// WithLabelValues function. This way extra allocations are avoided, but it is
|
||||
// sensitive to the labels order.
|
||||
func counterWithRequestLabels(
|
||||
serverInfo dnsserver.ServerInfo,
|
||||
// reqLabelMetricKey contains the information for a request label.
|
||||
type reqLabelMetricKey struct {
|
||||
network string
|
||||
qType string
|
||||
family string
|
||||
srvInfo dnsserver.ServerInfo
|
||||
}
|
||||
|
||||
// newReqLabelMetricKey returns a new metric key from the given data.
|
||||
func newReqLabelMetricKey(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
rw dnsserver.ResponseWriter,
|
||||
vec *prometheus.CounterVec,
|
||||
) (c prometheus.Counter) {
|
||||
ip, _ := netutil.IPAndPortFromAddr(rw.RemoteAddr())
|
||||
) (k reqLabelMetricKey) {
|
||||
return reqLabelMetricKey{
|
||||
network: string(dnsserver.NetworkFromAddr(rw.LocalAddr())),
|
||||
qType: typeToString(req),
|
||||
family: raddrToFamily(rw.RemoteAddr()),
|
||||
srvInfo: dnsserver.MustServerInfoFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// Address family metric.
|
||||
var family string
|
||||
if ip == nil {
|
||||
// Unknown.
|
||||
family = "0"
|
||||
} else if ip.To4() != nil {
|
||||
// IPv4.
|
||||
family = "1"
|
||||
} else {
|
||||
// IPv6.
|
||||
family = "2"
|
||||
// withLabelValues returns a counter with the given arguments in the correct
|
||||
// order.
|
||||
func (k reqLabelMetricKey) withLabelValues(vec *prometheus.CounterVec) (c prometheus.Counter) {
|
||||
// The labels must be in the following order:
|
||||
// 1. server name;
|
||||
// 2. server protocol;
|
||||
// 3. server socket network ("tcp"/"udp");
|
||||
// 4. server addr;
|
||||
// 5. question type (see [typeToString]);
|
||||
// 6. IP family (see [raddrToFamily]).
|
||||
return vec.WithLabelValues(
|
||||
k.srvInfo.Name,
|
||||
k.srvInfo.Proto.String(),
|
||||
k.network,
|
||||
k.srvInfo.Addr,
|
||||
k.qType,
|
||||
k.family,
|
||||
)
|
||||
}
|
||||
|
||||
// prometheusVector is the interface for vectors of counters, histograms, etc.
|
||||
type prometheusVector[T any] interface {
|
||||
WithLabelValues(labelValues ...string) (m T)
|
||||
}
|
||||
|
||||
// withSrvInfoLabelValues returns a metric with the server info data in the
|
||||
// correct order.
|
||||
func withSrvInfoLabelValues[T any](
|
||||
vec prometheusVector[T],
|
||||
srvInfo dnsserver.ServerInfo,
|
||||
) (m T) {
|
||||
// The labels must be in the following order:
|
||||
// 1. server name;
|
||||
// 2. server protocol;
|
||||
// 3. server addr;
|
||||
return vec.WithLabelValues(
|
||||
srvInfo.Name,
|
||||
srvInfo.Proto.String(),
|
||||
srvInfo.Addr,
|
||||
)
|
||||
}
|
||||
|
||||
// raddrToFamily returns a family metric value for raddr.
|
||||
// The values are:
|
||||
//
|
||||
// 0. Unknown.
|
||||
// 1. IPv4.
|
||||
// 2. IPv6.
|
||||
func raddrToFamily(raddr net.Addr) (family string) {
|
||||
ip := netutil.NetAddrToAddrPort(raddr).Addr()
|
||||
|
||||
if !ip.IsValid() {
|
||||
return "0"
|
||||
} else if ip.Is4() {
|
||||
return "1"
|
||||
}
|
||||
|
||||
// The metric's labels MUST be in the following order:
|
||||
// "name", "proto", "network", "addr", "type", "family"
|
||||
return vec.WithLabelValues(
|
||||
serverInfo.Name,
|
||||
serverInfo.Proto.String(),
|
||||
string(dnsserver.NetworkFromAddr(rw.LocalAddr())),
|
||||
serverInfo.Addr,
|
||||
typeToString(req),
|
||||
family,
|
||||
)
|
||||
}
|
||||
|
||||
// counterWithRequestLabels is a helper method that gets or creates a
|
||||
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
|
||||
// this method is to avoid allocating [prometheus.Labels] and instead use the
|
||||
// WithLabelValues function. This way extra allocations are avoided, but it is
|
||||
// sensitive to the labels order.
|
||||
func counterWithServerLabels(
|
||||
serverInfo dnsserver.ServerInfo,
|
||||
vec *prometheus.CounterVec,
|
||||
) (c prometheus.Counter) {
|
||||
// The metric's labels MUST be in the following order:
|
||||
// "name", "proto", "addr"
|
||||
return vec.WithLabelValues(
|
||||
serverInfo.Name,
|
||||
serverInfo.Proto.String(),
|
||||
serverInfo.Addr,
|
||||
)
|
||||
}
|
||||
|
||||
// histogramWithServerLabels is a helper method that gets or creates a
|
||||
// [prometheus.Observer] from the specified *prometheus.HistogramVec. The point
|
||||
// of this method is to avoid allocating [prometheus.Labels] and instead use the
|
||||
// WithLabelValues function. This way extra allocations are avoided, but it is
|
||||
// sensitive to the labels order.
|
||||
func histogramWithServerLabels(
|
||||
serverInfo dnsserver.ServerInfo,
|
||||
vec *prometheus.HistogramVec,
|
||||
) (h prometheus.Observer) {
|
||||
// The metric's labels MUST be in the following order:
|
||||
// "name", "proto", "addr"
|
||||
return vec.WithLabelValues(serverInfo.Name, serverInfo.Proto.String(), serverInfo.Addr)
|
||||
}
|
||||
|
||||
// counterWithServerLabelsPlusRCode is a helper method that gets or creates a
|
||||
// [prometheus.Counter] from the specified *prometheus.CounterVec. The point of
|
||||
// this method is to avoid allocating [prometheus.Labels] and instead use the
|
||||
// WithLabelValues function. This way extra allocations are avoided, but it is
|
||||
// sensitive to the labels order.
|
||||
func counterWithServerLabelsPlusRCode(
|
||||
serverInfo dnsserver.ServerInfo,
|
||||
rCode string,
|
||||
vec *prometheus.CounterVec,
|
||||
) (c prometheus.Counter) {
|
||||
// The metric's labels MUST be in the following order:
|
||||
// "name", "proto", "addr", "rcode"
|
||||
return vec.WithLabelValues(serverInfo.Name, serverInfo.Proto.String(), serverInfo.Addr, rCode)
|
||||
return "2"
|
||||
}
|
||||
|
||||
// setBoolGauge sets gauge to the numeric value corresponding to the val.
|
||||
|
51
internal/dnsserver/prometheus/initsyncmap.go
Normal file
51
internal/dnsserver/prometheus/initsyncmap.go
Normal file
@ -0,0 +1,51 @@
|
||||
package prometheus
|
||||
|
||||
import "sync"
|
||||
|
||||
// initSyncMap is a wrapper around [*sync.Map] that initializes the data if it's
|
||||
// not present in an atomic way.
|
||||
//
|
||||
// TODO(a.garipov): Move to golibs and use more.
|
||||
type initSyncMap[K, V any] struct {
|
||||
inner *sync.Map
|
||||
new func(k K) (v V)
|
||||
}
|
||||
|
||||
// newInitSyncMap returns a new properly initialized *initSyncMap that uses
|
||||
// newFunc to return a value for the given key.
|
||||
func newInitSyncMap[K, V any](newFunc func(k K) (v V)) (m *initSyncMap[K, V]) {
|
||||
return &initSyncMap[K, V]{
|
||||
inner: &sync.Map{},
|
||||
new: newFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// get returns a value for the given key. If a value isn't available, it waits
|
||||
// until it is.
|
||||
func (m *initSyncMap[K, V]) get(key K) (v V) {
|
||||
// Step 1. The fast track: check if there is already a value present.
|
||||
loadVal, inited := m.inner.Load(key)
|
||||
if inited {
|
||||
return loadVal.(func() (v V))()
|
||||
}
|
||||
|
||||
// Step 2. Allocate a done channel and create a function that waits for one
|
||||
// single initialization. Use the one returned from LoadOrStore regardless
|
||||
// of whether it's this one.
|
||||
var cached V
|
||||
done := make(chan struct{}, 1)
|
||||
done <- struct{}{}
|
||||
loadVal, _ = m.inner.LoadOrStore(key, func() (loaded V) {
|
||||
_, ok := <-done
|
||||
if ok {
|
||||
// The only real receive. Initialize the cached value and close the
|
||||
// channel so that other goroutines receive the same value.
|
||||
cached = m.new(key)
|
||||
close(done)
|
||||
}
|
||||
|
||||
return cached
|
||||
})
|
||||
|
||||
return loadVal.(func() (v V))()
|
||||
}
|
41
internal/dnsserver/prometheus/initsyncmap_internal_test.go
Normal file
41
internal/dnsserver/prometheus/initsyncmap_internal_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInitSyncMap(t *testing.T) {
|
||||
numCalls := atomic.Uint32{}
|
||||
m := newInitSyncMap[int, int](func(k int) (v int) {
|
||||
numCalls.Add(1)
|
||||
|
||||
return k + 1
|
||||
})
|
||||
|
||||
const (
|
||||
n = 1_000
|
||||
|
||||
key = 1
|
||||
want = key + 1
|
||||
)
|
||||
|
||||
results := make(chan int, n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
results <- m.get(key)
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
got, _ := testutil.RequireReceive(t, results, 1*time.Second)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint32(1), numCalls.Load())
|
||||
}
|
@ -1,8 +1,10 @@
|
||||
package prometheus_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -12,6 +14,22 @@ func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testReqDomain is the common request domain for tests.
|
||||
const testReqDomain = "request-domain.example"
|
||||
|
||||
// testServerInfo is the common server information structure for tests.
|
||||
var testServerInfo = dnsserver.ServerInfo{
|
||||
Name: "test_server",
|
||||
Addr: "127.0.0.1:80",
|
||||
Proto: dnsserver.ProtoDNS,
|
||||
}
|
||||
|
||||
// testUDPAddr is the common UDP address for tests.
|
||||
var testUDPAddr = &net.UDPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Port: 53,
|
||||
}
|
||||
|
||||
// requireMetrics accepts a list of metrics names and checks that
|
||||
// they exist in the prom registry.
|
||||
func requireMetrics(t testing.TB, args ...string) {
|
||||
@ -34,6 +52,5 @@ func requireMetrics(t testing.TB, args ...string) {
|
||||
delete(metricsToCheck, m.GetName())
|
||||
}
|
||||
|
||||
require.Len(t, metricsToCheck, 0,
|
||||
"Some metrics weren't reported: %v", metricsToCheck)
|
||||
require.Len(t, metricsToCheck, 0, "Some metrics weren't reported: %v", metricsToCheck)
|
||||
}
|
||||
|
@ -10,33 +10,47 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
// RateLimitMetricsListener implements the ratelimit.MetricsListener interface
|
||||
// RateLimitMetricsListener implements the [ratelimit.MetricsListener] interface
|
||||
// and increments prom counters.
|
||||
type RateLimitMetricsListener struct{}
|
||||
type RateLimitMetricsListener struct {
|
||||
dropCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
|
||||
allowlistedCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
|
||||
}
|
||||
|
||||
// NewRateLimitMetricsListener returns a new properly initialized
|
||||
// *RateLimitMetricsListener.
|
||||
func NewRateLimitMetricsListener() (l *RateLimitMetricsListener) {
|
||||
return &RateLimitMetricsListener{
|
||||
dropCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
|
||||
return k.withLabelValues(droppedTotal)
|
||||
}),
|
||||
allowlistedCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
|
||||
return k.withLabelValues(allowlistedTotal)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ratelimit.MetricsListener = (*RateLimitMetricsListener)(nil)
|
||||
|
||||
// OnRateLimited implements the ratelimit.MetricsListener interface for
|
||||
// *RateLimitMetricsListener.
|
||||
func (r *RateLimitMetricsListener) OnRateLimited(
|
||||
func (l *RateLimitMetricsListener) OnRateLimited(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
rw dnsserver.ResponseWriter,
|
||||
) {
|
||||
s := dnsserver.MustServerInfoFromContext(ctx)
|
||||
counterWithRequestLabels(s, req, rw, droppedTotal).Inc()
|
||||
l.dropCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
|
||||
}
|
||||
|
||||
// OnAllowlisted implements the ratelimit.MetricsListener interface for
|
||||
// *RateLimitMetricsListener.
|
||||
func (r *RateLimitMetricsListener) OnAllowlisted(
|
||||
func (l *RateLimitMetricsListener) OnAllowlisted(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
rw dnsserver.ResponseWriter,
|
||||
) {
|
||||
s := dnsserver.MustServerInfoFromContext(ctx)
|
||||
counterWithRequestLabels(s, req, rw, allowlistedTotal).Inc()
|
||||
l.allowlistedCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
|
||||
}
|
||||
|
||||
// This block contains prometheus metrics declarations for ratelimit.Middleware
|
||||
|
@ -2,7 +2,6 @@ package prometheus_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@ -33,7 +32,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
})
|
||||
rlMw, err := ratelimit.NewMiddleware(rl, nil)
|
||||
require.NoError(t, err)
|
||||
rlMw.Metrics = &prometheus.RateLimitMetricsListener{}
|
||||
rlMw.Metrics = prometheus.NewRateLimitMetricsListener()
|
||||
|
||||
handlerWithMiddleware := dnsserver.WithMiddlewares(
|
||||
dnsservertest.DefaultHandler(),
|
||||
@ -42,17 +41,14 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
|
||||
// Pass 10 requests through the middleware.
|
||||
for i := 0; i < 10; i++ {
|
||||
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
|
||||
addr := &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 53}
|
||||
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), dnsserver.ServerInfo{
|
||||
Name: "test",
|
||||
Addr: "127.0.0.1",
|
||||
Proto: dnsserver.ProtoDNS,
|
||||
})
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
|
||||
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
|
||||
|
||||
nrw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
|
||||
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
|
||||
err = handlerWithMiddleware.ServeDNS(ctx, nrw, req)
|
||||
require.NoError(t, err)
|
||||
if i < rps {
|
||||
@ -65,3 +61,35 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
// Now make sure that prometheus metrics were incremented properly.
|
||||
requireMetrics(t, "dns_ratelimit_dropped_total")
|
||||
}
|
||||
|
||||
func BenchmarkRateLimitMetricsListener(b *testing.B) {
|
||||
l := prometheus.NewRateLimitMetricsListener()
|
||||
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
rw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
|
||||
|
||||
b.Run("OnAllowlisted", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnAllowlisted(ctx, req, rw)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("OnRateLimited", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnRateLimited(ctx, req, rw)
|
||||
}
|
||||
})
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkRateLimitMetricsListener/OnAllowlisted-16 6025423 209.5 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkRateLimitMetricsListener/OnRateLimited-16 5798031 209.4 ns/op 0 B/op 0 allocs/op
|
||||
}
|
||||
|
@ -12,7 +12,80 @@ import (
|
||||
|
||||
// ServerMetricsListener implements the [dnsserver.MetricsListener] interface
|
||||
// and increments prom counters.
|
||||
type ServerMetricsListener struct{}
|
||||
type ServerMetricsListener struct {
|
||||
reqTotalCounters *initSyncMap[reqLabelMetricKey, prometheus.Counter]
|
||||
|
||||
respRCodeCounters *initSyncMap[srvInfoRCode, prometheus.Counter]
|
||||
|
||||
invalidMsgCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
|
||||
errorCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
|
||||
panicCounters *initSyncMap[dnsserver.ServerInfo, prometheus.Counter]
|
||||
|
||||
reqDurationHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
|
||||
reqSizeHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
|
||||
respSizeHistograms *initSyncMap[dnsserver.ServerInfo, prometheus.Observer]
|
||||
}
|
||||
|
||||
// srvInfoRCode is a struct containing the server information along with a
|
||||
// response code.
|
||||
type srvInfoRCode struct {
|
||||
rCode string
|
||||
dnsserver.ServerInfo
|
||||
}
|
||||
|
||||
// withLabelValues returns a counter with the server info and rcode data in the
|
||||
// correct order.
|
||||
func (i srvInfoRCode) withLabelValues(
|
||||
vec *prometheus.CounterVec,
|
||||
) (c prometheus.Counter) {
|
||||
// The labels must be in the following order:
|
||||
// 1. server name;
|
||||
// 2. server protocol;
|
||||
// 3. server addr;
|
||||
// 4. response code;
|
||||
return vec.WithLabelValues(
|
||||
i.Name,
|
||||
i.Proto.String(),
|
||||
i.Addr,
|
||||
i.rCode,
|
||||
)
|
||||
}
|
||||
|
||||
// NewServerMetricsListener returns a new properly initialized
|
||||
// *ServerMetricsListener.
|
||||
func NewServerMetricsListener() (l *ServerMetricsListener) {
|
||||
return &ServerMetricsListener{
|
||||
reqTotalCounters: newInitSyncMap(func(k reqLabelMetricKey) (c prometheus.Counter) {
|
||||
return k.withLabelValues(requestTotal)
|
||||
}),
|
||||
|
||||
respRCodeCounters: newInitSyncMap(func(k srvInfoRCode) (c prometheus.Counter) {
|
||||
return k.withLabelValues(responseRCode)
|
||||
}),
|
||||
|
||||
invalidMsgCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
|
||||
// TODO(a.garipov): Here and below, remove explicit type
|
||||
// declarations in Go 1.21.
|
||||
return withSrvInfoLabelValues[prometheus.Counter](invalidMsgTotal, k)
|
||||
}),
|
||||
errorCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
|
||||
return withSrvInfoLabelValues[prometheus.Counter](errorTotal, k)
|
||||
}),
|
||||
panicCounters: newInitSyncMap(func(k dnsserver.ServerInfo) (c prometheus.Counter) {
|
||||
return withSrvInfoLabelValues[prometheus.Counter](panicTotal, k)
|
||||
}),
|
||||
|
||||
reqDurationHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
|
||||
return withSrvInfoLabelValues[prometheus.Observer](requestDuration, k)
|
||||
}),
|
||||
reqSizeHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
|
||||
return withSrvInfoLabelValues[prometheus.Observer](requestSize, k)
|
||||
}),
|
||||
respSizeHistograms: newInitSyncMap(func(k dnsserver.ServerInfo) (o prometheus.Observer) {
|
||||
return withSrvInfoLabelValues[prometheus.Observer](responseSize, k)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.MetricsListener = (*ServerMetricsListener)(nil)
|
||||
@ -21,51 +94,57 @@ var _ dnsserver.MetricsListener = (*ServerMetricsListener)(nil)
|
||||
// [*ServerMetricsListener].
|
||||
func (l *ServerMetricsListener) OnRequest(
|
||||
ctx context.Context,
|
||||
req, resp *dns.Msg,
|
||||
req *dns.Msg,
|
||||
resp *dns.Msg,
|
||||
rw dnsserver.ResponseWriter,
|
||||
) {
|
||||
serverInfo := dnsserver.MustServerInfoFromContext(ctx)
|
||||
startTime := dnsserver.MustStartTimeFromContext(ctx)
|
||||
|
||||
// Increment total requests count metrics.
|
||||
counterWithRequestLabels(serverInfo, req, rw, requestTotal).Inc()
|
||||
l.reqTotalCounters.get(newReqLabelMetricKey(ctx, req, rw)).Inc()
|
||||
|
||||
// Increment request duration histogram.
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
histogramWithServerLabels(serverInfo, requestDuration).Observe(elapsed)
|
||||
l.reqDurationHistograms.get(serverInfo).Observe(elapsed)
|
||||
|
||||
// Increment request size.
|
||||
ri := dnsserver.MustRequestInfoFromContext(ctx)
|
||||
histogramWithServerLabels(serverInfo, requestSize).Observe(float64(ri.RequestSize))
|
||||
l.reqSizeHistograms.get(serverInfo).Observe(float64(ri.RequestSize))
|
||||
|
||||
// If resp is not nil, increment response-related metrics.
|
||||
if resp != nil {
|
||||
histogramWithServerLabels(serverInfo, responseSize).Observe(float64(ri.ResponseSize))
|
||||
rCode := rCodeToString(resp.Rcode)
|
||||
counterWithServerLabelsPlusRCode(serverInfo, rCode, responseRCode).Inc()
|
||||
l.respSizeHistograms.get(serverInfo).Observe(float64(ri.ResponseSize))
|
||||
l.respRCodeCounters.get(srvInfoRCode{
|
||||
ServerInfo: serverInfo,
|
||||
rCode: rCodeToString(resp.Rcode),
|
||||
}).Inc()
|
||||
} else {
|
||||
// If resp is nil, increment responseRCode with a special "rcode"
|
||||
// label value ("DROPPED").
|
||||
counterWithServerLabelsPlusRCode(serverInfo, "DROPPED", responseRCode).Inc()
|
||||
// If resp is nil, increment responseRCode with a special "rcode" label
|
||||
// value ("DROPPED").
|
||||
l.respRCodeCounters.get(srvInfoRCode{
|
||||
ServerInfo: serverInfo,
|
||||
rCode: "DROPPED",
|
||||
}).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// OnInvalidMsg implements the [dnsserver.MetricsListener] interface for
|
||||
// [*ServerMetricsListener].
|
||||
func (l *ServerMetricsListener) OnInvalidMsg(ctx context.Context) {
|
||||
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), invalidMsgTotal).Inc()
|
||||
l.invalidMsgCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
|
||||
}
|
||||
|
||||
// OnError implements the [dnsserver.MetricsListener] interface for
|
||||
// [*ServerMetricsListener].
|
||||
func (l *ServerMetricsListener) OnError(ctx context.Context, _ error) {
|
||||
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), errorTotal).Inc()
|
||||
l.errorCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
|
||||
}
|
||||
|
||||
// OnPanic implements the [dnsserver.MetricsListener] interface for
|
||||
// [*ServerMetricsListener].
|
||||
func (l *ServerMetricsListener) OnPanic(ctx context.Context, _ any) {
|
||||
counterWithServerLabels(dnsserver.MustServerInfoFromContext(ctx), panicTotal).Inc()
|
||||
l.panicCounters.get(dnsserver.MustServerInfoFromContext(ctx)).Inc()
|
||||
}
|
||||
|
||||
// OnQUICAddressValidation implements the [dnsserver.MetricsListener] interface
|
||||
|
@ -3,10 +3,11 @@ package prometheus_test
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
prom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -24,7 +25,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
|
||||
Name: "test",
|
||||
Addr: "127.0.0.1:0",
|
||||
Handler: dnsservertest.DefaultHandler(),
|
||||
Metrics: &prom.ServerMetricsListener{},
|
||||
Metrics: prometheus.NewServerMetricsListener(),
|
||||
},
|
||||
}
|
||||
srv := dnsserver.NewServerDNS(conf)
|
||||
@ -39,7 +40,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
|
||||
})
|
||||
|
||||
// Create a test message.
|
||||
req := dnsservertest.CreateMessage("example.org", dns.TypeA)
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
|
||||
c := &dns.Client{Net: "tcp"}
|
||||
|
||||
@ -48,8 +49,8 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
|
||||
|
||||
// Pass 10 requests to make the test less flaky.
|
||||
for i := 0; i < 10; i++ {
|
||||
res, _, eerr := c.Exchange(req, addr)
|
||||
require.NoError(t, eerr)
|
||||
res, _, exchErr := c.Exchange(req, addr)
|
||||
require.NoError(t, exchErr)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, dns.RcodeSuccess, res.Rcode)
|
||||
}
|
||||
@ -64,3 +65,61 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
|
||||
"dns_server_response_rcode_total",
|
||||
)
|
||||
}
|
||||
|
||||
func BenchmarkServerMetricsListener(b *testing.B) {
|
||||
l := prometheus.NewServerMetricsListener()
|
||||
|
||||
ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
|
||||
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
|
||||
|
||||
req := dnsservertest.CreateMessage(testReqDomain, dns.TypeA)
|
||||
resp := (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess)
|
||||
ctx = dnsserver.ContextWithRequestInfo(ctx, dnsserver.RequestInfo{
|
||||
RequestSize: req.Len(),
|
||||
ResponseSize: resp.Len(),
|
||||
})
|
||||
|
||||
rw := dnsserver.NewNonWriterResponseWriter(testUDPAddr, testUDPAddr)
|
||||
|
||||
b.Run("OnRequest", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnRequest(ctx, req, resp, rw)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("OnInvalidMsg", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnInvalidMsg(ctx)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("OnError", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnError(ctx, nil)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("OnPanic", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.OnPanic(ctx, nil)
|
||||
}
|
||||
})
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkServerMetricsListener/OnRequest-16 1550391 716.7 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkServerMetricsListener/OnInvalidMsg-16 13041940 91.75 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkServerMetricsListener/OnError-16 12297494 97.04 ns/op 0 B/op 0 allocs/op
|
||||
// BenchmarkServerMetricsListener/OnPanic-16 14029394 89.19 ns/op 0 B/op 0 allocs/op
|
||||
}
|
||||
|
@ -19,8 +19,8 @@ type ConfigBase struct {
|
||||
// Name is used for logging, and it may be used for perf counters reporting.
|
||||
Name string
|
||||
|
||||
// Addr is the address the server listens to. See go doc net.Dial for
|
||||
// the documentation on the address format.
|
||||
// Addr is the address the server listens to. See [net.Dial] for the
|
||||
// documentation on the address format.
|
||||
Addr string
|
||||
|
||||
// Network is the network this server listens to. If empty, the server will
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -47,8 +48,8 @@ func TestService_writeDebugResponse(t *testing.T) {
|
||||
blockRule = "||example.com^"
|
||||
)
|
||||
|
||||
clientIPStr := testClientIP.String()
|
||||
serverIPStr := testServerAddr.String()
|
||||
clientIPStr := dnssvctest.ClientIP.String()
|
||||
serverIPStr := dnssvctest.ServerAddr.String()
|
||||
testCases := []struct {
|
||||
name string
|
||||
ri *agd.RequestInfo
|
||||
@ -129,26 +130,26 @@ func TestService_writeDebugResponse(t *testing.T) {
|
||||
}),
|
||||
}, {
|
||||
name: "device",
|
||||
ri: &agd.RequestInfo{Device: &agd.Device{ID: testDeviceID}},
|
||||
ri: &agd.RequestInfo{Device: &agd.Device{ID: dnssvctest.DeviceID}},
|
||||
reqRes: nil,
|
||||
respRes: nil,
|
||||
wantExtra: newTXTExtra([][2]string{
|
||||
{"client-ip.adguard-dns.com.", clientIPStr},
|
||||
{"server-ip.adguard-dns.com.", serverIPStr},
|
||||
{"device-id.adguard-dns.com.", testDeviceID},
|
||||
{"device-id.adguard-dns.com.", dnssvctest.DeviceIDStr},
|
||||
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||
}),
|
||||
}, {
|
||||
name: "profile",
|
||||
ri: &agd.RequestInfo{
|
||||
Profile: &agd.Profile{ID: testProfileID},
|
||||
Profile: &agd.Profile{ID: dnssvctest.ProfileID},
|
||||
},
|
||||
reqRes: nil,
|
||||
respRes: nil,
|
||||
wantExtra: newTXTExtra([][2]string{
|
||||
{"client-ip.adguard-dns.com.", clientIPStr},
|
||||
{"server-ip.adguard-dns.com.", serverIPStr},
|
||||
{"profile-id.adguard-dns.com.", testProfileID},
|
||||
{"profile-id.adguard-dns.com.", dnssvctest.ProfileIDStr},
|
||||
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||
}),
|
||||
}, {
|
||||
@ -182,7 +183,7 @@ func TestService_writeDebugResponse(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rw := dnsserver.NewNonWriterResponseWriter(testLocalAddr, testRAddr)
|
||||
rw := dnsserver.NewNonWriterResponseWriter(dnssvctest.LocalAddr, dnssvctest.RemoteAddr)
|
||||
|
||||
ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)
|
||||
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
|
||||
@ -20,6 +21,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/accessmw"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
@ -48,6 +51,9 @@ type Config struct {
|
||||
// active stream-connections.
|
||||
ConnLimiter *connlimiter.Limiter
|
||||
|
||||
// AccessManager is used to block requests.
|
||||
AccessManager access.Interface
|
||||
|
||||
// SafeBrowsing is the safe browsing TXT hash matcher.
|
||||
SafeBrowsing filter.HashMatcher
|
||||
|
||||
@ -397,7 +403,7 @@ func NewListener(
|
||||
|
||||
metricsListener := &errCollMetricsListener{
|
||||
errColl: errColl,
|
||||
baseListener: &prometheus.ServerMetricsListener{},
|
||||
baseListener: prometheus.NewServerMetricsListener(),
|
||||
}
|
||||
|
||||
confBase := dnsserver.ConfigBase{
|
||||
@ -473,61 +479,47 @@ func newServers(
|
||||
rlProtos := []agd.Protocol{agd.ProtoDNS}
|
||||
|
||||
var rlm *ratelimit.Middleware
|
||||
srvName := s.Name
|
||||
rlm, err = ratelimit.NewMiddleware(c.RateLimit, rlProtos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ratelimit: %w", err)
|
||||
return nil, fmt.Errorf("server %q: ratelimit: %w", srvName, err)
|
||||
}
|
||||
|
||||
rlm.Metrics = &prometheus.RateLimitMetricsListener{}
|
||||
rlm.Metrics = prometheus.NewRateLimitMetricsListener()
|
||||
|
||||
imw := &initMw{
|
||||
messages: c.Messages,
|
||||
fltGrp: fg,
|
||||
srvGrp: srvGrp,
|
||||
srv: s,
|
||||
db: c.ProfileDB,
|
||||
geoIP: c.GeoIP,
|
||||
errColl: c.ErrColl,
|
||||
}
|
||||
amw := accessmw.New(&accessmw.Config{
|
||||
AccessManager: c.AccessManager,
|
||||
})
|
||||
|
||||
imw := initial.New(&initial.Config{
|
||||
Messages: c.Messages,
|
||||
FilteringGroup: fg,
|
||||
ServerGroup: srvGrp,
|
||||
Server: s,
|
||||
ProfileDB: c.ProfileDB,
|
||||
GeoIP: c.GeoIP,
|
||||
ErrColl: c.ErrColl,
|
||||
})
|
||||
|
||||
h := dnsserver.WithMiddlewares(
|
||||
handler,
|
||||
|
||||
// Keep the rate limiting middleware as the outer one to make sure
|
||||
// that the application logic isn't touched if the request is
|
||||
// ratelimited.
|
||||
// Keep the rate limiting and access middlewares as the outer ones
|
||||
// to make sure that the application logic isn't touched if the
|
||||
// request is ratelimited or blocked by access settings.
|
||||
rlm,
|
||||
amw,
|
||||
imw,
|
||||
)
|
||||
|
||||
listeners := make([]*listener, 0, len(s.BindData))
|
||||
for _, bindData := range s.BindData {
|
||||
addr := bindData.Address
|
||||
if addr == "" {
|
||||
addr = bindData.AddrPort.String()
|
||||
}
|
||||
|
||||
name := listenerName(s.Name, addr, s.Protocol)
|
||||
|
||||
lc := bindData.ListenConfig
|
||||
if lc == nil {
|
||||
lc = newListenConfig(c.ControlConf, c.ConnLimiter, s.Protocol)
|
||||
}
|
||||
|
||||
var l Listener
|
||||
l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, lc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server %q: %w", s.Name, err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, &listener{
|
||||
name: name,
|
||||
Listener: l,
|
||||
})
|
||||
var listeners []*listener
|
||||
listeners, err = newListeners(c, s, h, newListener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server %q: %w", srvName, err)
|
||||
}
|
||||
|
||||
servers[i] = &server{
|
||||
name: s.Name,
|
||||
name: srvName,
|
||||
handler: h,
|
||||
listeners: listeners,
|
||||
}
|
||||
@ -536,16 +528,59 @@ func newServers(
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// newServers creates a slice of listeners for a server.
|
||||
func newListeners(
|
||||
c *Config,
|
||||
srv *agd.Server,
|
||||
handler dnsserver.Handler,
|
||||
newListener NewListenerFunc,
|
||||
) (listeners []*listener, err error) {
|
||||
listeners = make([]*listener, 0, len(srv.BindData))
|
||||
for i, bindData := range srv.BindData {
|
||||
addr := bindData.Address
|
||||
if addr == "" {
|
||||
addr = bindData.AddrPort.String()
|
||||
}
|
||||
|
||||
proto := srv.Protocol
|
||||
name := listenerName(srv.Name, addr, proto)
|
||||
|
||||
lc := newListenConfig(bindData.ListenConfig, c.ControlConf, c.ConnLimiter, proto)
|
||||
|
||||
var l Listener
|
||||
l, err = newListener(srv, name, addr, handler, c.NonDNS, c.ErrColl, lc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, &listener{
|
||||
name: name,
|
||||
Listener: l,
|
||||
})
|
||||
}
|
||||
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
|
||||
// servers. The resulting ListenConfig sets additional socket flags and
|
||||
// processes the control messages of connections created with ListenPacket.
|
||||
// Additionally, if l is not nil, it is used to limit the number of
|
||||
// simultaneously active stream-connections.
|
||||
func newListenConfig(
|
||||
original netext.ListenConfig,
|
||||
ctrlConf *netext.ControlConfig,
|
||||
l *connlimiter.Limiter,
|
||||
p agd.Protocol,
|
||||
) (lc netext.ListenConfig) {
|
||||
if original != nil {
|
||||
if l == nil {
|
||||
return original
|
||||
}
|
||||
|
||||
return connlimiter.NewListenConfig(original, l)
|
||||
}
|
||||
|
||||
if p == agd.ProtoDNS {
|
||||
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
|
||||
} else {
|
||||
|
@ -1,30 +0,0 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// Common addresses for tests.
|
||||
var (
|
||||
testClientIP = net.IP{1, 2, 3, 4}
|
||||
testRAddr = &net.TCPAddr{
|
||||
IP: testClientIP,
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
testClientAddrPort = testRAddr.AddrPort()
|
||||
testClientAddr = testClientAddrPort.Addr()
|
||||
|
||||
testServerAddr = netip.MustParseAddr("5.6.7.8")
|
||||
testLocalAddr = &net.TCPAddr{
|
||||
IP: testServerAddr.AsSlice(),
|
||||
Port: 54321,
|
||||
}
|
||||
)
|
||||
|
||||
// testDeviceID is the common device ID for tests
|
||||
const testDeviceID = "dev1234"
|
||||
|
||||
// testProfileID is the common profile ID for tests
|
||||
const testProfileID = "prof1234"
|
@ -1,334 +0,0 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The Initial Middleware
|
||||
|
||||
// initMw is the outermost middleware of the AdGuard DNS server. It filters out
|
||||
// the Firefox canary domain logic, sets and resets the AD bit for further
|
||||
// processing, as well as puts as much information as it can into the context.
|
||||
//
|
||||
// This middleware must be the most outer middleware apart from the ratelimit
|
||||
// one.
|
||||
//
|
||||
// TODO(a.garipov): Add tests.
|
||||
type initMw struct {
|
||||
// messages is used to build the responses specific for the request's
|
||||
// context.
|
||||
messages *dnsmsg.Constructor
|
||||
|
||||
// fltGrp is the filtering group to which srv belongs.
|
||||
fltGrp *agd.FilteringGroup
|
||||
|
||||
// srvGrp is the server group to which srv belongs.
|
||||
srvGrp *agd.ServerGroup
|
||||
|
||||
// srv is the current server which serves the request.
|
||||
srv *agd.Server
|
||||
|
||||
// db is the database of user profiles and devices.
|
||||
db profiledb.Interface
|
||||
|
||||
// geoIP detects the location of the request source.
|
||||
geoIP geoip.Interface
|
||||
|
||||
// errColl collects and reports the errors considered non-critical.
|
||||
errColl agd.ErrorCollector
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.Middleware = (*initMw)(nil)
|
||||
|
||||
// Wrap implements the [dnsserver.Middleware] interface for *initMw.
|
||||
func (mw *initMw) Wrap(h dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
return &initMwHandler{
|
||||
mw: mw,
|
||||
next: h,
|
||||
}
|
||||
}
|
||||
|
||||
// newRequestInfo returns the new request information structure using the
|
||||
// middleware's configuration and values from ctx.
|
||||
func (mw *initMw) newRequestInfo(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
laddr net.Addr,
|
||||
raddr net.Addr,
|
||||
fqdn string,
|
||||
qt dnsmsg.RRType,
|
||||
cl dnsmsg.Class,
|
||||
) (ri *agd.RequestInfo, err error) {
|
||||
// Put the host, server, and client IP data into the request information
|
||||
// immediately.
|
||||
ri = &agd.RequestInfo{
|
||||
FilteringGroup: mw.fltGrp,
|
||||
Messages: mw.messages,
|
||||
ServerGroup: mw.srvGrp.Name,
|
||||
Server: mw.srv.Name,
|
||||
Host: strings.TrimSuffix(fqdn, "."),
|
||||
QType: qt,
|
||||
QClass: cl,
|
||||
}
|
||||
|
||||
ri.RemoteIP = netutil.NetAddrToAddrPort(raddr).Addr()
|
||||
|
||||
// As an optimization, put the request ID closer to the top of the context
|
||||
// stack.
|
||||
ri.ID, _ = agd.RequestIDFromContext(ctx)
|
||||
|
||||
// Add the GeoIP information, if any.
|
||||
err = mw.addLocation(ctx, ri, req)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add the profile information, if any.
|
||||
localIP := netutil.NetAddrToAddrPort(laddr).Addr()
|
||||
err = mw.addProfile(ctx, ri, req, localIP)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
// addLocation adds GeoIP location information about the client's remote address
|
||||
// as well as the EDNS Client Subnet information, if there is one, to ri. err
|
||||
// is not nil only if req contains a malformed EDNS Client Subnet option.
|
||||
func (mw *initMw) addLocation(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) {
|
||||
ri.Location = mw.locationData(ctx, ri.RemoteIP, "client")
|
||||
|
||||
ecs, scope, err := dnsmsg.ECSFromMsg(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding ecs info: %w", err)
|
||||
} else if ecs != (netip.Prefix{}) {
|
||||
ri.ECS = &agd.ECS{
|
||||
Location: mw.locationData(ctx, ecs.Addr(), "ecs"),
|
||||
Subnet: ecs,
|
||||
Scope: scope,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// locationData returns the GeoIP location information about the IP address.
|
||||
// typ is the type of data being requested for error reporting and logging.
|
||||
func (mw *initMw) locationData(ctx context.Context, ip netip.Addr, typ string) (l *agd.Location) {
|
||||
l, err := mw.geoIP.Data("", ip)
|
||||
if err != nil {
|
||||
// Consider GeoIP errors non-critical. Report and go on.
|
||||
agd.Collectf(ctx, mw.errColl, "init mw: getting geoip for %s ip: %w", typ, err)
|
||||
}
|
||||
|
||||
if l == nil {
|
||||
optlog.Debug2("init mw: no geoip for %s ip %s", typ, ip)
|
||||
} else {
|
||||
optlog.Debug4("init mw: found country/asn %q/%d for %s ip %s", l.Country, l.ASN, typ, ip)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// addProfile adds profile and device information, if any, to the request
|
||||
// information.
|
||||
func (mw *initMw) addProfile(
|
||||
ctx context.Context,
|
||||
ri *agd.RequestInfo,
|
||||
req *dns.Msg,
|
||||
localIP netip.Addr,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting profile from req: %w") }()
|
||||
|
||||
var id agd.DeviceID
|
||||
if p := mw.srv.Protocol; p.IsStdEncrypted() {
|
||||
// Assume that mw.srvGrp.TLS is non-nil if p.IsStdEncrypted() is true.
|
||||
wildcards := mw.srvGrp.TLS.DeviceIDWildcards
|
||||
id, err = deviceIDFromContext(ctx, mw.srv.Protocol, wildcards)
|
||||
} else if p == agd.ProtoDNS {
|
||||
id, err = deviceIDFromEDNS(req)
|
||||
} else {
|
||||
// No DeviceID for DNSCrypt yet.
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
optlog.Debug3("init mw: got device id %q, raddr %s, and laddr %s", id, ri.RemoteIP, localIP)
|
||||
|
||||
prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, profiledb.ErrDeviceNotFound) {
|
||||
// Very unlikely, since those two error types are the only ones
|
||||
// currently returned from the default profile DB.
|
||||
return fmt.Errorf("unexpected profiledb error: %s", err)
|
||||
}
|
||||
|
||||
optlog.Debug1("init mw: profile or device not found: %s", err)
|
||||
} else if prof.Deleted {
|
||||
optlog.Debug1("init mw: profile %s is deleted", prof.ID)
|
||||
} else {
|
||||
optlog.Debug3("init mw: found profile %s and device %s by %s", prof.ID, dev.ID, byWhat)
|
||||
|
||||
ri.Device, ri.Profile = dev, prof
|
||||
ri.Messages = dnsmsg.NewConstructor(prof.BlockingMode.Mode, prof.FilteredResponseTTL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constants for the parameter by which a device has been found.
|
||||
const (
|
||||
byDeviceID = "device id"
|
||||
byDedicatedIP = "dedicated ip"
|
||||
byLinkedIP = "linked ip"
|
||||
)
|
||||
|
||||
// profile finds the profile by the client data.
|
||||
func (mw *initMw) profile(
|
||||
ctx context.Context,
|
||||
localIP netip.Addr,
|
||||
remoteIP netip.Addr,
|
||||
id agd.DeviceID,
|
||||
) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) {
|
||||
if id != "" {
|
||||
prof, dev, err = mw.db.ProfileByDeviceID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
return prof, dev, byDeviceID, nil
|
||||
}
|
||||
|
||||
if !mw.srv.LinkedIPEnabled {
|
||||
optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name)
|
||||
|
||||
return nil, nil, "", profiledb.ErrDeviceNotFound
|
||||
} else if p := mw.srv.Protocol; p != agd.ProtoDNS {
|
||||
optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p)
|
||||
|
||||
return nil, nil, "", profiledb.ErrDeviceNotFound
|
||||
}
|
||||
|
||||
byWhat = byDedicatedIP
|
||||
prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP)
|
||||
if errors.Is(err, profiledb.ErrDeviceNotFound) {
|
||||
byWhat = byLinkedIP
|
||||
prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
return prof, dev, byWhat, nil
|
||||
}
|
||||
|
||||
// initMwHandler implements the [dnsserver.Handler] interface and will be used
|
||||
// as a [dnsserver.Handler] that the initMw middleware returns from the Wrap
|
||||
// function call.
|
||||
type initMwHandler struct {
|
||||
mw *initMw
|
||||
next dnsserver.Handler
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.Handler = (*initMwHandler)(nil)
|
||||
|
||||
// ServeDNS implements the [dnsserver.Handler] interface for *initMwHandler.
|
||||
func (mh *initMwHandler) ServeDNS(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
req *dns.Msg,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "init mw: %w") }()
|
||||
|
||||
// Save the actual value of the request AD and DO bits and set the AD
|
||||
// bit in the request to true, so that the upstream validates the data
|
||||
// and caches the actual value of the response AD bit. Restore it
|
||||
// later, depending on the request and response data.
|
||||
reqAD := req.AuthenticatedData
|
||||
reqDO := dnsmsg.IsDO(req)
|
||||
req.AuthenticatedData = true
|
||||
|
||||
// Assume that module dnsserver has already validated that the request
|
||||
// always has exactly one question for us.
|
||||
q := req.Question[0]
|
||||
fqdn := strings.ToLower(q.Name)
|
||||
qt := q.Qtype
|
||||
cl := q.Qclass
|
||||
|
||||
// Copy middleware to the local variable to make the code simpler.
|
||||
mw := mh.mw
|
||||
|
||||
// Get the request's information, such as GeoIP data and user profiles.
|
||||
ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), fqdn, qt, cl)
|
||||
if err != nil {
|
||||
var ecsErr dnsmsg.BadECSError
|
||||
if errors.As(err, &ecsErr) {
|
||||
// We've got a bad ECS option. Log and respond with a FORMERR
|
||||
// immediately.
|
||||
optlog.Debug1("init mw: %s", err)
|
||||
|
||||
err = rw.WriteMsg(ctx, req, mw.messages.NewMsgFORMERR(req))
|
||||
err = errors.Annotate(err, "writing formerr resp: %w")
|
||||
}
|
||||
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already errors.Annotate here.
|
||||
return err
|
||||
}
|
||||
|
||||
if specHdlr, name := mw.reqInfoSpecialHandler(ri, cl); specHdlr != nil {
|
||||
optlog.Debug1("init mw: got req-info special handler %s", name)
|
||||
|
||||
// Don't wrap the error, because it's informative enough as is, and
|
||||
// because if handled is true, the main flow terminates here.
|
||||
return specHdlr(ctx, rw, req, ri)
|
||||
}
|
||||
|
||||
ctx = agd.ContextWithRequestInfo(ctx, ri)
|
||||
|
||||
// Record the response, restore the AD bit value in both the request and
|
||||
// the response, and write the response.
|
||||
nwrw := makeNonWriter(rw)
|
||||
err = mh.next.ServeDNS(ctx, nwrw, req)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already errors.Annotate here.
|
||||
return err
|
||||
}
|
||||
|
||||
resp := nwrw.Msg()
|
||||
|
||||
// Following RFC 6840, set the AD bit in the response only when the
|
||||
// response is authenticated, and the request contained either a set DO
|
||||
// bit or a set AD bit.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6840#section-5.8.
|
||||
resp.AuthenticatedData = resp.AuthenticatedData && (reqAD || reqDO)
|
||||
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing resp: %w")
|
||||
}
|
@ -1,855 +0,0 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func TestInitMw_profile(t *testing.T) {
|
||||
prof := &agd.Profile{
|
||||
ID: testProfileID,
|
||||
DeviceIDs: []agd.DeviceID{
|
||||
testDeviceID,
|
||||
},
|
||||
}
|
||||
dev := &agd.Device{
|
||||
ID: testDeviceID,
|
||||
LinkedIP: testClientAddr,
|
||||
DedicatedIPs: []netip.Addr{
|
||||
testServerAddr,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
wantDev *agd.Device
|
||||
wantProf *agd.Profile
|
||||
wantByWhat string
|
||||
wantErrMsg string
|
||||
name string
|
||||
id agd.DeviceID
|
||||
proto agd.Protocol
|
||||
linkedIPEnabled bool
|
||||
}{{
|
||||
wantDev: nil,
|
||||
wantProf: nil,
|
||||
wantByWhat: "",
|
||||
wantErrMsg: "device not found",
|
||||
name: "no_device_id",
|
||||
id: "",
|
||||
proto: agd.ProtoDNS,
|
||||
linkedIPEnabled: true,
|
||||
}, {
|
||||
wantDev: dev,
|
||||
wantProf: prof,
|
||||
wantByWhat: byDeviceID,
|
||||
wantErrMsg: "",
|
||||
name: "device_id",
|
||||
id: testDeviceID,
|
||||
proto: agd.ProtoDNS,
|
||||
linkedIPEnabled: true,
|
||||
}, {
|
||||
wantDev: dev,
|
||||
wantProf: prof,
|
||||
wantByWhat: byLinkedIP,
|
||||
wantErrMsg: "",
|
||||
name: "linked_ip",
|
||||
id: "",
|
||||
proto: agd.ProtoDNS,
|
||||
linkedIPEnabled: true,
|
||||
}, {
|
||||
wantDev: nil,
|
||||
wantProf: nil,
|
||||
wantByWhat: "",
|
||||
wantErrMsg: "device not found",
|
||||
name: "linked_ip_dot",
|
||||
id: "",
|
||||
proto: agd.ProtoDoT,
|
||||
linkedIPEnabled: true,
|
||||
}, {
|
||||
wantDev: nil,
|
||||
wantProf: nil,
|
||||
wantByWhat: "",
|
||||
wantErrMsg: "device not found",
|
||||
name: "linked_ip_disabled",
|
||||
id: "",
|
||||
proto: agd.ProtoDoT,
|
||||
linkedIPEnabled: false,
|
||||
}, {
|
||||
wantDev: dev,
|
||||
wantProf: prof,
|
||||
wantByWhat: byDedicatedIP,
|
||||
wantErrMsg: "",
|
||||
name: "dedicated_ip",
|
||||
id: "",
|
||||
proto: agd.ProtoDNS,
|
||||
linkedIPEnabled: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
gotID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
assert.Equal(t, tc.id, gotID)
|
||||
|
||||
if tc.wantByWhat == byDeviceID {
|
||||
return prof, dev, nil
|
||||
}
|
||||
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
gotLocalIP netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
assert.Equal(t, testServerAddr, gotLocalIP)
|
||||
|
||||
if tc.wantByWhat == byDedicatedIP {
|
||||
return prof, dev, nil
|
||||
}
|
||||
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
gotRemoteIP netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
assert.Equal(t, testClientAddr, gotRemoteIP)
|
||||
|
||||
if tc.wantByWhat == byLinkedIP {
|
||||
return prof, dev, nil
|
||||
}
|
||||
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
}
|
||||
|
||||
mw := &initMw{
|
||||
srv: &agd.Server{
|
||||
Protocol: tc.proto,
|
||||
LinkedIPEnabled: tc.linkedIPEnabled,
|
||||
},
|
||||
db: db,
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
gotProf, gotDev, gotByWhat, err := mw.profile(
|
||||
ctx,
|
||||
testServerAddr,
|
||||
testClientAddr,
|
||||
tc.id,
|
||||
)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
assert.Equal(t, tc.wantProf, gotProf)
|
||||
assert.Equal(t, tc.wantDev, gotDev)
|
||||
assert.Equal(t, tc.wantByWhat, gotByWhat)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitMw_ServeDNS_ddr(t *testing.T) {
|
||||
const (
|
||||
resolverName = "dns.example.com"
|
||||
resolverFQDN = resolverName + "."
|
||||
|
||||
targetWithID = testDeviceID + ".d." + resolverName + "."
|
||||
|
||||
ddrFQDN = ddrDomain + "."
|
||||
|
||||
dohPath = "/dns-query"
|
||||
)
|
||||
|
||||
testDevice := &agd.Device{ID: testDeviceID}
|
||||
|
||||
srvs := map[agd.ServerName]*agd.Server{
|
||||
"dot": {
|
||||
TLS: &tls.Config{},
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoT,
|
||||
},
|
||||
"doh": {
|
||||
TLS: &tls.Config{},
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("5.6.7.8:54321"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoH,
|
||||
},
|
||||
"dns": {
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
|
||||
}},
|
||||
Protocol: agd.ProtoDNS,
|
||||
LinkedIPEnabled: true,
|
||||
},
|
||||
"dns_nolink": {
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
|
||||
}},
|
||||
Protocol: agd.ProtoDNS,
|
||||
},
|
||||
}
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
TLS: &agd.TLS{
|
||||
DeviceIDWildcards: []string{"*.d." + resolverName},
|
||||
},
|
||||
DDR: &agd.DDR{
|
||||
DeviceTargets: stringutil.NewSet(),
|
||||
PublicTargets: stringutil.NewSet(),
|
||||
Enabled: true,
|
||||
},
|
||||
Name: agd.ServerGroupName("test_server_group"),
|
||||
Servers: maps.Values(srvs),
|
||||
}
|
||||
|
||||
srvGrp.DDR.DeviceTargets.Add("d." + resolverName)
|
||||
srvGrp.DDR.PublicTargets.Add(resolverName)
|
||||
|
||||
var dev *agd.Device
|
||||
mw := &initMw{
|
||||
messages: agdtest.NewConstructor(),
|
||||
fltGrp: &agd.FilteringGroup{},
|
||||
srvGrp: srvGrp,
|
||||
db: &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
p = &agd.Profile{}
|
||||
|
||||
return p, dev, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
p = &agd.Profile{}
|
||||
|
||||
return p, dev, nil
|
||||
},
|
||||
},
|
||||
geoIP: &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (_ netip.Prefix, _ error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
errColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
}
|
||||
|
||||
pubSVCBTmpls := []*dns.SVCB{
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoH, resolverName, dohPath, nil, nil, 443, 1),
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoT, resolverName, "", nil, nil, 853, 1),
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoQ, resolverName, "", nil, nil, 853, 1),
|
||||
}
|
||||
|
||||
devSVCBTmpls := []*dns.SVCB{
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoH, "d."+resolverName, dohPath, nil, nil, 443, 1),
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoT, "d."+resolverName, "", nil, nil, 853, 1),
|
||||
mw.messages.NewDDRTemplate(agd.ProtoDoQ, "d."+resolverName, "", nil, nil, 853, 1),
|
||||
}
|
||||
|
||||
srvGrp.DDR.PublicRecordTemplates = pubSVCBTmpls
|
||||
srvGrp.DDR.DeviceRecordTemplates = devSVCBTmpls
|
||||
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
_ context.Context,
|
||||
_ dnsserver.ResponseWriter,
|
||||
_ *dns.Msg,
|
||||
) (_ error) {
|
||||
// Make sure we haven't reached the following middleware.
|
||||
panic("not implemented")
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
device *agd.Device
|
||||
name string
|
||||
srv *agd.Server
|
||||
host string
|
||||
wantTarget string
|
||||
wantNum int
|
||||
qtype uint16
|
||||
}{{
|
||||
device: testDevice,
|
||||
name: "id",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "id_specific",
|
||||
srv: srvs["dot"],
|
||||
host: ddrLabel + "." + targetWithID,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(devSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "no_id",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: resolverFQDN,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "linked_ip",
|
||||
srv: srvs["dns"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "no_linked_ip",
|
||||
srv: srvs["dns_nolink"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: resolverFQDN,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "public_resolver_name",
|
||||
srv: srvs["dot"],
|
||||
host: ddrLabel + "." + resolverFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "arpa_not_ddr_svcb",
|
||||
srv: srvs["dot"],
|
||||
host: dns.Fqdn(ddrLabel + ".something.else." + resolverArpaDomain),
|
||||
wantTarget: "",
|
||||
wantNum: 0,
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "arpa_ddr_not_svcb",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: "",
|
||||
wantNum: 0,
|
||||
qtype: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
mw.srv = tc.srv
|
||||
dev = tc.device
|
||||
|
||||
var tlsServerName string
|
||||
switch mw.srv.Protocol {
|
||||
case agd.ProtoDoT, agd.ProtoDoQ:
|
||||
tlsServerName = resolverName
|
||||
if dev != nil {
|
||||
tlsServerName = string(dev.ID) + ".d." + tlsServerName
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: tc.host,
|
||||
Qtype: tc.qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
|
||||
TLSServerName: tlsServerName,
|
||||
})
|
||||
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||
|
||||
err := h.ServeDNS(ctx, rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
require.NotNil(t, resp)
|
||||
|
||||
if tc.wantNum == 0 {
|
||||
assert.Empty(t, resp.Answer)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, resp.Answer, tc.wantNum)
|
||||
for _, rr := range resp.Answer {
|
||||
svcb := testutil.RequireTypeAssert[*dns.SVCB](t, rr)
|
||||
|
||||
assert.Equal(t, tc.wantTarget, svcb.Target)
|
||||
assert.Equal(t, tc.host, svcb.Hdr.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
qtype dnsmsg.RRType
|
||||
fltGrpBlocked bool
|
||||
hasProf bool
|
||||
profBlocked bool
|
||||
wantRCode dnsmsg.RCode
|
||||
}{{
|
||||
name: "private_relay_blocked_by_fltgrp",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: false,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeNameError,
|
||||
}, {
|
||||
name: "no_special_domain",
|
||||
host: "www.example.com",
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: false,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "no_private_relay_qtype",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeTXT,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: false,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "private_relay_blocked_by_prof",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: false,
|
||||
hasProf: true,
|
||||
profBlocked: true,
|
||||
wantRCode: dns.RcodeNameError,
|
||||
}, {
|
||||
name: "private_relay_allowed_by_prof",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: true,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "private_relay_allowed_by_both",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: false,
|
||||
hasProf: true,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "private_relay_blocked_by_both",
|
||||
host: applePrivateRelayMaskHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: true,
|
||||
profBlocked: true,
|
||||
wantRCode: dns.RcodeNameError,
|
||||
}, {
|
||||
name: "firefox_canary_allowed_by_prof",
|
||||
host: firefoxCanaryHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: false,
|
||||
hasProf: true,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "firefox_canary_allowed_by_fltgrp",
|
||||
host: firefoxCanaryHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: false,
|
||||
hasProf: false,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
name: "firefox_canary_blocked_by_prof",
|
||||
host: firefoxCanaryHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: false,
|
||||
hasProf: true,
|
||||
profBlocked: true,
|
||||
wantRCode: dns.RcodeRefused,
|
||||
}, {
|
||||
name: "firefox_canary_blocked_by_fltgrp",
|
||||
host: firefoxCanaryHost,
|
||||
qtype: dns.TypeA,
|
||||
fltGrpBlocked: true,
|
||||
hasProf: false,
|
||||
profBlocked: false,
|
||||
wantRCode: dns.RcodeRefused,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
req *dns.Msg,
|
||||
) (err error) {
|
||||
if tc.wantRCode != dns.RcodeSuccess {
|
||||
return errors.Error("unexpectedly reached handler")
|
||||
}
|
||||
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
|
||||
return rw.WriteMsg(ctx, req, resp)
|
||||
})
|
||||
|
||||
onProfileByLinkedIP := func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
if !tc.hasProf {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
}
|
||||
|
||||
prof := &agd.Profile{
|
||||
BlockPrivateRelay: tc.profBlocked,
|
||||
BlockFirefoxCanary: tc.profBlocked,
|
||||
}
|
||||
|
||||
return prof, &agd.Device{}, nil
|
||||
}
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByLinkedIP: onProfileByLinkedIP,
|
||||
}
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
errColl := &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
mw := &initMw{
|
||||
messages: agdtest.NewConstructor(),
|
||||
fltGrp: &agd.FilteringGroup{
|
||||
BlockPrivateRelay: tc.fltGrpBlocked,
|
||||
BlockFirefoxCanary: tc.fltGrpBlocked,
|
||||
},
|
||||
srvGrp: &agd.ServerGroup{},
|
||||
srv: &agd.Server{
|
||||
Protocol: agd.ProtoDNS,
|
||||
LinkedIPEnabled: true,
|
||||
},
|
||||
db: db,
|
||||
geoIP: geoIP,
|
||||
errColl: errColl,
|
||||
}
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
ctx := context.Background()
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
err := h.ServeDNS(ctx, rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var errSink error
|
||||
|
||||
func BenchmarkInitMw_Wrap(b *testing.B) {
|
||||
const devIDTarget = "dns.example.com"
|
||||
srvGrp := &agd.ServerGroup{
|
||||
TLS: &agd.TLS{
|
||||
DeviceIDWildcards: []string{"*." + devIDTarget},
|
||||
},
|
||||
DDR: &agd.DDR{
|
||||
DeviceTargets: stringutil.NewSet(),
|
||||
PublicTargets: stringutil.NewSet(),
|
||||
Enabled: true,
|
||||
},
|
||||
Name: agd.ServerGroupName("test_server_group"),
|
||||
Servers: []*agd.Server{{
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||||
}, {
|
||||
AddrPort: netip.MustParseAddrPort("4.3.2.1:12345"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoT,
|
||||
}},
|
||||
}
|
||||
|
||||
messages := agdtest.NewConstructor()
|
||||
|
||||
ipv4Hints := []netip.Addr{srvGrp.Servers[0].BindData[0].AddrPort.Addr()}
|
||||
ipv6Hints := []netip.Addr{netip.MustParseAddr("2001::1234")}
|
||||
|
||||
srvGrp.DDR.DeviceTargets.Add(devIDTarget)
|
||||
srvGrp.DDR.DeviceRecordTemplates = []*dns.SVCB{
|
||||
messages.NewDDRTemplate(agd.ProtoDoH, devIDTarget, "/dns", ipv4Hints, ipv6Hints, 443, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoT, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoQ, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
|
||||
}
|
||||
|
||||
mw := &initMw{
|
||||
messages: messages,
|
||||
fltGrp: &agd.FilteringGroup{},
|
||||
srvGrp: srvGrp,
|
||||
srv: srvGrp.Servers[0],
|
||||
geoIP: &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
errColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
}
|
||||
|
||||
prof := &agd.Profile{}
|
||||
dev := &agd.Device{}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
|
||||
TLSServerName: testDeviceID + ".dns.example.com",
|
||||
})
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "example.net",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
resp := new(dns.Msg).SetReply(req)
|
||||
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
req *dns.Msg,
|
||||
) (err error) {
|
||||
return rw.WriteMsg(ctx, req, resp)
|
||||
})
|
||||
|
||||
handler = mw.Wrap(handler)
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||
|
||||
mw.db = &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, dev, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
b.Run("success", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = handler.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
mw.db = &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
b.Run("not_found", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = handler.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
ffReq := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "use-application-dns.net.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
mw.db = &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, dev, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
b.Run("firefox_canary", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = handler.ServeDNS(ctx, rw, ffReq)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
ddrReq := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
// Check the worst case when wildcards are checked.
|
||||
Name: "_dns." + testDeviceID + ".dns.example.com.",
|
||||
Qtype: dns.TypeSVCB,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
devWithID := &agd.Device{
|
||||
ID: testDeviceID,
|
||||
}
|
||||
mw.db = &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, devWithID, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
b.Run("ddr", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = handler.ServeDNS(ctx, rw, ddrReq)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
}
|
80
internal/dnssvc/internal/accessmw/access.go
Normal file
80
internal/dnssvc/internal/accessmw/access.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Package accessmw contains the access middleware of the AdGuard DNS server.
|
||||
// It filters out the domain scanners and other requests by specified AdBlock
|
||||
// rules and IP subnets.
|
||||
package accessmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ dnsserver.Middleware = (*Middleware)(nil)
|
||||
|
||||
// Middleware is the access middleware of the AdGuard DNS server.
|
||||
type Middleware struct {
|
||||
accessManager access.Interface
|
||||
}
|
||||
|
||||
// Config is the configuration structure for the access middleware. All fields
|
||||
// must be non-nil.
|
||||
type Config struct {
|
||||
AccessManager access.Interface
|
||||
}
|
||||
|
||||
// New returns a new access middleware. c must not be nil.
|
||||
func New(c *Config) (mw *Middleware) {
|
||||
return &Middleware{
|
||||
accessManager: c.AccessManager,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.Middleware = (*Middleware)(nil)
|
||||
|
||||
// Wrap implements the [dnsserver.Middleware] interface for *Middleware
|
||||
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "access mw: %w") }()
|
||||
|
||||
rAddr := netutil.NetAddrToAddrPort(rw.RemoteAddr()).Addr()
|
||||
if blocked, _ := mw.accessManager.IsBlockedIP(rAddr); blocked {
|
||||
metrics.AccessBlockedForSubnetTotal.Inc()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assume that module dnsserver has already validated that the request
|
||||
// always has exactly one question for us.
|
||||
q := req.Question[0]
|
||||
if blocked := mw.accessManager.IsBlockedHost(normalizeDomain(q.Name), q.Qtype); blocked {
|
||||
metrics.AccessBlockedForHostTotal.Inc()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return next.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
return dnsserver.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// normalizeDomain returns a lowercased version of the host without the final
|
||||
// dot, unless the host is ".", in which case it returns the unchanged host.
|
||||
// That is the special case to allow matching queries like:
|
||||
//
|
||||
// dig IN NS '.'
|
||||
func normalizeDomain(host string) (norm string) {
|
||||
if host == "." {
|
||||
return host
|
||||
}
|
||||
|
||||
return agdnet.NormalizeDomain(host)
|
||||
}
|
135
internal/dnssvc/internal/accessmw/access_test.go
Normal file
135
internal/dnssvc/internal/accessmw/access_test.go
Normal file
@ -0,0 +1,135 @@
|
||||
package accessmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/accessmw"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMiddleware_Wrap(t *testing.T) {
|
||||
am, accessErr := access.New([]string{
|
||||
"block.test",
|
||||
"UPPERCASE.test",
|
||||
"||block_aaaa.test^$dnstype=AAAA",
|
||||
}, []string{
|
||||
"1.1.1.1",
|
||||
"2.2.2.0/8",
|
||||
})
|
||||
require.NoError(t, accessErr)
|
||||
|
||||
amw := accessmw.New(&accessmw.Config{
|
||||
AccessManager: am,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
wantResp assert.BoolAssertionFunc
|
||||
name string
|
||||
host string
|
||||
ip net.IP
|
||||
qtype uint16
|
||||
}{{
|
||||
ip: net.IP{1, 1, 1, 0},
|
||||
name: "pass_ip",
|
||||
host: "pass.test",
|
||||
qtype: dns.TypeA,
|
||||
wantResp: assert.True,
|
||||
}, {
|
||||
name: "block_ip",
|
||||
ip: net.IP{1, 1, 1, 1},
|
||||
host: "pass.test",
|
||||
qtype: dns.TypeA,
|
||||
wantResp: assert.False,
|
||||
}, {
|
||||
name: "pass_subnet",
|
||||
ip: net.IP{1, 2, 2, 2},
|
||||
host: "pass.test",
|
||||
qtype: dns.TypeA,
|
||||
wantResp: assert.True,
|
||||
}, {
|
||||
name: "block_subnet",
|
||||
ip: net.IP{2, 2, 2, 2},
|
||||
host: "pass.test",
|
||||
qtype: dns.TypeA,
|
||||
wantResp: assert.False,
|
||||
}, {
|
||||
wantResp: assert.True,
|
||||
name: "pass_domain",
|
||||
host: "pass.test",
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
wantResp: assert.False,
|
||||
name: "blocked_domain_A",
|
||||
host: "block.test",
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
wantResp: assert.False,
|
||||
name: "blocked_domain_HTTPS",
|
||||
host: "block.test",
|
||||
qtype: dns.TypeHTTPS,
|
||||
}, {
|
||||
wantResp: assert.False,
|
||||
name: "uppercase_domain",
|
||||
host: "uppercase.test",
|
||||
qtype: dns.TypeHTTPS,
|
||||
}, {
|
||||
wantResp: assert.True,
|
||||
name: "pass_qt",
|
||||
host: "block_aaaa.test",
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
wantResp: assert.False,
|
||||
name: "block_qt",
|
||||
host: "block_aaaa.test",
|
||||
qtype: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
q *dns.Msg,
|
||||
) (_ error) {
|
||||
resp := dnsservertest.NewResp(
|
||||
dns.RcodeSuccess,
|
||||
q,
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA("test.domain", 0, netip.MustParseAddr("5.5.5.5")),
|
||||
},
|
||||
)
|
||||
|
||||
err := rw.WriteMsg(ctx, q, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, &net.TCPAddr{IP: tc.ip, Port: 5357})
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: tc.host,
|
||||
Qtype: tc.qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
h := amw.Wrap(handler)
|
||||
err := h.ServeDNS(context.Background(), rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
tc.wantResp(t, resp != nil)
|
||||
})
|
||||
}
|
||||
}
|
44
internal/dnssvc/internal/dnssvctest/dnssvctest.go
Normal file
44
internal/dnssvc/internal/dnssvctest/dnssvctest.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Package dnssvctest contains common constants and utilities for the internal
|
||||
// DNS-service packages.
|
||||
package dnssvctest
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
)
|
||||
|
||||
// Timeout is the common timeout for tests.
|
||||
const Timeout time.Duration = 1 * time.Second
|
||||
|
||||
// String representations of the common IDs for tests.
|
||||
const (
|
||||
DeviceIDStr = "dev1234"
|
||||
ProfileIDStr = "prof1234"
|
||||
)
|
||||
|
||||
// DeviceID is the common device ID for tests.
|
||||
const DeviceID agd.DeviceID = DeviceIDStr
|
||||
|
||||
// ProfileID is the common profile ID for tests.
|
||||
const ProfileID agd.ProfileID = ProfileIDStr
|
||||
|
||||
// Common addresses for tests.
|
||||
var (
|
||||
ClientIP = net.IP{1, 2, 3, 4}
|
||||
RemoteAddr = &net.TCPAddr{
|
||||
IP: ClientIP,
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
ClientAddrPort = RemoteAddr.AddrPort()
|
||||
ClientAddr = ClientAddrPort.Addr()
|
||||
|
||||
ServerAddr = netip.MustParseAddr("5.6.7.8")
|
||||
LocalAddr = &net.TCPAddr{
|
||||
IP: ServerAddr.AsSlice(),
|
||||
Port: 54321,
|
||||
}
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
package dnssvc
|
||||
package initial
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -16,8 +16,6 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Device ID Extraction
|
||||
|
||||
// deviceIDFromClientServerName extracts and validates a device ID. cliSrvName
|
||||
// is the server name as sent by the client. wildcards are the domain wildcards
|
||||
// for device ID detection.
|
@ -1,4 +1,4 @@
|
||||
package dnssvc
|
||||
package initial
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -9,13 +9,14 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_Wrap_deviceID(t *testing.T) {
|
||||
func TestDeviceIDFromContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
cliSrvName string
|
||||
@ -46,8 +47,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
|
||||
proto: agd.ProtoDoT,
|
||||
}, {
|
||||
name: "tls_device_id",
|
||||
cliSrvName: testDeviceID + ".dns.example.com",
|
||||
wantDeviceID: testDeviceID,
|
||||
cliSrvName: dnssvctest.DeviceIDStr + ".dns.example.com",
|
||||
wantDeviceID: dnssvctest.DeviceID,
|
||||
wantErrMsg: "",
|
||||
wildcards: []string{"*.dns.example.com"},
|
||||
proto: agd.ProtoDoT,
|
||||
@ -61,7 +62,7 @@ func TestService_Wrap_deviceID(t *testing.T) {
|
||||
proto: agd.ProtoDoT,
|
||||
}, {
|
||||
name: "tls_deep_subdomain",
|
||||
cliSrvName: "abc." + testDeviceID + ".dns.example.com",
|
||||
cliSrvName: "abc." + dnssvctest.DeviceIDStr + ".dns.example.com",
|
||||
wantDeviceID: "",
|
||||
wantErrMsg: "",
|
||||
wildcards: []string{"*.dns.example.com"},
|
||||
@ -79,8 +80,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
|
||||
proto: agd.ProtoDoT,
|
||||
}, {
|
||||
name: "quic_device_id",
|
||||
cliSrvName: testDeviceID + ".dns.example.com",
|
||||
wantDeviceID: testDeviceID,
|
||||
cliSrvName: dnssvctest.DeviceIDStr + ".dns.example.com",
|
||||
wantDeviceID: dnssvctest.DeviceID,
|
||||
wantErrMsg: "",
|
||||
wildcards: []string{"*.dns.example.com"},
|
||||
proto: agd.ProtoDoQ,
|
||||
@ -93,8 +94,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
|
||||
proto: agd.ProtoDoT,
|
||||
}, {
|
||||
name: "tls_device_id_subdomain_wildcard",
|
||||
cliSrvName: testDeviceID + ".sub.dns.example.com",
|
||||
wantDeviceID: testDeviceID,
|
||||
cliSrvName: dnssvctest.DeviceIDStr + ".sub.dns.example.com",
|
||||
wantDeviceID: dnssvctest.DeviceID,
|
||||
wantErrMsg: "",
|
||||
wildcards: []string{
|
||||
"*.dns.example.com",
|
||||
@ -117,7 +118,7 @@ func TestService_Wrap_deviceID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
|
||||
func TestDeviceIDFromContext_https(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
@ -135,13 +136,13 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "device_id",
|
||||
path: "/dns-query/" + testDeviceID,
|
||||
wantDeviceID: testDeviceID,
|
||||
path: "/dns-query/" + dnssvctest.DeviceIDStr,
|
||||
wantDeviceID: dnssvctest.DeviceID,
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "device_id_slash",
|
||||
path: "/dns-query/" + testDeviceID + "/",
|
||||
wantDeviceID: testDeviceID,
|
||||
path: "/dns-query/" + dnssvctest.DeviceIDStr + "/",
|
||||
wantDeviceID: dnssvctest.DeviceID,
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bad_url",
|
||||
@ -150,10 +151,10 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
|
||||
wantErrMsg: `http url device id check: bad path "/foo"`,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/" + testDeviceID + "/foo",
|
||||
path: "/dns-query/" + dnssvctest.DeviceIDStr + "/foo",
|
||||
wantDeviceID: "",
|
||||
wantErrMsg: `http url device id check: bad path "/dns-query/` + testDeviceID + `/foo": ` +
|
||||
`extra parts`,
|
||||
wantErrMsg: `http url device id check: bad path "/dns-query/` + dnssvctest.DeviceIDStr +
|
||||
`/foo": extra parts`,
|
||||
}, {
|
||||
name: "bad_device_id",
|
||||
path: "/dns-query/!!!",
|
||||
@ -186,7 +187,7 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
|
||||
t.Run("domain_name", func(t *testing.T) {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: testDeviceID + ".dns.example.com",
|
||||
Host: dnssvctest.DeviceIDStr + ".dns.example.com",
|
||||
Path: "/dns-query",
|
||||
}
|
||||
|
||||
@ -202,11 +203,11 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
|
||||
deviceID, err := deviceIDFromContext(ctx, proto, []string{"*.dns.example.com"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, agd.DeviceID(testDeviceID), deviceID)
|
||||
assert.Equal(t, agd.DeviceID(dnssvctest.DeviceID), deviceID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
|
||||
func TestDeviceIDFromEDNS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
opt dns.EDNS0
|
284
internal/dnssvc/internal/initial/initial.go
Normal file
284
internal/dnssvc/internal/initial/initial.go
Normal file
@ -0,0 +1,284 @@
|
||||
// Package initial contains the initial, outermost (except for ratelimit)
|
||||
// middleware of the AdGuard DNS server. It filters out the Firefox canary
|
||||
// domain logic, sets and resets the AD bit for further processing, as well as
|
||||
// puts as much information as it can into the context and request info.
|
||||
package initial
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdsync"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Middleware is the initial middleware of the AdGuard DNS server. This
|
||||
// middleware must be the most outer middleware apart from the ratelimit one.
|
||||
type Middleware struct {
|
||||
// messages is used to build the responses specific for the request's
|
||||
// context.
|
||||
messages *dnsmsg.Constructor
|
||||
|
||||
// fltGrp is the filtering group to which srv belongs.
|
||||
fltGrp *agd.FilteringGroup
|
||||
|
||||
// srvGrp is the server group to which srv belongs.
|
||||
srvGrp *agd.ServerGroup
|
||||
|
||||
// srv is the current server which serves the request.
|
||||
srv *agd.Server
|
||||
|
||||
// pool is the pool of [agd.RequestInfo] values.
|
||||
pool *agdsync.TypedPool[agd.RequestInfo]
|
||||
|
||||
// db is the database of user profiles and devices.
|
||||
db profiledb.Interface
|
||||
|
||||
// geoIP detects the location of the request source.
|
||||
geoIP geoip.Interface
|
||||
|
||||
// errColl collects and reports the errors considered non-critical.
|
||||
errColl agd.ErrorCollector
|
||||
}
|
||||
|
||||
// Config is the configuration structure for the initial middleware. All fields
|
||||
// must be non-nil.
|
||||
type Config struct {
|
||||
// messages is used to build the responses specific for a request's context.
|
||||
Messages *dnsmsg.Constructor
|
||||
|
||||
// FilteringGroup is the filtering group to which Server belongs.
|
||||
FilteringGroup *agd.FilteringGroup
|
||||
|
||||
// ServerGroup is the server group to which Server belongs.
|
||||
ServerGroup *agd.ServerGroup
|
||||
|
||||
// Server is the current server which serves the request.
|
||||
Server *agd.Server
|
||||
|
||||
// DB is the database of user profiles and devices.
|
||||
ProfileDB profiledb.Interface
|
||||
|
||||
// GeoIP detects the location of the request source.
|
||||
GeoIP geoip.Interface
|
||||
|
||||
// ErrColl collects and reports the errors considered non-critical.
|
||||
ErrColl agd.ErrorCollector
|
||||
}
|
||||
|
||||
// New returns a new initial middleware. c must not be nil.
|
||||
func New(c *Config) (mw *Middleware) {
|
||||
return &Middleware{
|
||||
messages: c.Messages,
|
||||
fltGrp: c.FilteringGroup,
|
||||
srvGrp: c.ServerGroup,
|
||||
srv: c.Server,
|
||||
pool: agdsync.NewTypedPool(func() (v *agd.RequestInfo) {
|
||||
return &agd.RequestInfo{}
|
||||
}),
|
||||
db: c.ProfileDB,
|
||||
geoIP: c.GeoIP,
|
||||
errColl: c.ErrColl,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.Middleware = (*Middleware)(nil)
|
||||
|
||||
// Wrap implements the [dnsserver.Middleware] interface for *Middleware
|
||||
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "init mw: %w") }()
|
||||
|
||||
// Save the actual value of the request AD and DO bits and set the AD
|
||||
// bit in the request to true, so that the upstream validates the data
|
||||
// and caches the actual value of the response AD bit. Restore it
|
||||
// later, depending on the request and response data.
|
||||
reqAD := req.AuthenticatedData
|
||||
reqDO := dnsmsg.IsDO(req)
|
||||
req.AuthenticatedData = true
|
||||
|
||||
// Assume that module dnsserver has already validated that the request
|
||||
// always has exactly one question for us.
|
||||
q := req.Question[0]
|
||||
qt := q.Qtype
|
||||
cl := q.Qclass
|
||||
|
||||
// Get the request's information, such as GeoIP data and user profiles.
|
||||
ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), q.Name, qt, cl)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already [errors.Annotate] here.
|
||||
return mw.processReqInfoErr(ctx, rw, req, err)
|
||||
}
|
||||
defer mw.pool.Put(ri)
|
||||
|
||||
if specHdlr, name := mw.reqInfoSpecialHandler(ri, cl); specHdlr != nil {
|
||||
optlog.Debug1("init mw: got req-info special handler %s", name)
|
||||
|
||||
// Don't wrap the error, because it's informative enough as is, and
|
||||
// because if handled is true, the main flow terminates here.
|
||||
return specHdlr(ctx, rw, req, ri)
|
||||
}
|
||||
|
||||
ctx = agd.ContextWithRequestInfo(ctx, ri)
|
||||
|
||||
// Record the response, restore the AD bit value in both the request and
|
||||
// the response, and write the response.
|
||||
nwrw := internal.MakeNonWriter(rw)
|
||||
err = next.ServeDNS(ctx, nwrw, req)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already errors.Annotate here.
|
||||
return err
|
||||
}
|
||||
|
||||
resp := nwrw.Msg()
|
||||
|
||||
// Following RFC 6840, set the AD bit in the response only when the
|
||||
// response is authenticated, and the request contained either a set DO
|
||||
// bit or a set AD bit.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6840#section-5.8.
|
||||
resp.AuthenticatedData = resp.AuthenticatedData && (reqAD || reqDO)
|
||||
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing resp: %w")
|
||||
}
|
||||
|
||||
return dnsserver.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// newRequestInfo returns the new request information structure using the
|
||||
// middleware's configuration and values from ctx.
|
||||
func (mw *Middleware) newRequestInfo(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
laddr net.Addr,
|
||||
raddr net.Addr,
|
||||
fqdn string,
|
||||
qt dnsmsg.RRType,
|
||||
cl dnsmsg.Class,
|
||||
) (ri *agd.RequestInfo, err error) {
|
||||
ri = mw.pool.Get()
|
||||
|
||||
// Use ri as an argument here to evaluate and save the non-nil value of ri
|
||||
// and prevent returns with an error from overwriting ri with nil.
|
||||
defer func(fromPool *agd.RequestInfo) {
|
||||
if err != nil {
|
||||
mw.pool.Put(fromPool)
|
||||
}
|
||||
}(ri)
|
||||
|
||||
// Clear all fields that must be set later.
|
||||
ri.Device = nil
|
||||
ri.Profile = nil
|
||||
ri.ECS = nil
|
||||
ri.Location = nil
|
||||
|
||||
// Put the host, server, and client IP data into the request information
|
||||
// immediately.
|
||||
ri.FilteringGroup = mw.fltGrp
|
||||
ri.Messages = mw.messages
|
||||
ri.RemoteIP = netutil.NetAddrToAddrPort(raddr).Addr()
|
||||
ri.ServerGroup = mw.srvGrp.Name
|
||||
ri.Server = mw.srv.Name
|
||||
ri.Host = agdnet.NormalizeDomain(fqdn)
|
||||
ri.QType = qt
|
||||
ri.QClass = cl
|
||||
|
||||
// As an optimization, put the request ID closer to the top of the context
|
||||
// stack.
|
||||
ri.ID, _ = agd.RequestIDFromContext(ctx)
|
||||
|
||||
// Add the GeoIP information, if any.
|
||||
err = mw.addLocation(ctx, ri, req)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add the profile information, if any.
|
||||
localIP := netutil.NetAddrToAddrPort(laddr).Addr()
|
||||
err = mw.addProfile(ctx, ri, req, localIP)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
// addLocation adds GeoIP location information about the client's remote address
|
||||
// as well as the EDNS Client Subnet information, if there is one, to ri. err
|
||||
// is not nil only if req contains a malformed EDNS Client Subnet option.
|
||||
func (mw *Middleware) addLocation(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) {
|
||||
ri.Location = mw.locationData(ctx, ri.RemoteIP, "client")
|
||||
|
||||
ecs, scope, err := dnsmsg.ECSFromMsg(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding ecs info: %w", err)
|
||||
} else if ecs != (netip.Prefix{}) {
|
||||
ri.ECS = &agd.ECS{
|
||||
Location: mw.locationData(ctx, ecs.Addr(), "ecs"),
|
||||
Subnet: ecs,
|
||||
Scope: scope,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// locationData returns the GeoIP location information about the IP address.
|
||||
// typ is the type of data being requested for error reporting and logging.
|
||||
func (mw *Middleware) locationData(ctx context.Context, ip netip.Addr, typ string) (l *agd.Location) {
|
||||
l, err := mw.geoIP.Data("", ip)
|
||||
if err != nil {
|
||||
// Consider GeoIP errors non-critical. Report and go on.
|
||||
agd.Collectf(ctx, mw.errColl, "init mw: getting geoip for %s ip: %w", typ, err)
|
||||
}
|
||||
|
||||
if l == nil {
|
||||
optlog.Debug2("init mw: no geoip for %s ip %s", typ, ip)
|
||||
} else {
|
||||
optlog.Debug4("init mw: found country/asn %q/%d for %s ip %s", l.Country, l.ASN, typ, ip)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// processReqInfoErr processes the error returned by [Middleware.newRequestInfo]
|
||||
// and returns the properly handled and/or wrapped error.
|
||||
func (mw *Middleware) processReqInfoErr(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
req *dns.Msg,
|
||||
origErr error,
|
||||
) (err error) {
|
||||
var ecsErr dnsmsg.BadECSError
|
||||
if errors.As(origErr, &ecsErr) {
|
||||
// We've got a bad ECS option. Log and respond with a FORMERR
|
||||
// immediately.
|
||||
optlog.Debug1("init mw: %s", origErr)
|
||||
|
||||
writeErr := rw.WriteMsg(ctx, req, mw.messages.NewMsgFORMERR(req))
|
||||
writeErr = errors.Annotate(writeErr, "writing formerr resp: %w")
|
||||
|
||||
return errors.WithDeferred(origErr, writeErr)
|
||||
}
|
||||
|
||||
return origErr
|
||||
}
|
660
internal/dnssvc/internal/initial/initial_test.go
Normal file
660
internal/dnssvc/internal/initial/initial_test.go
Normal file
@ -0,0 +1,660 @@
|
||||
package initial_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func TestMiddleware_Wrap(t *testing.T) {
|
||||
const (
|
||||
resolverName = "dns.example.com"
|
||||
resolverFQDN = resolverName + "."
|
||||
|
||||
targetWithID = dnssvctest.DeviceIDStr + ".d." + resolverName + "."
|
||||
|
||||
ddrFQDN = initial.DDRDomain + "."
|
||||
|
||||
dohPath = "/dns-query"
|
||||
)
|
||||
|
||||
testDevice := &agd.Device{ID: dnssvctest.DeviceID}
|
||||
|
||||
srvs := map[agd.ServerName]*agd.Server{
|
||||
"dot": {
|
||||
TLS: &tls.Config{},
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoT,
|
||||
},
|
||||
"doh": {
|
||||
TLS: &tls.Config{},
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("5.6.7.8:54321"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoH,
|
||||
},
|
||||
"dns": {
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
|
||||
}},
|
||||
Protocol: agd.ProtoDNS,
|
||||
LinkedIPEnabled: true,
|
||||
},
|
||||
"dns_nolink": {
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("2.4.6.8:53"),
|
||||
}},
|
||||
Protocol: agd.ProtoDNS,
|
||||
},
|
||||
}
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
TLS: &agd.TLS{
|
||||
DeviceIDWildcards: []string{"*.d." + resolverName},
|
||||
},
|
||||
DDR: &agd.DDR{
|
||||
DeviceTargets: stringutil.NewSet(),
|
||||
PublicTargets: stringutil.NewSet(),
|
||||
Enabled: true,
|
||||
},
|
||||
Name: agd.ServerGroupName("test_server_group"),
|
||||
Servers: maps.Values(srvs),
|
||||
}
|
||||
|
||||
srvGrp.DDR.DeviceTargets.Add("d." + resolverName)
|
||||
srvGrp.DDR.PublicTargets.Add(resolverName)
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (_ netip.Prefix, _ error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
messages := agdtest.NewConstructor()
|
||||
|
||||
pubSVCBTmpls := []*dns.SVCB{
|
||||
messages.NewDDRTemplate(agd.ProtoDoH, resolverName, dohPath, nil, nil, 443, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoT, resolverName, "", nil, nil, 853, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoQ, resolverName, "", nil, nil, 853, 1),
|
||||
}
|
||||
|
||||
devSVCBTmpls := []*dns.SVCB{
|
||||
messages.NewDDRTemplate(agd.ProtoDoH, "d."+resolverName, dohPath, nil, nil, 443, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoT, "d."+resolverName, "", nil, nil, 853, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoQ, "d."+resolverName, "", nil, nil, 853, 1),
|
||||
}
|
||||
|
||||
srvGrp.DDR.PublicRecordTemplates = pubSVCBTmpls
|
||||
srvGrp.DDR.DeviceRecordTemplates = devSVCBTmpls
|
||||
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
_ context.Context,
|
||||
_ dnsserver.ResponseWriter,
|
||||
_ *dns.Msg,
|
||||
) (_ error) {
|
||||
// Make sure we haven't reached the following middleware.
|
||||
panic("not implemented")
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
device *agd.Device
|
||||
name string
|
||||
srv *agd.Server
|
||||
host string
|
||||
wantTarget string
|
||||
wantNum int
|
||||
qtype uint16
|
||||
}{{
|
||||
device: testDevice,
|
||||
name: "id",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "id_specific",
|
||||
srv: srvs["dot"],
|
||||
host: initial.DDRLabel + "." + targetWithID,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(devSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "no_id",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: resolverFQDN,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "linked_ip",
|
||||
srv: srvs["dns"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "no_linked_ip",
|
||||
srv: srvs["dns_nolink"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: resolverFQDN,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: testDevice,
|
||||
name: "public_resolver_name",
|
||||
srv: srvs["dot"],
|
||||
host: initial.DDRLabel + "." + resolverFQDN,
|
||||
wantTarget: targetWithID,
|
||||
wantNum: len(pubSVCBTmpls),
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "arpa_not_ddr_svcb",
|
||||
srv: srvs["dot"],
|
||||
host: dns.Fqdn(
|
||||
initial.DDRLabel + ".something.else." + initial.ResolverARPADomain,
|
||||
),
|
||||
wantTarget: "",
|
||||
wantNum: 0,
|
||||
qtype: dns.TypeSVCB,
|
||||
}, {
|
||||
device: nil,
|
||||
name: "arpa_ddr_not_svcb",
|
||||
srv: srvs["dot"],
|
||||
host: ddrFQDN,
|
||||
wantTarget: "",
|
||||
wantNum: 0,
|
||||
qtype: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return &agd.Profile{}, tc.device, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return &agd.Profile{}, tc.device, nil
|
||||
},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: agdtest.NewConstructor(),
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: tc.srv,
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
|
||||
TLSServerName: srvNameForProto(tc.device, resolverName, tc.srv.Protocol),
|
||||
})
|
||||
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: tc.host,
|
||||
Qtype: tc.qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
err := h.ServeDNS(ctx, rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
require.NotNil(t, resp)
|
||||
|
||||
if tc.wantNum == 0 {
|
||||
assert.Empty(t, resp.Answer)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, resp.Answer, tc.wantNum)
|
||||
for _, rr := range resp.Answer {
|
||||
svcb := testutil.RequireTypeAssert[*dns.SVCB](t, rr)
|
||||
|
||||
assert.Equal(t, tc.wantTarget, svcb.Target)
|
||||
assert.Equal(t, tc.host, svcb.Hdr.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// srvNameForProto returns a client's TLS server name based on the protocol and
|
||||
// other data.
|
||||
func srvNameForProto(dev *agd.Device, resolverName string, proto agd.Protocol) (srvName string) {
|
||||
switch proto {
|
||||
case agd.ProtoDoT, agd.ProtoDoQ:
|
||||
srvName = resolverName
|
||||
if dev != nil {
|
||||
srvName = string(dev.ID) + ".d." + srvName
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
return srvName
|
||||
}
|
||||
|
||||
func TestMiddleware_Wrap_error(t *testing.T) {
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
_ context.Context,
|
||||
_ dnsserver.ResponseWriter,
|
||||
_ *dns.Msg,
|
||||
) (_ error) {
|
||||
// Make sure we haven't reached the following middleware.
|
||||
panic("not implemented")
|
||||
})
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
Name: agd.ServerGroupName("test_server_group"),
|
||||
}
|
||||
|
||||
srv := &agd.Server{
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("1.2.3.4:53"),
|
||||
}},
|
||||
Protocol: agd.ProtoDNS,
|
||||
LinkedIPEnabled: true,
|
||||
}
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (_ netip.Prefix, _ error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
const testError errors.Error = errors.Error("test error")
|
||||
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, testError
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: agdtest.NewConstructor(),
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: srv,
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
err := h.ServeDNS(ctx, rw, req)
|
||||
assert.ErrorIs(t, err, testError)
|
||||
}
|
||||
|
||||
var errSink error
|
||||
|
||||
func BenchmarkMiddleware_Wrap(b *testing.B) {
|
||||
const devIDTarget = "dns.example.com"
|
||||
srvGrp := &agd.ServerGroup{
|
||||
TLS: &agd.TLS{
|
||||
DeviceIDWildcards: []string{"*." + devIDTarget},
|
||||
},
|
||||
DDR: &agd.DDR{
|
||||
DeviceTargets: stringutil.NewSet(),
|
||||
PublicTargets: stringutil.NewSet(),
|
||||
Enabled: true,
|
||||
},
|
||||
Name: agd.ServerGroupName("test_server_group"),
|
||||
Servers: []*agd.Server{{
|
||||
BindData: []*agd.ServerBindData{{
|
||||
AddrPort: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||||
}, {
|
||||
AddrPort: netip.MustParseAddrPort("4.3.2.1:12345"),
|
||||
}},
|
||||
Protocol: agd.ProtoDoT,
|
||||
}},
|
||||
}
|
||||
|
||||
messages := agdtest.NewConstructor()
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ agd.Country,
|
||||
_ agd.ASN,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
ipv4Hints := []netip.Addr{srvGrp.Servers[0].BindData[0].AddrPort.Addr()}
|
||||
ipv6Hints := []netip.Addr{netip.MustParseAddr("2001::1234")}
|
||||
|
||||
srvGrp.DDR.DeviceTargets.Add(devIDTarget)
|
||||
srvGrp.DDR.DeviceRecordTemplates = []*dns.SVCB{
|
||||
messages.NewDDRTemplate(agd.ProtoDoH, devIDTarget, "/dns", ipv4Hints, ipv6Hints, 443, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoT, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
|
||||
messages.NewDDRTemplate(agd.ProtoDoQ, devIDTarget, "", ipv4Hints, ipv6Hints, 853, 1),
|
||||
}
|
||||
|
||||
prof := &agd.Profile{}
|
||||
dev := &agd.Device{}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
|
||||
TLSServerName: dnssvctest.DeviceIDStr + ".dns.example.com",
|
||||
})
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "example.net",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
resp := new(dns.Msg).SetReply(req)
|
||||
|
||||
var handler dnsserver.Handler = dnsserver.HandlerFunc(func(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
req *dns.Msg,
|
||||
) (err error) {
|
||||
return rw.WriteMsg(ctx, req, resp)
|
||||
})
|
||||
|
||||
rw := dnsserver.NewNonWriterResponseWriter(nil, dnssvctest.RemoteAddr)
|
||||
|
||||
b.Run("success", func(b *testing.B) {
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, dev, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: messages,
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: srvGrp.Servers[0],
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = h.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
b.Run("not_found", func(b *testing.B) {
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: messages,
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: srvGrp.Servers[0],
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = h.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
b.Run("firefox_canary", func(b *testing.B) {
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, dev, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
ffReq := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "use-application-dns.net.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: messages,
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: srvGrp.Servers[0],
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = h.ServeDNS(ctx, rw, ffReq)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
b.Run("ddr", func(b *testing.B) {
|
||||
devWithID := &agd.Device{
|
||||
ID: dnssvctest.DeviceID,
|
||||
}
|
||||
|
||||
db := &agdtest.ProfileDB{
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return prof, devWithID, nil
|
||||
},
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
ddrReq := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
// Check the worst case when wildcards are checked.
|
||||
Name: "_dns." + dnssvctest.DeviceIDStr + ".dns.example.com.",
|
||||
Qtype: dns.TypeSVCB,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
mw := initial.New(&initial.Config{
|
||||
Messages: messages,
|
||||
FilteringGroup: &agd.FilteringGroup{},
|
||||
ServerGroup: srvGrp,
|
||||
Server: srvGrp.Servers[0],
|
||||
ProfileDB: db,
|
||||
GeoIP: geoIP,
|
||||
ErrColl: &agdtest.ErrorCollector{
|
||||
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
|
||||
},
|
||||
})
|
||||
|
||||
h := mw.Wrap(handler)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
errSink = h.ServeDNS(ctx, rw, ddrReq)
|
||||
}
|
||||
|
||||
assert.NoError(b, errSink)
|
||||
})
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkMiddleware_Wrap/success-16 1970464 735.8 ns/op 72 B/op 2 allocs/op
|
||||
// BenchmarkMiddleware_Wrap/not_found-16 1469100 715.9 ns/op 48 B/op 1 allocs/op
|
||||
// BenchmarkMiddleware_Wrap/firefox_canary-16 1644410 861.9 ns/op 72 B/op 2 allocs/op
|
||||
// BenchmarkMiddleware_Wrap/ddr-16 252656 4810 ns/op 1408 B/op 45 allocs/op
|
||||
}
|
110
internal/dnssvc/internal/initial/profile.go
Normal file
110
internal/dnssvc/internal/initial/profile.go
Normal file
@ -0,0 +1,110 @@
|
||||
package initial
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// addProfile adds profile and device information, if any, to the request
|
||||
// information.
|
||||
func (mw *Middleware) addProfile(
|
||||
ctx context.Context,
|
||||
ri *agd.RequestInfo,
|
||||
req *dns.Msg,
|
||||
localIP netip.Addr,
|
||||
) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting profile from req: %w") }()
|
||||
|
||||
var id agd.DeviceID
|
||||
if p := mw.srv.Protocol; p.IsStdEncrypted() {
|
||||
// Assume that mw.srvGrp.TLS is non-nil if p.IsStdEncrypted() is true.
|
||||
wildcards := mw.srvGrp.TLS.DeviceIDWildcards
|
||||
id, err = deviceIDFromContext(ctx, mw.srv.Protocol, wildcards)
|
||||
} else if p == agd.ProtoDNS {
|
||||
id, err = deviceIDFromEDNS(req)
|
||||
} else {
|
||||
// No DeviceID for DNSCrypt yet.
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
optlog.Debug3("init mw: got device id %q, raddr %s, and laddr %s", id, ri.RemoteIP, localIP)
|
||||
|
||||
prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, profiledb.ErrDeviceNotFound) {
|
||||
// Very unlikely, since there is only one error type currently
|
||||
// returned from the default profile DB.
|
||||
return fmt.Errorf("unexpected profiledb error: %w", err)
|
||||
}
|
||||
|
||||
optlog.Debug1("init mw: profile or device not found: %s", err)
|
||||
} else if prof.Deleted {
|
||||
optlog.Debug1("init mw: profile %s is deleted", prof.ID)
|
||||
} else {
|
||||
optlog.Debug3("init mw: found profile %s and device %s by %s", prof.ID, dev.ID, byWhat)
|
||||
|
||||
ri.Device, ri.Profile = dev, prof
|
||||
ri.Messages = dnsmsg.NewConstructor(prof.BlockingMode.Mode, prof.FilteredResponseTTL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constants for the parameter by which a device has been found.
|
||||
const (
|
||||
byDeviceID = "device id"
|
||||
byDedicatedIP = "dedicated ip"
|
||||
byLinkedIP = "linked ip"
|
||||
)
|
||||
|
||||
// profile finds the profile by the client data.
|
||||
func (mw *Middleware) profile(
|
||||
ctx context.Context,
|
||||
localIP netip.Addr,
|
||||
remoteIP netip.Addr,
|
||||
id agd.DeviceID,
|
||||
) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) {
|
||||
if id != "" {
|
||||
prof, dev, err = mw.db.ProfileByDeviceID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
return prof, dev, byDeviceID, nil
|
||||
}
|
||||
|
||||
if !mw.srv.LinkedIPEnabled {
|
||||
optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name)
|
||||
|
||||
return nil, nil, "", profiledb.ErrDeviceNotFound
|
||||
} else if p := mw.srv.Protocol; p != agd.ProtoDNS {
|
||||
optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p)
|
||||
|
||||
return nil, nil, "", profiledb.ErrDeviceNotFound
|
||||
}
|
||||
|
||||
byWhat = byDedicatedIP
|
||||
prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP)
|
||||
if errors.Is(err, profiledb.ErrDeviceNotFound) {
|
||||
byWhat = byLinkedIP
|
||||
prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
return prof, dev, byWhat, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user