Sync v2.10.0

This commit is contained in:
Andrey Meshkov 2024-11-08 16:26:22 +03:00
parent da0cb6fd0e
commit 87137bddcf
150 changed files with 8610 additions and 6094 deletions

View File

@ -7,6 +7,94 @@ The format is **not** based on [Keep a Changelog][kec], since the project **does
[kec]: https://keepachangelog.com/en/1.0.0/
[sem]: https://semver.org/spec/v2.0.0.html
## AGDNS-2484/ Build 886
- Property `type` of the `ratelimit` object has been moved to the underlying `allowlist` object. So replace this:
```yaml
ratelimit:
type: 'consul'
# …
allowlist:
# …
```
with this:
```yaml
ratelimit:
# …
allowlist:
type: 'consul'
# …
```
## AGDNS-2443 / Build 877
- The object `filters` has new properties: `ede_enabled`, and `sde_enabled`. So replace this:
```yaml
filters:
# …
```
with this:
```yaml
filters:
# …
ede_enabled: true
sde_enabled: true
```
## AGDNS-2456 / Build 873
- The environment variables `BACKEND_RATELIMIT_URL` and `BACKEND_RATELIMIT_API_KEY` have been added.
- Added the `type` property within the `ratelimit` object. So add it:
```yaml
ratelimit:
type: 'consul'
# …
```
## AGDNS-2431 / Build 872
- The objects `ratelimit.ipv4` and `ratelimit.ipv6` have been modified. Its `rps` properties have been replaced with the new properties `count` and `interval`. So replace this:
```yaml
ratelimit:
# …
ipv4:
rps: 30
ipv6:
rps: 300
```
with this:
```yaml
ratelimit:
# …
ipv4:
# …
count: 300
interval: 10s
ipv6:
# …
count: 3000
interval: 10s
```
Adjust the value and add new ones, if necessary.
## AGDNS-2457 / Build 871
- The environment variables `DNSCHECK_REMOTEKV_URL` and `DNSCHECK_REMOTEKV_API_KEY` have been added.
- The property `kv.type` within the `check` object now supports the `backend` value.
## AGDNS-2468 / Build 869
- The environment variable `PROFILES_MAX_RESP_SIZE` has been added. It sets the maximum size of the response from the profiles endpoint of the backend API. The default value is `8MB`.

View File

@ -24,7 +24,7 @@ BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)}
GOAMD64 = v1
GOPROXY = https://proxy.golang.org|direct
GOTELEMETRY = off
GOTOOLCHAIN = go1.23.1
GOTOOLCHAIN = go1.23.2
RACE = 0
REVISION = $${REVISION:-$$(git rev-parse --short HEAD)}
VERSION = 0

View File

@ -10,15 +10,19 @@ ratelimit:
response_size_estimate: 1KB
# Rate limit options for IPv4 addresses.
ipv4:
# Rate of requests per second for one subnet for IPv4 addresses.
rps: 30
# Requests per configured interval for one subnet for IPv4 addresses.
count: 300
# The time during which to count the number of requests.
interval: 10s
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv4 addresses.
subnet_key_len: 24
# Rate limit options for IPv6 addresses.
ipv6:
# Rate of requests per second for one subnet for IPv6 addresses.
rps: 300
# Requests per configured interval for one subnet for IPv6 addresses.
count: 3000
# The time during which to count the number of requests.
interval: 10s
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv6 addresses.
subnet_key_len: 48
@ -40,6 +44,9 @@ ratelimit:
- '127.0.0.1/24'
# Time between two updates of allow list.
refresh_interval: 1h
# Defines where the rate limiting settings are received from. Allowed
# values are "backend" and "consul".
type: 'consul'
# Configuration for the stream connection limiting.
connection_limit:
@ -167,8 +174,8 @@ geoip:
check:
# Domains to use for DNS checking.
kv:
# Defines the type of remote kay-value storage. Allowed values are
# "consul" and "redis".
# Defines the type of remote key-value storage. Allowed values are
# "backend", "consul", and "redis".
type: 'consul'
# For how long to keep the information about the client.
ttl: 30s
@ -313,6 +320,10 @@ filters:
enabled: true
# The size of the LRU cache of rule-list filtering results.
size: 10000
# Enable the Extended DNS Errors feature.
ede_enabled: true
# Enable the Structured DNS Errors feature. Requires ede_enabled: true.
sde_enabled: true
# Filtering groups are a set of different filtering configurations. These
# filtering configurations are then used by server_groups.

View File

@ -104,9 +104,13 @@ The `ratelimit` object has the following properties:
- <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>: The ipv4 configuration object. It has the following fields:
- <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>: The rate of requests per second for one subnet. Requests above this are counted in the backoff count.
- <a href="#ratelimit-ipv4-count" id="ratelimit-ipv4-count" name="ratelimit-ipv4-count">`count`</a>: Requests per configured interval for one subnet for IPv4 addresses. Requests above this are counted in the backoff count.
**Example:** `30`.
**Example:** `300`.
- <a href="#ratelimit-ipv4-interval" id="ratelimit-ipv4-interval" name="ratelimit-ipv4-interval">`interval`</a>: The time during which to count the number of requests.
**Example:** `10s`.
- <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>: The length of the subnet prefix used to calculate rate limiter bucket keys.
@ -134,7 +138,11 @@ The `ratelimit` object has the following properties:
**Example:** `30s`.
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, and `ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`.
- <a href="#ratelimit-allowlist-type" id="ratelimit-allowlist-type" name="ratelimit-allowlist-type">`type`</a>: Defines where the rate limit settings are received from. Allowed values are `backend` and `consul`.
**Example:** `consul`.
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, `ipv4-count` is `5`, and `ipv4-interval` is `1s`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`.
### <a href="#ratelimit-connection_limit" id="ratelimit-connection_limit" name="ratelimit-connection_limit">Stream connection limit</a>
@ -370,12 +378,14 @@ The `check` object has the following properties:
- <a href="#check_kv" id="check_kv" name="check_kv">`kv`</a>: Remote key-value storage settings. It has the following properties:
- <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `consul` and `redis`.
- <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `backend`, `consul`, and `redis`.
**Example:** `consul`.
- <a href="#check-kv-ttl" id="check-kv-ttl" name="check-kv-ttl">`ttl`</a>: For how long to keep the information about a single user in remote KV, as a human-readable duration.
For `backend`, the TTL must be greater than `0s`.
For `consul`, the TTL must be between `10s` and `1d`. Note that the actual TTL can be up to twice as long.
For `redis`, the TTL must be greater than or equal to `1ms`.
@ -592,6 +602,14 @@ The `filters` object has the following properties:
**Example:** `10000`.
- <a href="#filters-ede_enabled" id="filters-ede_enabled" name="filters-ede_enabled">`ede_enabled`</a>: Shows if Extended DNS Error codes should be added.
**Example:** `true`.
- <a href="#filters-sde_enabled" id="filters-sde_enabled" name="filters-sde_enabled">`sde_enabled`</a>: Shows if the experimental Structured DNS Errors feature should be enabled. `ede_enabled` must be `true` to enable SDE.
**Example:** `true`.
[env-blocked_services]: environment.md#BLOCKED_SERVICE_INDEX_URL
## <a href="#filtering_groups" id="filtering_groups" name="filtering_groups">Filtering groups</a>

View File

@ -159,7 +159,7 @@ You'll need to supply the following:
See the [external HTTP API documentation][externalhttp].
You may use `go run ./scripts/backend` to start mock GRPC server for `BILLSTAT_URL` and `PROFILES_URL` endpoints.
You may use `go run ./scripts/backend` to start mock GRPC server for `BACKEND_PROFILES_URL`, `BILLSTAT_URL`, `DNSCHECK_REMOTEKV_URL`, and `PROFILES_URL` endpoints.
You may need to change the listen ports in `config.yaml` which are less than 1024 to some other ports. Otherwise, `sudo` or `doas` is required to run `AdGuardDNS`.

View File

@ -6,6 +6,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
- [`ADULT_BLOCKING_ENABLED`](#ADULT_BLOCKING_ENABLED)
- [`ADULT_BLOCKING_URL`](#ADULT_BLOCKING_URL)
- [`BACKEND_RATELIMIT_API_KEY`](#BACKEND_RATELIMIT_API_KEY)
- [`BACKEND_RATELIMIT_URL`](#BACKEND_RATELIMIT_URL)
- [`BILLSTAT_API_KEY`](#BILLSTAT_API_KEY)
- [`BILLSTAT_URL`](#BILLSTAT_URL)
- [`BLOCKED_SERVICE_ENABLED`](#BLOCKED_SERVICE_ENABLED)
@ -14,6 +16,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
- [`CONSUL_ALLOWLIST_URL`](#CONSUL_ALLOWLIST_URL)
- [`CONSUL_DNSCHECK_KV_URL`](#CONSUL_DNSCHECK_KV_URL)
- [`CONSUL_DNSCHECK_SESSION_URL`](#CONSUL_DNSCHECK_SESSION_URL)
- [`DNSCHECK_REMOTEKV_API_KEY`](#DNSCHECK_REMOTEKV_API_KEY)
- [`DNSCHECK_REMOTEKV_URL`](#DNSCHECK_REMOTEKV_URL)
- [`FILTER_CACHE_PATH`](#FILTER_CACHE_PATH)
- [`FILTER_INDEX_URL`](#FILTER_INDEX_URL)
- [`GENERAL_SAFE_ENABLED`](#GENERAL_SAFE_SEARCH_ENABLED)
@ -62,6 +66,21 @@ The HTTP(S) URL of source list of rules for adult blocking filter.
**Default:** No default value, the variable is required if `ADULT_BLOCKING_ENABLED` is set to `1`.
## <a href="#BACKEND_RATELIMIT_API_KEY" id="BACKEND_RATELIMIT_API_KEY" name="BACKEND_RATELIMIT_API_KEY">`BACKEND_RATELIMIT_API_KEY`</a>
The API key to use when authenticating requests to the backend rate limiter API, if any. The API key should be valid as defined by [RFC 6750].
**Default:** **Unset.**
## <a href="#BACKEND_RATELIMIT_URL" id="BACKEND_RATELIMIT_URL" name="BACKEND_RATELIMIT_URL">`BACKEND_RATELIMIT_URL`</a>
The base backend URL for backend rate limiter. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-backend-ratelimit].
**Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `backend` in the configuration file.
[conf-ratelimit-type]: configuration.md#ratelimit-type
[ext-backend-ratelimit]: externalhttp.md#backend-ratelimit
## <a href="#BILLSTAT_API_KEY" id="BILLSTAT_API_KEY" name="BILLSTAT_API_KEY">`BILLSTAT_API_KEY`</a>
The API key to use when authenticating queries to the billing statistics API, if any. The API key should be valid as defined by [RFC 6750].
@ -72,7 +91,7 @@ The API key to use when authenticating queries to the billing statistics API, if
## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a>
The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat].
The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat].
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
@ -103,7 +122,7 @@ The path to the configuration file.
The HTTP(S) URL of the Consul instance serving the dynamic part of the rate-limit allowlist. See the [external HTTP API requirements section][ext-consul] on the expected format of the response.
**Default:** No default value, the variable is **required.**
**Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `consul` in the configuration file.
[ext-consul]: externalhttp.md#consul
@ -123,6 +142,20 @@ The HTTP(S) URL of the session API of the Consul instance used as a key-value da
**Example:** `http://localhost:8500/v1/session/create`
## <a href="#DNSCHECK_REMOTEKV_API_KEY" id="DNSCHECK_REMOTEKV_API_KEY" name="DNSCHECK_REMOTEKV_API_KEY">`DNSCHECK_REMOTEKV_API_KEY`</a>
The API key to use when authenticating queries to the backend key-value database API, if any. The API key should be valid as defined by [RFC 6750].
**Default:** **Unset.**
## <a href="#DNSCHECK_REMOTEKV_URL" id="DNSCHECK_REMOTEKV_URL" name="DNSCHECK_REMOTEKV_URL">`DNSCHECK_REMOTEKV_URL`</a>
The base backend URL used as a key-value database for the DNS server checking. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-backend-dnscheck].
**Default:** **Unset.**
[ext-backend-dnscheck]: externalhttp.md#backend-dnscheck
## <a href="#FILTER_CACHE_PATH" id="FILTER_CACHE_PATH" name="FILTER_CACHE_PATH">`FILTER_CACHE_PATH`</a>
The path to the directory used to store the cached version of all filters and filter indexes.
@ -238,11 +271,11 @@ The profile cache is read on start and is later updated on every [full refresh][
The maximum size of the response from the profiles API in a human-readable format.
**Default:** `8MB`.
**Default:** `64MB`.
## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a>
The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-profiles].
The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-profiles].
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
@ -252,7 +285,7 @@ The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`
Redis server address. Can be an IP address or a hostname.
**Default:** No default value, the variable if required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file.
**Default:** No default value, the variable is required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file.
[conf-check-kv-type]: configuration.md#check-kv-type

View File

@ -10,7 +10,9 @@ AdGuard DNS uses information from external HTTP APIs for filtering and other pie
## Contents
- [Backend billing statistics](#backend-billstat)
- [Backend DNSCheck service](#backend-dnscheck)
- [Backend profiles service](#backend-profiles)
- [Backend ratelimit service](#backend-ratelimit)
- [Consul key-value storage](#consul)
- [Filtering](#filters)
- [Blocked services](#filters-blocked-services)
@ -28,6 +30,15 @@ This service is disabled when all server groups have property [`profiles_enabled
[env-billstat_url]: environment.md#BILLSTAT_URL
[conf-srvgrp-prof]: configuration.md#sg-*-profiles_enabled
## <a href="#backend-dnscheck" id="backend-dnscheck" name="backend-dnscheck">Backend DNSCheck service</a>
This is the service to which the [`DNSCHECK_REMOTEKV_URL`][env-dnscheck_remotekv_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
This service is only enabled when the `check.kv` object has the [`type`][conf-check-kv-type] property set to `backend`.
[env-dnscheck_remotekv_url]: environment.md#DNSCHECK_REMOTEKV_URL
[conf-check-kv-type]: configuration.md#check-kv-type
## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend profiles service</a>
This is the service to which the [`PROFILES_URL`][env-profiles_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
@ -36,6 +47,15 @@ This service is disabled when all server groups have property [`profiles_enabled
[env-profiles_url]: environment.md#PROFILES_URL
## <a href="#backend-ratelimit" id="backend-ratelimit" name="backend-ratelimit">Backend ratelimit_service</a>
This is the service to which the [`BACKEND_RATELIMIT_URL`][env-backend_ratelimit_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
This service is only enabled when the `ratelimit` object has the [`type`][conf-ratelimit-type] property set to `backend`.
[conf-ratelimit-type]: configuration.md#ratelimit-type
[env-backend_ratelimit_url]: environment.md#BACKEND_RATELIMIT_URL
## <a href="#consul" id="consul" name="consul">Consul key-value storage</a>
A [Consul][consul-io] service can be used for the DNS server check and dynamic rate-limit allowlist features. Currently used endpoints can be seen in the documentation of the [`CONSUL_ALLOWLIST_URL`][env-consul-allowlist], [`CONSUL_DNSCHECK_KV_URL`][env-consul-dnscheck-kv], and [`CONSUL_DNSCHECK_SESSION_URL`][env-consul-dnscheck-session] environment variables.

43
go.mod
View File

@ -1,34 +1,34 @@
module github.com/AdguardTeam/AdGuardDNS
go 1.23.1
go 1.23.2
require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.0.0-20240607112746-5690301129fe
github.com/AdguardTeam/golibs v0.28.0
github.com/AdguardTeam/urlfilter v0.19.0
github.com/AdguardTeam/golibs v0.30.1
github.com/AdguardTeam/urlfilter v0.20.0
github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/axiomhq/hyperloglog v0.2.0
github.com/bluele/gcache v0.0.2
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
github.com/caarlos0/env/v7 v7.1.0
github.com/getsentry/sentry-go v0.28.1
github.com/getsentry/sentry-go v0.29.1
github.com/gomodule/redigo v1.9.2
github.com/google/renameio/v2 v2.0.0
github.com/miekg/dns v1.1.62
github.com/oschwald/maxminddb-golang v1.13.1
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/prometheus/client_golang v1.20.1
github.com/prometheus/client_golang v1.20.5
github.com/prometheus/client_model v0.6.1
github.com/prometheus/common v0.55.0
github.com/quic-go/quic-go v0.47.0
github.com/prometheus/common v0.60.0
github.com/quic-go/quic-go v0.48.1
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.27.0
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
golang.org/x/net v0.29.0
golang.org/x/sys v0.25.0
golang.org/x/time v0.6.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
golang.org/x/crypto v0.28.0
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.30.0
golang.org/x/sys v0.26.0
golang.org/x/time v0.7.0
google.golang.org/grpc v1.67.1
google.golang.org/protobuf v1.35.1
gopkg.in/yaml.v2 v2.4.0
)
@ -41,24 +41,21 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.20.2 // indirect
github.com/panjf2000/ants/v2 v2.10.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/tools v0.25.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.26.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/AdguardTeam/AdGuardDNS/internal/dnsserver => ./internal/dnsserver
// TODO(a.garipov): Remove once https://github.com/quic-go/quic-go/pull/4685 is merged.
replace github.com/quic-go/quic-go => github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd

76
go.sum
View File

@ -1,13 +1,11 @@
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4=
github.com/AdguardTeam/golibs v0.28.0/go.mod h1:iWdjXPCwmK2g2FKIb/OwEPnovSXeMqRhI8FWLxF5oxE=
github.com/AdguardTeam/urlfilter v0.19.0 h1:q7eH13+yNETlpD/VD3u5rLQOripcUdEktqZFy+KiQLk=
github.com/AdguardTeam/urlfilter v0.19.0/go.mod h1:+N54ZvxqXYLnXuvpaUhK2exDQW+djZBRSb6F6j0rkBY=
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd h1:mw4LqrCiv3vcKuCxBRg7kA17xfHKM+9hZgFWmyhe/AY=
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs=
github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
@ -29,8 +27,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 h1:y7y0Oa6UawqTFPCDw9JG6pdKt4F9pAhHv0B7FMGaGD0=
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
github.com/getsentry/sentry-go v0.28.1 h1:zzaSm/vHmGllRM6Tpx1492r0YDzauArdBfkJRtY6P5k=
github.com/getsentry/sentry-go v0.28.1/go.mod h1:1fQZ+7l7eeJ3wYi82q5Hg8GqAPgefRq+FP/QhafYVgg=
github.com/getsentry/sentry-go v0.29.1 h1:DyZuChN8Hz3ARxGVV8ePaNXh1dQ7d76AiB117xcREwA=
github.com/getsentry/sentry-go v0.29.1/go.mod h1:x3AtIzN01d6SiWkderzaH28Tm0lgkafpJ5Bm3li39O0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
@ -43,12 +41,12 @@ github.com/gomodule/redigo v1.9.2 h1:HrutZBLhSIU8abiSfW8pj8mPhOyMYjZT/wcA4/L9L9s
github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@ -79,16 +77,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8=
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
@ -109,33 +109,33 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd h1:6TEm2ZxXoQmFWFlt1vNxvVOa1Q0dXFQD1m/rYjXmS0E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -1,4 +1,4 @@
go 1.23.1
go 1.23.2
use (
.

View File

@ -1,5 +1,6 @@
cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w=
cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
@ -45,6 +46,7 @@ cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGB
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
cloud.google.com/go/contactcenterinsights v1.13.0/go.mod h1:ieq5d5EtHsu8vhe2y3amtZ+BE+AQwX5qAy7cpo0POsI=
cloud.google.com/go/container v1.31.0/go.mod h1:7yABn5s3Iv3lmw7oMmyGbeV6tQj86njcTijkkGuvdZA=
cloud.google.com/go/containeranalysis v0.11.4/go.mod h1:cVZT7rXYBS9NG1rhQbWL9pWbXCKHWJPYraE8/FTSYPE=
@ -156,6 +158,9 @@ github.com/AdguardTeam/golibs v0.19.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMss
github.com/AdguardTeam/golibs v0.25.2/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
github.com/AdguardTeam/golibs v0.25.3 h1:A06JZGSuAhAC0uq/s7IlNsv/V8TyNJfLalB0vhkd1vA=
github.com/AdguardTeam/golibs v0.25.3/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
github.com/AdguardTeam/golibs v0.30.0/go.mod h1:vjw1OVZG6BYyoqGRY88U4LCJLOMfhBFhU0UJBdaSAuQ=
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/AdguardTeam/gomitmproxy v0.2.1 h1:p9gr8Er1TYvf+7ic81Ax1sZ62UNCsMTZNbm7tC59S9o=
github.com/AdguardTeam/gomitmproxy v0.2.1/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
@ -245,6 +250,7 @@ github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50 h1:DBmgJDC9dTfkVyGgipa
github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50/go.mod h1:5e1+Vvlzido69INQaVO6d87Qn543Xr6nooe9Kz7oBFM=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM=
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI=
@ -262,12 +268,14 @@ github.com/envoyproxy/go-control-plane v0.11.1 h1:wSUXTlLfiAQRWs2F+p+EKOY9rUyis1
github.com/envoyproxy/go-control-plane v0.11.1/go.mod h1:uhMcXKCQMEJHiAb0w+YGefQLaTEw+YhGluxZkrTmD0g=
github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI=
github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0=
github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8=
github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA=
github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE=
github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A=
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
@ -338,6 +346,7 @@ github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4=
github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@ -576,6 +585,7 @@ github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwb
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
@ -677,6 +687,7 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
github.com/urfave/negroni/v3 v3.1.1/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc=
@ -828,6 +839,7 @@ golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -912,6 +924,7 @@ golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@ -1000,6 +1013,7 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU=
google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo=
google.golang.org/genproto/googleapis/bytestream v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:vh/N7795ftP0AkN1w8XKqN4w1OdUKXW5Eummda+ofv8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// Client is a wrapper around http.Client.
@ -93,6 +94,7 @@ func (c *Client) do(
req.Header.Set(httphdr.UserAgent, c.userAgent)
resp, err = c.http.Do(req)
urlutil.RedactUserinfoInURLError(u, err)
if err != nil && resp != nil && resp.Header != nil {
// A non-nil Response with a non-nil error only occurs when
// CheckRedirect fails.

View File

@ -5,47 +5,9 @@ import (
"net/url"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// Known scheme constants.
//
// TODO(a.garipov): Move to agdurl or golibs.
//
// TODO(a.garipov): Use more.
const (
SchemeFile = "file"
SchemeGRPC = "grpc"
SchemeGRPCS = "grpcs"
SchemeHTTP = "http"
SchemeHTTPS = "https"
)
// CheckGRPCURLScheme returns true if s is a valid gRPC URL scheme. That is,
// [SchemeGRPC] or [SchemeGRPCS]
//
// TODO(a.garipov): Move to golibs?
func CheckGRPCURLScheme(s string) (ok bool) {
switch s {
case SchemeGRPC, SchemeGRPCS:
return true
default:
return false
}
}
// CheckHTTPURLScheme returns true if s is a valid HTTP URL scheme. That is,
// [SchemeHTTP] or [SchemeHTTPS]
//
// TODO(a.garipov): Move to golibs?
func CheckHTTPURLScheme(s string) (ok bool) {
switch s {
case SchemeHTTP, SchemeHTTPS:
return true
default:
return false
}
}
// ParseHTTPURL parses an absolute URL and makes sure that it is a valid HTTP(S)
// URL. All returned errors will have the underlying type [*url.Error].
//
@ -63,7 +25,7 @@ func ParseHTTPURL(s string) (u *url.URL, err error) {
URL: s,
Err: errors.Error("empty host"),
}
case !CheckHTTPURLScheme(u.Scheme):
case !urlutil.IsValidHTTPURLScheme(u.Scheme):
return nil, &url.Error{
Op: "parse",
URL: s,

View File

@ -6,12 +6,19 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
// Common user credentials for tests.
const (
testUsername = "user"
testPassword = "pass"
)
func TestParseHTTPURL(t *testing.T) {
goodURL := testURL()
goodURL := testURL(url.UserPassword(testUsername, testPassword))
badSchemeURL := netutil.CloneURL(goodURL)
badSchemeURL.Scheme = "ftp"
@ -61,10 +68,11 @@ func TestParseHTTPURL(t *testing.T) {
}
}
func testURL() (u *url.URL) {
// testURL is a helper function that returns an url with dummy values.
func testURL(info *url.Userinfo) (u *url.URL) {
return &url.URL{
Scheme: agdhttp.SchemeHTTP,
User: url.UserPassword("user", "pass"),
Scheme: urlutil.SchemeHTTP,
User: info,
Host: "example.com",
Path: "/a/b/c/",
RawQuery: "d=e",

View File

@ -3,6 +3,7 @@
package agdtest
import (
"net/url"
"testing"
"time"
@ -18,7 +19,7 @@ const FilteredResponseTTL = FilteredResponseTTLSec * time.Second
// number to simplify message creation.
const FilteredResponseTTLSec = 10
// NewConstructorWithTTL returns a standard dnsmsg.Constructor for tests, using
// NewConstructorWithTTL returns a standard *dnsmsg.Constructor for tests, using
// ttl as the TTL for filtered responses.
func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Constructor) {
tb.Helper()
@ -26,28 +27,57 @@ func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Construc
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: NewSDEConfig(true),
FilteredResponseTTL: ttl,
EDEEnabled: true,
})
require.NoError(tb, err)
return c
}
// NewConstructor returns a standard dnsmsg.Constructor for tests, using
// [FilteredResponseTTL] as the TTL for filtered responses.
// NewConstructor returns a standard *dnsmsg.Constructor for tests, using
// [FilteredResponseTTL] as the TTL for filtered responses. The returned
// constructor also has the Structured DNS Errors feature enabled.
func NewConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
tb.Helper()
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: NewSDEConfig(true),
FilteredResponseTTL: FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(tb, err)
return c
}
// SDEText is a test Structured DNS Error text.
//
// NOTE: Keep in sync with [NewSDEConfig].
//
// TODO(e.burkov): Add some helper when this message becomes configurable.
const SDEText = `{` +
`"j":"Filtering",` +
`"o":"Test Org",` +
`"c":["mailto:support@dns.example"]` +
`}`
// NewSDEConfig returns a standard *dnsmsg.StructuredDNSErrorsConfig for tests.
func NewSDEConfig(enabled bool) (c *dnsmsg.StructuredDNSErrorsConfig) {
return &dnsmsg.StructuredDNSErrorsConfig{
Contact: []*url.URL{{
Scheme: "mailto",
Opaque: "support@dns.example",
}},
Justification: "Filtering",
Organization: "Test Org",
Enabled: enabled,
}
}
// NewCloner returns a standard dnsmsg.Cloner for tests.
func NewCloner() (c *dnsmsg.Cloner) {
return dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{})

View File

@ -116,7 +116,7 @@ func (r *Refresher) Refresh(ctx context.Context) (err error) {
// type check
var _ billstat.Recorder = (*BillStatRecorder)(nil)
// BillStatRecorder is a billstat.Recorder for tests.
// BillStatRecorder is a [billstat.Recorder] for tests.
type BillStatRecorder struct {
OnRecord func(
ctx context.Context,
@ -128,7 +128,7 @@ type BillStatRecorder struct {
)
}
// Record implements the billstat.Recorder interface for *BillStatRecorder.
// Record implements the [billstat.Recorder] interface for *BillStatRecorder.
func (r *BillStatRecorder) Record(
ctx context.Context,
id agd.DeviceID,
@ -143,12 +143,12 @@ func (r *BillStatRecorder) Record(
// type check
var _ billstat.Uploader = (*BillStatUploader)(nil)
// BillStatUploader is a billstat.Uploader for tests.
// BillStatUploader is a [billstat.Uploader] for tests.
type BillStatUploader struct {
OnUpload func(ctx context.Context, records billstat.Records) (err error)
}
// Upload implements the billstat.Uploader interface for *BillStatUploader.
// Upload implements the [billstat.Uploader] interface for *BillStatUploader.
func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records) (err error) {
return b.OnUpload(ctx, records)
}
@ -158,7 +158,7 @@ func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records)
// type check
var _ dnscheck.Interface = (*DNSCheck)(nil)
// DNSCheck is a dnscheck.Interface for tests.
// DNSCheck is a [dnscheck.Interface] for tests.
type DNSCheck struct {
OnCheck func(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (reqp *dns.Msg, err error)
}
@ -177,12 +177,12 @@ func (db *DNSCheck) Check(
// type check
var _ dnsdb.Interface = (*DNSDB)(nil)
// DNSDB is a dnsdb.Interface for tests.
// DNSDB is a [dnsdb.Interface] for tests.
type DNSDB struct {
OnRecord func(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo)
}
// Record implements the dnsdb.Interface interface for *DNSDB.
// Record implements the [dnsdb.Interface] interface for *DNSDB.
func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) {
db.OnRecord(ctx, resp, ri)
}
@ -204,7 +204,7 @@ func (c *ErrorCollector) Collect(ctx context.Context, err error) {
c.OnCollect(ctx, err)
}
// NewErrorCollector returns a new [ErrorCollector] all methods of which panic.
// NewErrorCollector returns a new *ErrorCollector all methods of which panic.
func NewErrorCollector() (c *ErrorCollector) {
return &ErrorCollector{
OnCollect: func(_ context.Context, err error) {
@ -297,21 +297,38 @@ func (s *FilterStorage) HasListID(id agd.FilterListID) (ok bool) {
// type check
var _ geoip.Interface = (*GeoIP)(nil)
// GeoIP is a geoip.Interface for tests.
// GeoIP is a [geoip.Interface] for tests.
type GeoIP struct {
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
OnData func(host string, ip netip.Addr) (l *geoip.Location, err error)
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
}
// SubnetByLocation implements the geoip.Interface interface for *GeoIP.
func (g *GeoIP) SubnetByLocation(l *geoip.Location, fam netutil.AddrFamily,
// Data implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) {
return g.OnData(host, ip)
}
// SubnetByLocation implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) SubnetByLocation(
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) {
return g.OnSubnetByLocation(l, fam)
}
// Data implements the geoip.Interface interface for *GeoIP.
func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) {
return g.OnData(host, ip)
// NewGeoIP returns a new *GeoIP all methods of which panic.
func NewGeoIP() (c *GeoIP) {
return &GeoIP{
OnData: func(host string, ip netip.Addr) (l *geoip.Location, err error) {
panic(fmt.Errorf("unexpected call to GeoIP.Data(%v, %v)", host, ip))
},
OnSubnetByLocation: func(
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic(fmt.Errorf("unexpected call to GeoIP.SubnetByLocation(%v, %v)", l, fam))
},
}
}
// Package profiledb
@ -398,6 +415,58 @@ func (db *ProfileDB) ProfileByLinkedIP(
return db.OnProfileByLinkedIP(ctx, ip)
}
// NewProfileDB returns a new *ProfileDB all methods of which panic.
func NewProfileDB() (db *ProfileDB) {
return &ProfileDB{
OnCreateAutoDevice: func(
_ context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf(
"unexpected call to ProfileDB.CreateAutoDevice(%v, %v, %v)",
id,
humanID,
devType,
))
},
OnProfileByDedicatedIP: func(
_ context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDedicatedIP(%v)", ip))
},
OnProfileByDeviceID: func(
_ context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDeviceID(%v)", id))
},
OnProfileByHumanID: func(
_ context.Context,
profID agd.ProfileID,
humanID agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf(
"unexpected call to ProfileDB.ProfileByHumanID(%v, %v)",
profID,
humanID,
))
},
OnProfileByLinkedIP: func(
_ context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByLinkedIP(%v)", ip))
},
}
}
// type check
var _ profiledb.Storage = (*ProfileStorage)(nil)
@ -436,12 +505,12 @@ func (s *ProfileStorage) Profiles(
// type check
var _ querylog.Interface = (*QueryLog)(nil)
// QueryLog is a querylog.Interface for tests.
// QueryLog is a [querylog.Interface] for tests.
type QueryLog struct {
OnWrite func(ctx context.Context, e *querylog.Entry) (err error)
}
// Write implements the querylog.Interface interface for *QueryLog.
// Write implements the [querylog.Interface] interface for *QueryLog.
func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
return ql.OnWrite(ctx, e)
}
@ -451,12 +520,12 @@ func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
// type check
var _ rulestat.Interface = (*RuleStat)(nil)
// RuleStat is a rulestat.Interface for tests.
// RuleStat is a [rulestat.Interface] for tests.
type RuleStat struct {
OnCollect func(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText)
}
// Collect implements the rulestat.Interface interface for *RuleStat.
// Collect implements the [rulestat.Interface] interface for *RuleStat.
func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText) {
s.OnCollect(ctx, id, text)
}
@ -501,7 +570,7 @@ func (c *ListenConfig) ListenPacket(
// type check
var _ ratelimit.Interface = (*RateLimit)(nil)
// RateLimit is a ratelimit.Interface for tests.
// RateLimit is a [ratelimit.Interface] for tests.
type RateLimit struct {
OnIsRateLimited func(
ctx context.Context,
@ -511,7 +580,7 @@ type RateLimit struct {
OnCountResponses func(ctx context.Context, resp *dns.Msg, ip netip.Addr)
}
// IsRateLimited implements the ratelimit.Interface interface for *RateLimit.
// IsRateLimited implements the [ratelimit.Interface] interface for *RateLimit.
func (l *RateLimit) IsRateLimited(
ctx context.Context,
req *dns.Msg,
@ -520,12 +589,27 @@ func (l *RateLimit) IsRateLimited(
return l.OnIsRateLimited(ctx, req, ip)
}
// CountResponses implements the ratelimit.Interface interface for
// *RateLimit.
// CountResponses implements the [ratelimit.Interface] interface for *RateLimit.
func (l *RateLimit) CountResponses(ctx context.Context, req *dns.Msg, ip netip.Addr) {
l.OnCountResponses(ctx, req, ip)
}
// NewRateLimit returns a new *RateLimit all methods of which panic.
func NewRateLimit() (c *RateLimit) {
return &RateLimit{
OnIsRateLimited: func(
_ context.Context,
req *dns.Msg,
addr netip.Addr,
) (shouldDrop, isAllowlisted bool, err error) {
panic(fmt.Errorf("unexpected call to RateLimit.IsRateLimited(%v, %v)", req, addr))
},
OnCountResponses: func(_ context.Context, resp *dns.Msg, addr netip.Addr) {
panic(fmt.Errorf("unexpected call to RateLimit.CountResponses(%v, %v)", resp, addr))
},
}
}
// RemoteKV is an [remotekv.Interface] implementation for tests.
type RemoteKV struct {
OnGet func(ctx context.Context, key string) (val []byte, ok bool, err error)

View File

@ -16,8 +16,8 @@ import (
"google.golang.org/grpc/metadata"
)
// newClient returns new properly initialized DNSServiceClient.
func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
// newClient returns new properly initialized gRPC connection to the API server.
func newClient(apiURL *url.URL) (client *grpc.ClientConn, err error) {
var creds credentials.TransportCredentials
switch s := apiURL.Scheme; s {
case "grpc":
@ -38,7 +38,7 @@ func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
// called right before the initial refresh.
conn.Connect()
return NewDNSServiceClient(conn), nil
return conn, nil
}
// reportf is a helper method for reporting non-critical errors.

View File

@ -65,3 +65,39 @@ func (s *testDNSServiceServer) SaveDevicesBillingStat(
) (err error) {
return s.OnSaveDevicesBillingStat(srv)
}
// testRemoteKVServiceServer is the [backendpb.RemoteKVServiceServer] for tests.
type testRemoteKVServiceServer struct {
backendpb.UnimplementedRemoteKVServiceServer
OnGet func(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error)
OnSet func(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error)
}
// type check
var _ backendpb.RemoteKVServiceServer = (*testRemoteKVServiceServer)(nil)
// Get implements the [backendpb.RemoteKVServiceServer] interface for
// *testRemoteKVServiceServer.
func (s *testRemoteKVServiceServer) Get(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error) {
return s.OnGet(ctx, req)
}
// Set implements the [backendpb.RemoteKVServiceServer] interface for
// *testRemoteKVServiceServer.
func (s *testRemoteKVServiceServer) Set(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error) {
return s.OnSet(ctx, req)
}

View File

@ -42,7 +42,7 @@ func NewBillStat(c *BillStatConfig) (b *BillStat, err error) {
return &BillStat{
errColl: c.ErrColl,
metrics: c.Metrics,
client: client,
client: NewDNSServiceClient(client),
apiKey: c.APIKey,
}, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,42 @@ service DNSService {
rpc createDeviceByHumanId(CreateDeviceRequest) returns (CreateDeviceResponse);
}
service RateLimitService {
/*
Gets rate limit settings.
*/
rpc getRateLimitSettings(RateLimitSettingsRequest) returns (RateLimitSettingsResponse);
}
service RemoteKVService {
/**
Get the value for the specified key.
This method may return the following errors:
- AuthenticationFailedError: If the authentication failed.
*/
rpc get(RemoteKVGetRequest) returns (RemoteKVGetResponse);
/**
Set the value for the specified key.
This method may return the following errors:
- AuthenticationFailedError: If the authentication failed.
- BadRequestError: If the request is invalid: value size exceeds the 512kb.
*/
rpc set(RemoteKVSetRequest) returns (RemoteKVSetResponse);
}
message RateLimitSettingsRequest {
}
message RateLimitSettingsResponse {
repeated CidrRange allowed_subnets = 1;
}
message DNSProfilesRequest {
google.protobuf.Timestamp sync_time = 1;
}
@ -212,3 +248,24 @@ message RateLimitSettings {
uint32 rps = 2;
repeated CidrRange client_cidr = 3;
}
message RemoteKVGetRequest {
string key = 1;
}
message RemoteKVGetResponse {
oneof value {
bytes data = 1;
google.protobuf.Empty empty = 2;
}
}
message RemoteKVSetRequest {
string key = 1;
bytes data = 2;
google.protobuf.Duration ttl = 3;
}
message RemoteKVSetResponse {
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.27.1
// - protoc v5.28.3
// source: dns.proto
package backendpb
@ -235,3 +235,269 @@ var DNSService_ServiceDesc = grpc.ServiceDesc{
},
Metadata: "dns.proto",
}
const (
RateLimitService_GetRateLimitSettings_FullMethodName = "/RateLimitService/getRateLimitSettings"
)
// RateLimitServiceClient is the client API for RateLimitService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type RateLimitServiceClient interface {
// Gets rate limit settings.
GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error)
}
type rateLimitServiceClient struct {
cc grpc.ClientConnInterface
}
func NewRateLimitServiceClient(cc grpc.ClientConnInterface) RateLimitServiceClient {
return &rateLimitServiceClient{cc}
}
func (c *rateLimitServiceClient) GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RateLimitSettingsResponse)
err := c.cc.Invoke(ctx, RateLimitService_GetRateLimitSettings_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// RateLimitServiceServer is the server API for RateLimitService service.
// All implementations must embed UnimplementedRateLimitServiceServer
// for forward compatibility.
type RateLimitServiceServer interface {
// Gets rate limit settings.
GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error)
mustEmbedUnimplementedRateLimitServiceServer()
}
// UnimplementedRateLimitServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedRateLimitServiceServer struct{}
func (UnimplementedRateLimitServiceServer) GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetRateLimitSettings not implemented")
}
func (UnimplementedRateLimitServiceServer) mustEmbedUnimplementedRateLimitServiceServer() {}
func (UnimplementedRateLimitServiceServer) testEmbeddedByValue() {}
// UnsafeRateLimitServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to RateLimitServiceServer will
// result in compilation errors.
type UnsafeRateLimitServiceServer interface {
mustEmbedUnimplementedRateLimitServiceServer()
}
func RegisterRateLimitServiceServer(s grpc.ServiceRegistrar, srv RateLimitServiceServer) {
// If the following call pancis, it indicates UnimplementedRateLimitServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&RateLimitService_ServiceDesc, srv)
}
func _RateLimitService_GetRateLimitSettings_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RateLimitSettingsRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RateLimitService_GetRateLimitSettings_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, req.(*RateLimitSettingsRequest))
}
return interceptor(ctx, in, info, handler)
}
// RateLimitService_ServiceDesc is the grpc.ServiceDesc for RateLimitService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var RateLimitService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "RateLimitService",
HandlerType: (*RateLimitServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "getRateLimitSettings",
Handler: _RateLimitService_GetRateLimitSettings_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}
const (
RemoteKVService_Get_FullMethodName = "/RemoteKVService/get"
RemoteKVService_Set_FullMethodName = "/RemoteKVService/set"
)
// RemoteKVServiceClient is the client API for RemoteKVService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type RemoteKVServiceClient interface {
// *
// Get the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error)
// *
// Set the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error)
}
type remoteKVServiceClient struct {
cc grpc.ClientConnInterface
}
func NewRemoteKVServiceClient(cc grpc.ClientConnInterface) RemoteKVServiceClient {
return &remoteKVServiceClient{cc}
}
func (c *remoteKVServiceClient) Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoteKVGetResponse)
err := c.cc.Invoke(ctx, RemoteKVService_Get_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *remoteKVServiceClient) Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoteKVSetResponse)
err := c.cc.Invoke(ctx, RemoteKVService_Set_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// RemoteKVServiceServer is the server API for RemoteKVService service.
// All implementations must embed UnimplementedRemoteKVServiceServer
// for forward compatibility.
type RemoteKVServiceServer interface {
// *
// Get the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error)
// *
// Set the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error)
mustEmbedUnimplementedRemoteKVServiceServer()
}
// UnimplementedRemoteKVServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedRemoteKVServiceServer struct{}
func (UnimplementedRemoteKVServiceServer) Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
}
func (UnimplementedRemoteKVServiceServer) Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
}
func (UnimplementedRemoteKVServiceServer) mustEmbedUnimplementedRemoteKVServiceServer() {}
func (UnimplementedRemoteKVServiceServer) testEmbeddedByValue() {}
// UnsafeRemoteKVServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to RemoteKVServiceServer will
// result in compilation errors.
type UnsafeRemoteKVServiceServer interface {
mustEmbedUnimplementedRemoteKVServiceServer()
}
func RegisterRemoteKVServiceServer(s grpc.ServiceRegistrar, srv RemoteKVServiceServer) {
// If the following call pancis, it indicates UnimplementedRemoteKVServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&RemoteKVService_ServiceDesc, srv)
}
func _RemoteKVService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoteKVGetRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RemoteKVServiceServer).Get(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RemoteKVService_Get_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RemoteKVServiceServer).Get(ctx, req.(*RemoteKVGetRequest))
}
return interceptor(ctx, in, info, handler)
}
func _RemoteKVService_Set_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoteKVSetRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RemoteKVServiceServer).Set(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RemoteKVService_Set_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RemoteKVServiceServer).Set(ctx, req.(*RemoteKVSetRequest))
}
return interceptor(ctx, in, info, handler)
}
// RemoteKVService_ServiceDesc is the grpc.ServiceDesc for RemoteKVService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var RemoteKVService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "RemoteKVService",
HandlerType: (*RemoteKVServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "get",
Handler: _RemoteKVService_Get_Handler,
},
{
MethodName: "set",
Handler: _RemoteKVService_Set_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}

View File

@ -81,7 +81,7 @@ func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage, err error) {
return &ProfileStorage{
bindSet: c.BindSet,
errColl: c.ErrColl,
client: client,
client: NewDNSServiceClient(client),
logger: c.Logger,
metrics: c.Metrics,
apiKey: c.APIKey,

View File

@ -0,0 +1,103 @@
package backendpb
import (
"context"
"fmt"
"log/slog"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
)
// RateLimiterConfig is the configuration structure for the business logic
// backend rate limiter.
type RateLimiterConfig struct {
// Logger is used for logging the operation of the rate limiter. It must
// not be nil.
Logger *slog.Logger
// GRPCMetrics is used for the collection of the protobuf errors.
GRPCMetrics Metrics
// Metrics is used to collect allowlist statistics.
Metrics consul.Metrics
// Allowlist is the allowlist to update.
Allowlist *ratelimit.DynamicAllowlist
// ErrColl is used to collect errors during refreshes.
ErrColl errcoll.Interface
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs". It must not be nil.
Endpoint *url.URL
// APIKey is the API key used for authentication, if any. If empty, no
// authentication is performed.
APIKey string
}
// RateLimiter is the implementation of the [agdservice.Refresher] interface
// that retrieves the rate limit settings from the business logic backend.
type RateLimiter struct {
logger *slog.Logger
grpcMetrics Metrics
metrics consul.Metrics
allowlist *ratelimit.DynamicAllowlist
errColl errcoll.Interface
client RateLimitServiceClient
apiKey string
}
// NewRateLimiter creates a new properly initialized rate limiter. c must not
// be nil.
func NewRateLimiter(c *RateLimiterConfig) (l *RateLimiter, err error) {
client, err := newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return &RateLimiter{
logger: c.Logger,
grpcMetrics: c.GRPCMetrics,
metrics: c.Metrics,
allowlist: c.Allowlist,
errColl: c.ErrColl,
client: NewRateLimitServiceClient(client),
apiKey: c.APIKey,
}, nil
}
// type check
var _ agdservice.Refresher = (*RateLimiter)(nil)
// Refresh implements the [agdservice.Refresher] interface for *RateLimiter.
func (l *RateLimiter) Refresh(ctx context.Context) (err error) {
l.logger.InfoContext(ctx, "refresh started")
defer l.logger.InfoContext(ctx, "refresh finished")
defer func() { l.metrics.SetStatus(ctx, err) }()
ctx = ctxWithAuthentication(ctx, l.apiKey)
backendResp, err := l.client.GetRateLimitSettings(ctx, &RateLimitSettingsRequest{})
if err != nil {
return fmt.Errorf(
"loading backend rate limit settings: %w",
fixGRPCError(ctx, l.grpcMetrics, err),
)
}
allowedSubnets := backendResp.AllowedSubnets
prefixes := cidrRangeToInternal(ctx, l.errColl, allowedSubnets)
l.allowlist.Update(prefixes)
l.logger.InfoContext(ctx, "refresh successful", "num_records", len(prefixes))
l.metrics.SetSize(ctx, len(prefixes))
return nil
}

View File

@ -0,0 +1,110 @@
package backendpb_test
import (
"context"
"net"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// testRateLimitServiceServer is the [backendpb.RateLimitServiceServer] for
// tests.
type testRateLimitServiceServer struct {
backendpb.UnimplementedRateLimitServiceServer
OnGetRateLimitSettings func(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error)
}
// type check
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
// GetRateLimitSettings implements the [backendpb.RateLimitServiceServer]
// interface for *testRateLimitServiceServer.
func (s *testRateLimitServiceServer) GetRateLimitSettings(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error) {
return s.OnGetRateLimitSettings(ctx, req)
}
func TestRateLimiter_Refresh(t *testing.T) {
var (
allowedIP = netip.MustParseAddr("1.2.3.4")
notAllowedIP = netip.MustParseAddr("4.3.2.1")
cidr = &backendpb.CidrRange{
Address: allowedIP.AsSlice(),
Prefix: 32,
}
)
srv := &testRateLimitServiceServer{
OnGetRateLimitSettings: func(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error) {
return &backendpb.RateLimitSettingsResponse{
AllowedSubnets: []*backendpb.CidrRange{cidr},
}, nil
},
}
ln, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterRateLimitServiceServer(grpcSrv, srv)
go func() {
pt := testutil.PanicT{}
srvErr := grpcSrv.Serve(ln)
require.NoError(pt, srvErr)
}()
t.Cleanup(grpcSrv.GracefulStop)
allowlist := ratelimit.NewDynamicAllowlist(nil, nil)
l, err := backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
Logger: slogutil.NewDiscardLogger(),
Metrics: consul.EmptyMetrics{},
GRPCMetrics: backendpb.EmptyMetrics{},
Allowlist: allowlist,
Endpoint: &url.URL{
Scheme: "grpc",
Host: ln.Addr().String(),
},
})
require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
err = l.Refresh(ctx)
require.NoError(t, err)
ok, err := allowlist.IsAllowed(ctx, allowedIP)
require.NoError(t, err)
assert.True(t, ok)
ok, err = allowlist.IsAllowed(ctx, notAllowedIP)
require.NoError(t, err)
assert.False(t, ok)
}

View File

@ -0,0 +1,96 @@
package backendpb
import (
"context"
"fmt"
"net/url"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"google.golang.org/protobuf/types/known/durationpb"
)
// RemoteKVConfig is the configuration for the business logic backend key-value
// storage.
type RemoteKVConfig struct {
// Metrics is used for the collection of the remote key-value storage
// statistics.
//
// TODO(e.burkov): Perhaps, it worths of a separate metrics interface,
// since it's only used for the collection of the protobuf errors.
Metrics Metrics
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs".
Endpoint *url.URL
// APIKey is the API key used for authentication, if any.
APIKey string
// TTL is the TTL of the values in the storage.
TTL time.Duration
}
// RemoteKV is the implementation of the [remotekv.Interface] interface that
// uses the business logic backend as the key-value storage. It is safe for
// concurrent use.
type RemoteKV struct {
metrics Metrics
client RemoteKVServiceClient
apiKey string
ttl time.Duration
}
// NewRemoteKV returns a new [RemoteKV] that retrieves information from the
// business logic backend.
func NewRemoteKV(c *RemoteKVConfig) (kv *RemoteKV, err error) {
client, err := newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return &RemoteKV{
metrics: c.Metrics,
client: NewRemoteKVServiceClient(client),
apiKey: c.APIKey,
ttl: c.TTL,
}, nil
}
// type check
var _ remotekv.Interface = (*RemoteKV)(nil)
// Get implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Get(ctx context.Context, key string) (val []byte, ok bool, err error) {
req := &RemoteKVGetRequest{
Key: key,
}
ctx = ctxWithAuthentication(ctx, kv.apiKey)
resp, err := kv.client.Get(ctx, req)
if err != nil {
return nil, false, fmt.Errorf("getting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
}
val = resp.GetData()
return val, val != nil, nil
}
// Set implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Set(ctx context.Context, key string, val []byte) (err error) {
req := &RemoteKVSetRequest{
Key: key,
Data: val,
Ttl: durationpb.New(kv.ttl),
}
ctx = ctxWithAuthentication(ctx, kv.apiKey)
_, err = kv.client.Set(ctx, req)
if err != nil {
return fmt.Errorf("setting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
}
return nil
}

View File

@ -0,0 +1,106 @@
package backendpb_test
import (
"context"
"net"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestRemoteKV_Get(t *testing.T) {
const testTTL = 10 * time.Second
pt := &testutil.PanicT{}
strg := map[string][]byte{}
srv := &testRemoteKVServiceServer{
OnGet: func(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error) {
resp = &backendpb.RemoteKVGetResponse{
Value: &backendpb.RemoteKVGetResponse_Empty{},
}
if val, ok := strg[req.Key]; ok {
resp.Value = &backendpb.RemoteKVGetResponse_Data{Data: val}
}
return resp, nil
},
OnSet: func(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error) {
require.Equal(pt, testTTL, req.Ttl.AsDuration())
strg[req.Key] = req.Data
return &backendpb.RemoteKVSetResponse{}, nil
},
}
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterRemoteKVServiceServer(grpcSrv, srv)
go func() {
srvErr := grpcSrv.Serve(l)
require.NoError(pt, srvErr)
}()
t.Cleanup(grpcSrv.GracefulStop)
kv, err := backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
Metrics: backendpb.EmptyMetrics{},
Endpoint: &url.URL{
Scheme: "grpc",
Host: l.Addr().String(),
},
APIKey: "apikey",
TTL: testTTL,
})
require.NoError(t, err)
const (
keyWithData = "key"
keyNoData = "unknown"
)
t.Run("success", func(t *testing.T) {
val := []byte("value")
ctx := testutil.ContextWithTimeout(t, testTimeout)
setErr := kv.Set(ctx, keyWithData, val)
require.NoError(t, setErr)
gotVal, ok, getErr := kv.Get(ctx, keyWithData)
require.NoError(t, getErr)
require.True(t, ok)
assert.Equal(t, val, gotVal)
})
t.Run("not_found", func(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
val, ok, getErr := kv.Get(ctx, keyNoData)
require.NoError(t, getErr)
require.False(t, ok)
assert.Nil(t, val)
})
}

View File

@ -20,23 +20,19 @@ type connIndex struct {
}
// subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
// a sorts before b, 1 if a sorts after b, and 0 if their relative sorting
// position is the same.
func subnetCompare(x, y netip.Prefix) (cmp int) {
if x == y {
return 0
}
func subnetCompare(a, b netip.Prefix) (cmp int) {
aAddr, aBits := a.Addr(), a.Bits()
bAddr, bBits := b.Addr(), b.Bits()
xAddr, xBits := x.Addr(), x.Bits()
yAddr, yBits := y.Addr(), y.Bits()
if xBits == yBits {
return xAddr.Compare(yAddr)
}
if xBits > yBits {
switch {
case aBits > bBits:
return -1
} else {
case aBits < bBits:
return 1
default:
return aAddr.Compare(bAddr)
}
}
@ -45,8 +41,8 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
//
// TODO(a.garipov): Merge with [addListenerChannel].
func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
cmpFunc := func(x, y *chanPacketConn) (cmp int) {
return subnetCompare(x.subnet, y.subnet)
cmpFunc := func(a, b *chanPacketConn) (cmp int) {
return subnetCompare(a.subnet, b.subnet)
}
newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc)
@ -67,8 +63,8 @@ func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
//
// TODO(a.garipov): Merge with [addPacketConnChannel].
func (idx *connIndex) addListener(l *chanListener) (err error) {
cmpFunc := func(x, y *chanListener) (cmp int) {
return subnetCompare(x.subnet, y.subnet)
cmpFunc := func(a, b *chanListener) (cmp int) {
return subnetCompare(a.subnet, b.subnet)
}
newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc)

View File

@ -6,7 +6,6 @@ import (
"log/slog"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@ -100,8 +100,9 @@ func newBillStatUploader(
mtrc backendpb.Metrics,
) (s billstat.Uploader, err error) {
apiURL := netutil.CloneURL(&envs.BillStatURL.URL)
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) {
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
err = urlutil.ValidateGRPCURL(apiURL)
if err != nil {
return nil, fmt.Errorf("billstat api url: %w", err)
}
return backendpb.NewBillStat(&backendpb.BillStatConfig{

View File

@ -6,6 +6,7 @@ import (
"log/slog"
"maps"
"net/netip"
"net/url"
"path"
"path/filepath"
"slices"
@ -14,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@ -38,9 +38,11 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/service"
"github.com/c2h5oh/datasize"
@ -112,6 +114,8 @@ type builder struct {
ruleStat rulestat.Interface
safeBrowsing *hashprefix.Filter
safeBrowsingHashes *hashprefix.Storage
sdeConf *dnsmsg.StructuredDNSErrorsConfig
tlsMtrc tlsconfig.Metrics
webSvc *websvc.Service
// The fields below are initialized later, just like with the fields above,
@ -556,12 +560,42 @@ func (b *builder) initBindToDevice(ctx context.Context) (err error) {
return nil
}
// Constants for the experimental Structured DNS Errors feature.
//
// TODO(a.garipov): Make configurable.
const (
sdeJustification = "Filtered by AdGuard DNS"
sdeOrganization = "AdGuard DNS"
)
// Variables for the experimental Structured DNS Errors feature.
//
// TODO(a.garipov): Make configurable.
var (
sdeContactURL = &url.URL{
Scheme: "mailto",
Opaque: "support@adguard-dns.io",
}
)
// initMsgConstructor initializes the common DNS message constructor.
func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
fltConf := b.conf.Filters
b.sdeConf = &dnsmsg.StructuredDNSErrorsConfig{
Contact: []*url.URL{
sdeContactURL,
},
Justification: sdeJustification,
Organization: sdeOrganization,
Enabled: fltConf.SDEEnabled,
}
b.messages, err = dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: b.cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: b.conf.Filters.ResponseTTL.Duration,
StructuredErrors: b.sdeConf,
FilteredResponseTTL: fltConf.ResponseTTL.Duration,
EDEEnabled: fltConf.EDEEnabled,
})
if err != nil {
return fmt.Errorf("creating dns message constructor: %w", err)
@ -579,8 +613,17 @@ func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
// - [builder.initFilteringGroups]
// - [builder.initMsgConstructor]
func (b *builder) initServerGroups(ctx context.Context) (err error) {
mtrc, err := metrics.NewTLSConfig(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering tls metrics: %w", err)
}
b.tlsMtrc = mtrc
c := b.conf
b.serverGroups, err = c.ServerGroups.toInternal(
ctx,
mtrc,
b.messages,
b.btdManager,
b.filteringGroups,
@ -671,7 +714,7 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
}
}
tickRot := newTicketRotator(b.baseLogger, b.errColl, b.serverGroups)
tickRot := newTicketRotator(b.baseLogger, b.errColl, b.tlsMtrc, b.serverGroups)
err = tickRot.Refresh(ctx)
if err != nil {
return fmt.Errorf("initial session ticket refresh: %w", err)
@ -700,6 +743,29 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
return nil
}
// initGRPCMetrics initializes the gRPC metrics if necessary.
func (b *builder) initGRPCMetrics(ctx context.Context) (err error) {
switch {
case
b.profilesEnabled,
b.conf.Check.RemoteKV.Type == kvModeBackend,
b.conf.RateLimit.Allowlist.Type == rlAllowlistTypeBackend:
// Go on.
default:
// Don't initialize the metrics if no protobuf backend is used.
return nil
}
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering backendbp metrics: %w", err)
}
b.logger.DebugContext(ctx, "initialized grpc metrics")
return nil
}
// initBillStat initializes the billing-statistics recorder if necessary. It
// also adds the refresher with ID [debugIDBillStat] to the debug refreshers.
func (b *builder) initBillStat(ctx context.Context) (err error) {
@ -709,11 +775,6 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
return nil
}
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering backendbp metrics: %w", err)
}
upl, err := newBillStatUploader(b.env, b.errColl, b.backendGRPCMtrc)
if err != nil {
return fmt.Errorf("creating billstat uploader: %w", err)
@ -760,9 +821,9 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
// initProfileDB initializes the profile database if necessary.
//
// [builder.initBillStat] and [builder.initServerGroups] must be called before
// this method. It also adds the refresher with ID [debugIDProfileDB] to the
// debug refreshers.
// [builder.initGRPCMetrics] and [builder.initServerGroups] must be called
// before this method. It also adds the refresher with ID [debugIDProfileDB] to
// the debug refreshers.
func (b *builder) initProfileDB(ctx context.Context) (err error) {
if !b.profilesEnabled {
b.profileDB = &profiledb.Disabled{}
@ -771,8 +832,9 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
}
apiURL := netutil.CloneURL(&b.env.ProfilesURL.URL)
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) {
return fmt.Errorf("invalid backend api url: %s", apiURL)
err = urlutil.ValidateGRPCURL(apiURL)
if err != nil {
return fmt.Errorf("profile api url: %w", err)
}
respSzEst := b.conf.RateLimit.ResponseSizeEstimate
@ -842,7 +904,8 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
// initDNSCheck initializes the DNS checker.
//
// [builder.initMsgConstructor] must be called before this method.
// [builder.initGRPCMetrics] and [builder.initMsgConstructor] must be called
// before this method.
func (b *builder) initDNSCheck(ctx context.Context) (err error) {
b.dnsCheck = b.plugins.DNSCheck()
if b.dnsCheck != nil {
@ -853,7 +916,14 @@ func (b *builder) initDNSCheck(ctx context.Context) (err error) {
c := b.conf.Check
checkConf, err := c.toInternal(b.env, b.messages, b.errColl, b.mtrcNamespace, b.promRegisterer)
checkConf, err := c.toInternal(
b.env,
b.messages,
b.errColl,
b.mtrcNamespace,
b.promRegisterer,
b.backendGRPCMtrc,
)
if err != nil {
return fmt.Errorf("initializing dnscheck: %w", err)
}
@ -911,18 +981,44 @@ func (b *builder) initRuleStat(ctx context.Context) (err error) {
// well as starts and registers the rate-limiter refresher in the signal
// handler. It also adds the refresher with ID [debugIDAllowlist] to the debug
// refreshers.
//
// [builder.initGRPCMetrics] must be called before this method.
func (b *builder) initRateLimiter(ctx context.Context) (err error) {
c := b.conf.RateLimit
allowSubnets := netutil.UnembedPrefixes(c.Allowlist.List)
allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil)
updater := consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
typ := b.conf.RateLimit.Allowlist.Type
mtrc, err := metrics.NewAllowlist(b.mtrcNamespace, b.promRegisterer, typ)
if err != nil {
return fmt.Errorf("ratelimit metrics: %w", err)
}
var updater agdservice.Refresher
if typ == rlAllowlistTypeBackend {
updater, err = backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
Logger: b.baseLogger.With(slogutil.KeyPrefix, "backend_ratelimiter"),
Metrics: mtrc,
GRPCMetrics: b.backendGRPCMtrc,
Allowlist: allowlist,
Endpoint: &b.env.BackendRateLimitURL.URL,
ErrColl: b.errColl,
APIKey: b.env.BackendRateLimitAPIKey,
})
if err != nil {
return fmt.Errorf("ratelimit: %w", err)
}
} else {
updater = consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
Logger: b.baseLogger.With(slogutil.KeyPrefix, "ratelimit_allowlist_updater"),
Allowlist: allowlist,
ConsulURL: &b.env.ConsulAllowlistURL.URL,
ErrColl: b.errColl,
Metrics: mtrc,
// TODO(a.garipov): Make configurable.
Timeout: 15 * time.Second,
})
}
err = updater.Refresh(ctx)
if err != nil {
@ -956,9 +1052,11 @@ func (b *builder) initRateLimiter(ctx context.Context) (err error) {
// initWeb initializes the web service, starts it, and registers it in the
// signal handler.
//
// [builder.initServerGroups] must be called before this method.
func (b *builder) initWeb(ctx context.Context) (err error) {
c := b.conf.Web
webConf, err := c.toInternal(b.env, b.dnsCheck, b.errColl)
webConf, err := c.toInternal(ctx, b.env, b.dnsCheck, b.errColl, b.tlsMtrc)
if err != nil {
return fmt.Errorf("converting web configuration: %w", err)
}
@ -1057,45 +1155,55 @@ func (b *builder) initDNS(ctx context.Context) (err error) {
b.fwdHandler = forward.NewHandler(b.conf.Upstream.toInternal(b.baseLogger))
b.dnsDB = b.conf.DNSDB.toInternal(b.errColl)
cacheConf := b.conf.Cache
dnsConf := &dnssvc.Config{
dnsHdlrsConf := &dnssvc.HandlersConfig{
BaseLogger: b.baseLogger,
Messages: b.messages,
Cache: b.conf.Cache.toInternal(),
Cloner: b.cloner,
ControlConf: b.controlConf,
ConnLimiter: b.connLimit,
HumanIDParser: agd.NewHumanIDParser(),
Messages: b.messages,
PluginRegistry: b.plugins,
StructuredErrors: b.sdeConf,
AccessManager: b.access,
SafeBrowsing: b.hashMatcher,
BillStat: b.billStat,
CacheManager: b.cacheManager,
ProfileDB: b.profileDB,
PrometheusRegisterer: b.promRegisterer,
DNSCheck: b.dnsCheck,
NonDNS: b.webSvc,
DNSDB: b.dnsDB,
ErrColl: b.errColl,
FilterStorage: b.filterStorage,
GeoIP: b.geoIP,
Handler: b.fwdHandler,
HashMatcher: b.hashMatcher,
ProfileDB: b.profileDB,
PrometheusRegisterer: b.promRegisterer,
QueryLog: b.queryLog(),
RuleStat: b.ruleStat,
RateLimit: b.rateLimit,
RuleStat: b.ruleStat,
MetricsNamespace: b.mtrcNamespace,
FilteringGroups: b.filteringGroups,
ServerGroups: b.serverGroups,
EDEEnabled: b.conf.Filters.EDEEnabled,
}
dnsHdlrs, err := dnssvc.NewHandlers(ctx, dnsHdlrsConf)
if err != nil {
return fmt.Errorf("dns handlers: %w", err)
}
dnsConf := &dnssvc.Config{
Handlers: dnsHdlrs,
Cloner: b.cloner,
ControlConf: b.controlConf,
ConnLimiter: b.connLimit,
NonDNS: b.webSvc,
ErrColl: b.errColl,
MetricsNamespace: b.mtrcNamespace,
ServerGroups: b.serverGroups,
HandleTimeout: b.conf.DNS.HandleTimeout.Duration,
CacheSize: cacheConf.Size,
ECSCacheSize: cacheConf.ECSSize,
CacheMinTTL: cacheConf.TTLOverride.Min.Duration,
UseCacheTTLOverride: cacheConf.TTLOverride.Enabled,
UseECSCache: cacheConf.Type == cacheTypeECS,
}
b.dnsSvc, err = dnssvc.New(dnsConf)
if err != nil {
return fmt.Errorf("initializing dns: %w", err)
return fmt.Errorf("dns service: %w", err)
}
b.logger.DebugContext(ctx, "initialized dns")

View File

@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil"
)
@ -43,6 +44,28 @@ const (
cacheTypeSimple = "simple"
)
// toInternal converts c to the cache configuration for the DNS server. c must
// be valid.
func (c *cacheConfig) toInternal() (cacheConf *dnssvc.CacheConfig) {
var typ dnssvc.CacheType
if c.Size == 0 {
// TODO(a.garipov): Add as a type in the configuration file.
typ = dnssvc.CacheTypeNone
} else if c.Type == cacheTypeSimple {
typ = dnssvc.CacheTypeSimple
} else {
typ = dnssvc.CacheTypeECS
}
return &dnssvc.CacheConfig{
MinTTL: c.TTLOverride.Min.Duration,
ECSCount: c.ECSSize,
NoECSCount: c.Size,
Type: typ,
OverrideCacheTTL: c.TTLOverride.Enabled,
}
}
// type check
var _ validator = (*cacheConfig)(nil)

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -54,8 +55,9 @@ func (c *checkConfig) toInternal(
errColl errcoll.Interface,
namespace string,
reg prometheus.Registerer,
backendMtrc backendpb.Metrics,
) (conf *dnscheck.RemoteKVConfig, err error) {
kv, err := newDNSCheckKV(c, envs, namespace, reg)
kv, err := newRemoteKV(c.RemoteKV, envs, namespace, reg, backendMtrc)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@ -85,21 +87,33 @@ const maxRespSize = 1 * datasize.MB
// [remotekv.KeyNamespace].
const keyNamespaceCheck = "check"
// newDNSCheckKV returns a new properly initialized remote key-value storage.
func newDNSCheckKV(
conf *checkConfig,
// newRemoteKV returns a new properly initialized remote key-value storage.
func newRemoteKV(
c *remoteKVConfig,
envs *environment,
namespace string,
reg prometheus.Registerer,
backendMtrc backendpb.Metrics,
) (kv remotekv.Interface, err error) {
if conf.RemoteKV.Type == kvModeRedis {
switch c.Type {
case kvModeBackend:
kv, err = backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
Metrics: backendMtrc,
Endpoint: &envs.DNSCheckRemoteKVURL.URL,
APIKey: envs.DNSCheckRemoteKVAPIKey,
TTL: c.TTL.Duration,
})
if err != nil {
return nil, fmt.Errorf("initializing backend dnscheck kv: %w", err)
}
case kvModeRedis:
var redisKVMtrc rediskv.Metrics
redisKVMtrc, err = metrics.NewRedisKV(namespace, reg)
if err != nil {
return nil, fmt.Errorf("registering redis kv metrics: %w", err)
}
kv := rediskv.NewRedisKV(&rediskv.RedisKVConfig{
redisKV := rediskv.NewRedisKV(&rediskv.RedisKVConfig{
Metrics: redisKVMtrc,
Addr: &netutil.HostPort{
Host: envs.RedisAddr,
@ -108,18 +122,20 @@ func newDNSCheckKV(
MaxActive: envs.RedisMaxActive,
MaxIdle: envs.RedisMaxIdle,
IdleTimeout: envs.RedisIdleTimeout.Duration,
TTL: conf.RemoteKV.TTL.Duration,
TTL: c.TTL.Duration,
})
return remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{
KV: kv,
kv = remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{
KV: redisKV,
Prefix: fmt.Sprintf("%s:%s:", envs.RedisKeyPrefix, keyNamespaceCheck),
}), nil
}
})
case kvModeConsul:
consulKVURL := envs.ConsulDNSCheckKVURL
consulSessionURL := envs.ConsulDNSCheckSessionURL
if consulKVURL != nil && consulSessionURL != nil {
if consulKVURL == nil || consulSessionURL == nil {
return remotekv.Empty{}, nil
}
kv, err = consulkv.NewKV(&consulkv.Config{
URL: &consulKVURL.URL,
SessionURL: &consulSessionURL.URL,
@ -129,14 +145,14 @@ func newDNSCheckKV(
}),
// TODO(ameshkov): Consider making configurable.
Limiter: rate.NewLimiter(rate.Limit(200)/60, 1),
TTL: conf.RemoteKV.TTL.Duration,
TTL: c.TTL.Duration,
MaxRespSize: maxRespSize,
})
if err != nil {
return nil, fmt.Errorf("initializing consul dnscheck: %w", err)
return nil, fmt.Errorf("initializing consul dnscheck kv: %w", err)
}
} else {
kv = remotekv.Empty{}
default:
return remotekv.Empty{}, nil
}
return kv, nil
@ -221,6 +237,7 @@ func validateNonNilIPs(ips []netip.Addr, fam netutil.AddrFamily) (err error) {
// DNSCheck key-value database modes.
const (
kvModeBackend = "backend"
kvModeConsul = "consul"
kvModeRedis = "redis"
)
@ -229,7 +246,7 @@ const (
// checking.
type remoteKVConfig struct {
// Type defines the type of remote key-value store. Allowed values are
// [kvModeConsul] and [kvModeRedis].
// [kvModeBackend], [kvModeConsul] and [kvModeRedis].
Type string `yaml:"type"`
// TTL defines, for how long to keep the information about a single client.
@ -248,6 +265,10 @@ func (c *remoteKVConfig) validate() (err error) {
ttl := c.TTL
switch c.Type {
case kvModeBackend:
if ttl.Duration <= 0 {
return newNotPositiveError("ttl", ttl)
}
case kvModeConsul:
if ttl.Duration < consulkv.MinTTL || ttl.Duration > consulkv.MaxTTL {
return fmt.Errorf(
@ -270,7 +291,7 @@ func (c *remoteKVConfig) validate() (err error) {
case "":
return fmt.Errorf("type: %w", errors.ErrEmptyValue)
default:
return fmt.Errorf("type: %q: %w", c.Type, errors.ErrBadEnumValue)
return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
}
return nil

View File

@ -100,6 +100,8 @@ func Main(plugins *plugin.Registry) {
errors.Check(b.initTLS(ctx))
errors.Check(b.initGRPCMetrics(ctx))
errors.Check(b.initBillStat(ctx))
errors.Check(b.initProfileDB(ctx))

View File

@ -6,9 +6,10 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -25,13 +26,17 @@ import (
)
// environment represents the configuration that is kept in the environment.
//
// TODO(e.burkov, a.garipov): Name variables more consistently.
type environment struct {
AdultBlockingURL *urlutil.URL `env:"ADULT_BLOCKING_URL"`
BackendRateLimitURL *urlutil.URL `env:"BACKEND_RATELIMIT_URL"`
BillStatURL *urlutil.URL `env:"BILLSTAT_URL"`
BlockedServiceIndexURL *urlutil.URL `env:"BLOCKED_SERVICE_INDEX_URL"`
ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL,notEmpty"`
ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL"`
ConsulDNSCheckKVURL *urlutil.URL `env:"CONSUL_DNSCHECK_KV_URL"`
ConsulDNSCheckSessionURL *urlutil.URL `env:"CONSUL_DNSCHECK_SESSION_URL"`
DNSCheckRemoteKVURL *urlutil.URL `env:"DNSCHECK_REMOTEKV_URL"`
FilterIndexURL *urlutil.URL `env:"FILTER_INDEX_URL,notEmpty"`
GeneralSafeSearchURL *urlutil.URL `env:"GENERAL_SAFE_SEARCH_URL"`
LinkedIPTargetURL *urlutil.URL `env:"LINKED_IP_TARGET_URL"`
@ -41,8 +46,10 @@ type environment struct {
SafeBrowsingURL *urlutil.URL `env:"SAFE_BROWSING_URL"`
YoutubeSafeSearchURL *urlutil.URL `env:"YOUTUBE_SAFE_SEARCH_URL"`
BackendRateLimitAPIKey string `env:"BACKEND_RATELIMIT_API_KEY"`
BillStatAPIKey string `env:"BILLSTAT_API_KEY"`
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
DNSCheckRemoteKVAPIKey string `env:"DNSCHECK_REMOTEKV_API_KEY"`
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
@ -57,7 +64,7 @@ type environment struct {
ListenAddr net.IP `env:"LISTEN_ADDR" envDefault:"127.0.0.1"`
ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"8MB"`
ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"64MB"`
RedisIdleTimeout timeutil.Duration `env:"REDIS_IDLE_TIMEOUT" envDefault:"30s"`
@ -100,7 +107,8 @@ func (envs *environment) validate() (err error) {
errs = envs.validateHTTPURLs(errs)
if s := envs.FilterIndexURL.Scheme; s != agdhttp.SchemeFile && !agdhttp.CheckHTTPURLScheme(s) {
if s := envs.FilterIndexURL.Scheme; !strings.EqualFold(s, urlutil.SchemeFile) &&
!urlutil.IsValidHTTPURLScheme(s) {
errs = append(errs, fmt.Errorf(
"env %s: not a valid http(s) url or file uri",
"FILTER_INDEX_URL",
@ -140,10 +148,6 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
url: envs.BlockedServiceIndexURL,
name: "BLOCKED_SERVICE_INDEX_URL",
isRequired: bool(envs.BlockedServiceEnabled),
}, {
url: envs.ConsulAllowlistURL,
name: "CONSUL_ALLOWLIST_URL",
isRequired: true,
}, {
url: envs.ConsulDNSCheckKVURL,
name: "CONSUL_DNSCHECK_KV_URL",
@ -184,15 +188,14 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
continue
}
u := urlData.url
if u == nil {
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue))
continue
var u *url.URL
if urlData.url != nil {
u = &urlData.url.URL
}
if !agdhttp.CheckHTTPURLScheme(u.Scheme) {
res = append(res, fmt.Errorf("env %s: not a valid http(s) url", urlData.name))
err := urlutil.ValidateHTTPURL(u)
if err != nil {
res = append(res, fmt.Errorf("env %s: %w", urlData.name, err))
}
}
@ -227,59 +230,86 @@ func (envs *environment) validateWebStaticDir() (err error) {
// validateFromValidConfig returns an error if environment variables that depend
// on configuration properties contain errors. conf is expected to be valid.
func (envs *environment) validateFromValidConfig(conf *configuration) (err error) {
err = envs.validateRedis(conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
var errs []error
switch typ := conf.Check.RemoteKV.Type; typ {
case kvModeRedis:
errs = envs.validateRedis(errs)
case kvModeBackend:
errs = envs.validateBackendKV(errs)
default:
// Probably consul.
}
if !conf.isProfilesEnabled() {
return nil
}
if conf.isProfilesEnabled() {
errs = envs.validateProfilesURLs(errs)
if envs.ProfilesMaxRespSize > math.MaxInt {
return fmt.Errorf(
errs = append(errs, fmt.Errorf(
"PROFILES_MAX_RESP_SIZE: %w: must be less than or equal to %s, got %s",
errors.ErrOutOfRange,
datasize.ByteSize(math.MaxInt),
envs.ProfilesMaxRespSize,
)
))
}
}
return envs.validateProfilesURLs()
}
// validateRedis returns an error if environment variables for Redis as a remote
// key-value store for DNS server checking contain errors.
func (envs *environment) validateRedis(conf *configuration) (err error) {
if conf.Check.RemoteKV.Type != kvModeRedis {
return nil
}
var errs []error
if envs.RedisAddr == "" {
errs = append(errs, fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue))
}
if envs.RedisIdleTimeout.Duration <= 0 {
errs = append(errs, newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout))
}
if envs.RedisMaxActive < 0 {
errs = append(errs, newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive))
}
if envs.RedisMaxIdle < 0 {
errs = append(errs, newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle))
}
errs = envs.validateRateLimitURLs(conf, errs)
return errors.Join(errs...)
}
// validateRedis appends validation errors to the given errs if environment
// variables for Redis contain errors.
func (envs *environment) validateRedis(errs []error) (withRedis []error) {
withRedis = errs
if envs.RedisAddr == "" {
err := fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue)
withRedis = append(withRedis, err)
}
if envs.RedisIdleTimeout.Duration <= 0 {
err := newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout)
withRedis = append(withRedis, err)
}
if envs.RedisMaxActive < 0 {
err := newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive)
withRedis = append(withRedis, err)
}
if envs.RedisMaxIdle < 0 {
err := newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle)
withRedis = append(withRedis, err)
}
return withRedis
}
// validateBackendKV appends validation errors to the given errs if environment
// variables for a backend key-value store contain errors.
func (envs *environment) validateBackendKV(errs []error) (withKV []error) {
withKV = errs
var u *url.URL
if envs.DNSCheckRemoteKVURL != nil {
u = &envs.DNSCheckRemoteKVURL.URL
}
err := urlutil.ValidateGRPCURL(u)
if err != nil {
withKV = append(withKV, fmt.Errorf("env DNSCHECK_REMOTEKV_URL: %w", err))
}
return withKV
}
// validateProfilesURLs appends validation errors to the given errs if profiles
// URLs in environment variables are invalid. All errors are appended to errs
// and returned as res.
func (envs *environment) validateProfilesURLs() (err error) {
// URLs in environment variables are invalid.
func (envs *environment) validateProfilesURLs(errs []error) (withURLs []error) {
withURLs = errs
grpcOnlyURLs := []*urlEnvData{{
url: envs.BillStatURL,
name: "BILLSTAT_URL",
@ -290,24 +320,52 @@ func (envs *environment) validateProfilesURLs() (err error) {
isRequired: true,
}}
var res []error
for _, urlData := range grpcOnlyURLs {
if !urlData.isRequired {
continue
}
if urlData.url == nil {
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue))
continue
var u *url.URL
if urlData.url != nil {
u = &urlData.url.URL
}
if !agdhttp.CheckGRPCURLScheme(urlData.url.Scheme) {
res = append(res, fmt.Errorf("env %s: not a valid grpc(s) url", urlData.name))
err := urlutil.ValidateGRPCURL(u)
if err != nil {
withURLs = append(withURLs, fmt.Errorf("env %s: %w", urlData.name, err))
}
}
return errors.Join(res...)
return withURLs
}
// validateRateLimitURLs appends validation errors to the given errs if rate
// limit URLs in environment variables are invalid.
func (envs *environment) validateRateLimitURLs(
conf *configuration,
errs []error,
) (withURLs []error) {
rlURL := envs.BackendRateLimitURL
rlEnv := "BACKEND_RATELIMIT_URL"
validateFunc := urlutil.ValidateGRPCURL
if conf.RateLimit.Allowlist.Type == rlAllowlistTypeConsul {
rlURL = envs.ConsulAllowlistURL
rlEnv = "CONSUL_ALLOWLIST_URL"
validateFunc = urlutil.ValidateHTTPURL
}
var u *url.URL
if rlURL != nil {
u = &rlURL.URL
}
err := validateFunc(u)
if err != nil {
return append(errs, fmt.Errorf("env %s: %w", rlEnv, err))
}
return errs
}
// configureLogs sets the configuration for the plain text logs. It also

View File

@ -58,6 +58,12 @@ type filtersConfig struct {
// MaxSize is the maximum size of the downloadable filtering rule-list.
MaxSize datasize.ByteSize `yaml:"max_size"`
// EDEEnabled enables the Extended DNS Errors feature.
EDEEnabled bool `yaml:"ede_enabled"`
// SDEEnabled enables the experimental Structured DNS Errors feature.
SDEEnabled bool `yaml:"sde_enabled"`
}
// toInternal converts c to the filter storage configuration for the DNS server.
@ -117,33 +123,31 @@ var _ validator = (*filtersConfig)(nil)
// validate implements the [validator] interface for *filtersConfig.
func (c *filtersConfig) validate() (err error) {
switch {
case c == nil:
if c == nil {
return errors.ErrNoValue
case c.SafeSearchCacheSize <= 0:
return newNotPositiveError("safe_search_cache_size", c.SafeSearchCacheSize)
case c.ResponseTTL.Duration <= 0:
return newNotPositiveError("response_ttl", c.ResponseTTL)
case c.RefreshIvl.Duration <= 0:
return newNotPositiveError("refresh_interval", c.RefreshIvl)
case c.RefreshTimeout.Duration <= 0:
return newNotPositiveError("refresh_timeout", c.RefreshTimeout)
case c.IndexRefreshTimeout.Duration <= 0:
return newNotPositiveError("index_refresh_timeout", c.IndexRefreshTimeout)
case c.RuleListRefreshTimeout.Duration <= 0:
return newNotPositiveError("rule_list_refresh_timeout", c.RuleListRefreshTimeout)
case c.MaxSize <= 0:
return newNotPositiveError("max_size", c.MaxSize)
default:
// Go on.
}
errs := []error{
validatePositive("custom_filter_cache_size", c.CustomFilterCacheSize),
validatePositive("safe_search_cache_size", c.SafeSearchCacheSize),
validatePositive("response_ttl", c.ResponseTTL),
validatePositive("refresh_interval", c.RefreshIvl),
validatePositive("refresh_timeout", c.RefreshTimeout),
validatePositive("index_refresh_timeout", c.IndexRefreshTimeout),
validatePositive("rule_list_refresh_timeout", c.RuleListRefreshTimeout),
validatePositive("max_size", c.MaxSize),
}
if !c.EDEEnabled && c.SDEEnabled {
errs = append(errs, errors.Error("ede must be enabled to enable sde"))
}
err = c.RuleListCache.validate()
if err != nil {
return fmt.Errorf("rule_list_cache: %w", err)
errs = append(errs, fmt.Errorf("rule_list_cache: %w", err))
}
return nil
return errors.Join(errs...)
}
// fltRuleListCache contains filtering rule-list cache configuration.

View File

@ -4,6 +4,7 @@ package plugin
import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
)
@ -13,16 +14,19 @@ import (
type Registry struct {
dnscheck dnscheck.Interface
mainMwMtrc metrics.MainMiddleware
postInitMw dnsserver.Middleware
}
// NewRegistry returns a new registry with the given custom implementations.
func NewRegistry(
dnsCk dnscheck.Interface,
mainMwMtrc metrics.MainMiddleware,
postInitMw dnsserver.Middleware,
) (r *Registry) {
return &Registry{
dnscheck: dnsCk,
mainMwMtrc: mainMwMtrc,
postInitMw: postInitMw,
}
}
@ -44,3 +48,13 @@ func (r *Registry) MainMiddlewareMetrics() (mainMwMtrc metrics.MainMiddleware) {
return r.mainMwMtrc
}
// PostInitialMiddleware returns a custom implementation of the post-initial
// middleware, if any.
func (r *Registry) PostInitialMiddleware() (postInitMw dnsserver.Middleware) {
if r == nil {
return nil
}
return r.postInitMw
}

View File

@ -15,6 +15,12 @@ import (
"github.com/c2h5oh/datasize"
)
// Constants for rate limit settings endpoints.
const (
rlAllowlistTypeBackend = "backend"
rlAllowlistTypeConsul = "consul"
)
// rateLimitConfig is the configuration of the instance's rate limiting.
type rateLimitConfig struct {
// AllowList is the allowlist of clients.
@ -57,20 +63,14 @@ type rateLimitConfig struct {
RefuseANY bool `yaml:"refuse_any"`
}
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []netutil.Prefix `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
// addresses.
type rateLimitOptions struct {
// RPS is the maximum number of requests per second.
RPS uint `yaml:"rps"`
// Count is the maximum number of requests per interval.
Count uint `yaml:"count"`
// Interval is the time during which to count the number of requests.
Interval timeutil.Duration `yaml:"interval"`
// SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys.
@ -87,7 +87,8 @@ func (o *rateLimitOptions) validate() (err error) {
}
return cmp.Or(
validatePositive("rps", o.RPS),
validatePositive("count", o.Count),
validatePositive("interval", o.Interval),
validatePositive("subnet_key_len", o.SubnetKeyLen),
)
}
@ -100,9 +101,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
ResponseSizeEstimate: c.ResponseSizeEstimate,
Duration: c.BackoffDuration.Duration,
Period: c.BackoffPeriod.Duration,
IPv4RPS: c.IPv4.RPS,
IPv4Count: c.IPv4.Count,
IPv4Interval: c.IPv4.Interval.Duration,
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
IPv6RPS: c.IPv6.RPS,
IPv6Count: c.IPv6.Count,
IPv6Interval: c.IPv6.Interval.Duration,
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
Count: c.BackoffCount,
RefuseANY: c.RefuseANY,
@ -114,14 +117,12 @@ var _ validator = (*rateLimitConfig)(nil)
// validate implements the [validator] interface for *rateLimitConfig.
func (c *rateLimitConfig) validate() (err error) {
switch {
case c == nil:
if c == nil {
return errors.ErrNoValue
case c.Allowlist == nil:
return fmt.Errorf("allowlist: %w", errors.ErrNoValue)
}
return cmp.Or(
validateProp("allowlist", c.Allowlist.validate),
validateProp("connection_limit", c.ConnectionLimit.validate),
validateProp("ipv4", c.IPv4.validate),
validateProp("ipv6", c.IPv6.validate),
@ -131,10 +132,41 @@ func (c *rateLimitConfig) validate() (err error) {
validatePositive("backoff_duration", c.BackoffDuration),
validatePositive("backoff_period", c.BackoffPeriod),
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
)
}
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// Type defines where the rate limit settings are received from. Allowed
// values are [rlAllowlistTypeBackend] and [rlAllowlistTypeConsul].
Type string `yaml:"type"`
// List contains IPs and CIDRs.
List []netutil.Prefix `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// type check
var _ validator = (*allowListConfig)(nil)
// validate implements the [validator] interface for *allowListConfig.
func (c *allowListConfig) validate() (err error) {
if c == nil {
return errors.ErrNoValue
}
switch c.Type {
case rlAllowlistTypeBackend, rlAllowlistTypeConsul:
// Go on.
default:
return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
}
return validatePositive("refresh_interval", c.RefreshIvl)
}
// connLimitConfig is the configuration structure for the stream-connection
// limiter.
type connLimitConfig struct {

View File

@ -6,7 +6,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
)
@ -14,6 +14,7 @@ import (
// toInternal returns the configuration of DNS servers for a single server
// group. srvs and other parts of the configuration must be valid.
func (srvs servers) toInternal(
mtrc tlsconfig.Metrics,
tlsConfig *agd.TLS,
btdMgr *bindtodevice.Manager,
ratelimitConf *rateLimitConfig,
@ -68,8 +69,8 @@ func (srvs servers) toInternal(
tlsConf := tlsConfig.Conf.Clone()
// Attach the functions that will count TLS handshake metrics.
tlsConf.GetConfigForClient = metrics.TLSMetricsBeforeHandshake(string(srv.Protocol))
tlsConf.VerifyConnection = metrics.TLSMetricsAfterHandshake(
tlsConf.GetConfigForClient = mtrc.BeforeHandshake(string(srv.Protocol))
tlsConf.VerifyConnection = mtrc.AfterHandshake(
string(srv.Protocol),
srv.Name,
tlsConfig.DeviceDomains,

View File

@ -1,11 +1,13 @@
package cmd
import (
"context"
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
)
@ -17,6 +19,8 @@ type serverGroups []*serverGroup
// toInternal returns the configuration for all server groups in the DNS
// service. srvGrps and other parts of the configuration must be valid.
func (srvGrps serverGroups) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
messages *dnsmsg.Constructor,
btdMgr *bindtodevice.Manager,
fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup,
@ -32,7 +36,7 @@ func (srvGrps serverGroups) toInternal(
}
var tlsConf *agd.TLS
tlsConf, err = g.TLS.toInternal()
tlsConf, err = g.TLS.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("tls: %w", err)
}
@ -45,7 +49,13 @@ func (srvGrps serverGroups) toInternal(
ProfilesEnabled: g.ProfilesEnabled,
}
svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf, btdMgr, ratelimitConf, dnsConf)
svcSrvGrps[i].Servers, err = g.Servers.toInternal(
mtrc,
tlsConf,
btdMgr,
ratelimitConf,
dnsConf,
)
if err != nil {
return nil, fmt.Errorf("server group %q: %w", g.Name, err)
}

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
@ -19,6 +19,7 @@ import (
type ticketRotator struct {
logger *slog.Logger
errColl errcoll.Interface
mtrc tlsconfig.Metrics
confs map[*tls.Config][]string
}
@ -29,6 +30,7 @@ type ticketRotator struct {
func newTicketRotator(
logger *slog.Logger,
errColl errcoll.Interface,
mtrc tlsconfig.Metrics,
grps []*agd.ServerGroup,
) (tr *ticketRotator) {
confs := map[*tls.Config][]string{}
@ -49,6 +51,7 @@ func newTicketRotator(
return &ticketRotator{
logger: logger.With(slogutil.KeyPrefix, "tickrot"),
errColl: errColl,
mtrc: mtrc,
confs: confs,
}
}
@ -81,7 +84,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
var key [sessTickLen]byte
key, err = readSessionTicketKey(fileName)
if err != nil {
metrics.TLSSessionTicketsRotateStatus.Set(0)
r.mtrc.SetSessionTicketRotationStatus(ctx, false)
return fmt.Errorf("session ticket for srv %s: %w", conf.ServerName, err)
}
@ -96,8 +99,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
conf.SetSessionTicketKeys(keys)
}
metrics.TLSSessionTicketsRotateStatus.Set(1)
metrics.TLSSessionTicketsRotateTime.SetToCurrentTime()
r.mtrc.SetSessionTicketRotationStatus(ctx, true)
return nil
}

View File

@ -1,6 +1,7 @@
package cmd
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
@ -9,10 +10,9 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/prometheus/client_golang/prometheus"
)
// tlsConfig are the TLS settings of a DNS server, if any.
@ -37,12 +37,15 @@ type tlsConfig struct {
// toInternal converts c to the TLS configuration for a DNS server. c must be
// valid.
func (c *tlsConfig) toInternal() (conf *agd.TLS, err error) {
func (c *tlsConfig) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *agd.TLS, err error) {
if c == nil {
return nil, nil
}
tlsConf, err := c.Certificates.toInternal()
tlsConf, err := c.Certificates.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("certificates: %w", err)
}
@ -123,7 +126,10 @@ type tlsConfigCert struct {
type tlsConfigCerts []*tlsConfigCert
// toInternal converts certs to a TLS configuration. certs must be valid.
func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
func (certs tlsConfigCerts) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *tls.Config, err error) {
if len(certs) == 0 {
return nil, nil
}
@ -146,13 +152,8 @@ func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
tlsCerts[i] = cert
authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String()
metrics.TLSCertificateInfo.With(prometheus.Labels{
"auth_algo": authAlgo,
"subject": subj,
}).Set(1)
metrics.TLSCertificateNotAfter.With(prometheus.Labels{
"subject": subj,
}).Set(float64(leaf.NotAfter.Unix()))
mtrc.SetCertificateInfo(ctx, authAlgo, subj, leaf.NotAfter)
}
return &tls.Config{

View File

@ -1,6 +1,7 @@
package cmd
import (
"context"
"encoding/base64"
"fmt"
"maps"
@ -13,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
@ -62,9 +64,11 @@ type webConfig struct {
// toInternal converts c to the AdGuardDNS web service configuration. c must be
// valid.
func (c *webConfig) toInternal(
ctx context.Context,
envs *environment,
dnsCk dnscheck.Interface,
errColl errcoll.Interface,
mtrc tlsconfig.Metrics,
) (conf *websvc.Config, err error) {
if c == nil {
return nil, nil
@ -83,7 +87,7 @@ func (c *webConfig) toInternal(
conf.RootRedirectURL = netutil.CloneURL(&c.RootRedirectURL.URL)
}
conf.LinkedIP, err = c.LinkedIP.toInternal(envs.LinkedIPTargetURL)
conf.LinkedIP, err = c.LinkedIP.toInternal(ctx, mtrc, envs.LinkedIPTargetURL)
if err != nil {
return nil, fmt.Errorf("converting linked_ip: %w", err)
}
@ -107,7 +111,7 @@ func (c *webConfig) toInternal(
}}
for _, bp := range blockPages {
*bp.webConfPtr, err = bp.conf.toInternal()
*bp.webConfPtr, err = bp.conf.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("%s: %w", bp.name, err)
}
@ -119,7 +123,7 @@ func (c *webConfig) toInternal(
return nil, err
}
conf.NonDoHBind, err = c.NonDoHBind.toInternal()
conf.NonDoHBind, err = c.NonDoHBind.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("converting non_doh_bind: %w", err)
}
@ -225,6 +229,8 @@ type linkedIPServer struct {
// toInternal converts s to a linkedIP server configuration. s must be valid.
func (s *linkedIPServer) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
targetURL *urlutil.URL,
) (srv *websvc.LinkedIPServer, err error) {
if s == nil {
@ -232,7 +238,7 @@ func (s *linkedIPServer) toInternal(
}
srv = &websvc.LinkedIPServer{}
srv.Bind, err = s.Bind.toInternal()
srv.Bind, err = s.Bind.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("converting bind: %w", err)
}
@ -280,7 +286,10 @@ type blockPageServer struct {
}
// toInternal converts s to a block page server configuration. s must be valid.
func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err error) {
func (s *blockPageServer) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *websvc.BlockPageServerConfig, err error) {
if s == nil {
return nil, nil
}
@ -289,7 +298,7 @@ func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err
ContentFilePath: s.BlockPage,
}
conf.Bind, err = s.Bind.toInternal()
conf.Bind, err = s.Bind.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("converting bind: %w", err)
}
@ -326,11 +335,14 @@ type bindData []*bindItem
// toInternal converts bd to bind data for the AdGuard DNS web service. bd must
// be valid.
func (bd bindData) toInternal() (data []*websvc.BindData, err error) {
func (bd bindData) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (data []*websvc.BindData, err error) {
data = make([]*websvc.BindData, len(bd))
for i, d := range bd {
data[i], err = d.toInternal()
data[i], err = d.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
@ -369,8 +381,11 @@ type bindItem struct {
// toInternal converts i to bind data for the AdGuard DNS web service. i must
// be valid.
func (i *bindItem) toInternal() (data *websvc.BindData, err error) {
tlsConf, err := i.Certificates.toInternal()
func (i *bindItem) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (data *websvc.BindData, err error) {
tlsConf, err := i.Certificates.toInternal(ctx, mtrc)
if err != nil {
return nil, fmt.Errorf("certificates: %w", err)
}

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
@ -28,6 +27,7 @@ type AllowlistUpdater struct {
http *agdhttp.Client
url *url.URL
errColl errcoll.Interface
metrics Metrics
}
// AllowlistUpdaterConfig is the configuration structure for the allowlist
@ -45,6 +45,9 @@ type AllowlistUpdaterConfig struct {
// ErrColl is used to collect errors during refreshes.
ErrColl errcoll.Interface
// Metrics is used to collect allowlist statistics.
Metrics Metrics
// Timeout is the timeout for Consul queries.
Timeout time.Duration
}
@ -60,6 +63,7 @@ func NewAllowlistUpdater(c *AllowlistUpdaterConfig) (upd *AllowlistUpdater) {
}),
url: c.ConsulURL,
errColl: c.ErrColl,
metrics: c.Metrics,
}
}
@ -72,10 +76,7 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
upd.logger.InfoContext(ctx, "refresh started")
defer upd.logger.InfoContext(ctx, "refresh finished")
defer func() {
metrics.ConsulAllowlistUpdateTime.SetToCurrentTime()
metrics.SetStatusGauge(metrics.ConsulAllowlistUpdateStatus, err)
}()
defer func() { upd.metrics.SetStatus(ctx, err) }()
consulNets, err := upd.loadConsul(ctx)
if err != nil {
@ -89,13 +90,11 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
ctx,
"refresh successful",
"num_records", len(consulNets),
"url", &urlutil.URL{
URL: *upd.url,
},
"url", urlutil.RedactUserinfo(upd.url),
)
upd.allowlist.Update(consulNets)
metrics.ConsulAllowlistSize.Set(float64(len(consulNets)))
upd.metrics.SetSize(ctx, len(consulNets))
return nil
}
@ -107,7 +106,7 @@ type consulRecord struct {
// loadConsul fetches, decodes, and returns the list of IP networks from consul.
func (upd *AllowlistUpdater) loadConsul(ctx context.Context) (nets []netip.Prefix, err error) {
defer func() { err = errors.Annotate(err, "loading allowlist nets from %s: %w", upd.url) }()
defer func() { err = errors.Annotate(err, "loading allowlist nets: %w") }()
httpResp, err := upd.http.Get(ctx, upd.url)
if err != nil {

View File

@ -83,6 +83,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
Allowlist: al,
ConsulURL: u,
ErrColl: agdtest.NewErrorCollector(),
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout,
})
@ -128,6 +129,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
Allowlist: al,
ConsulURL: u,
ErrColl: errColl,
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout,
})
@ -161,6 +163,7 @@ func TestAllowlistUpdater_Refresh_deadline(t *testing.T) {
Allowlist: al,
ConsulURL: u,
ErrColl: errColl,
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout,
})

View File

@ -0,0 +1,26 @@
package consul
import "context"
// Metrics is an interface that is used for the collection of the allowlist
// statistics.
type Metrics interface {
// SetSize sets the number of received subnets.
SetSize(ctx context.Context, n int)
// SetStatus sets the status and time of the allowlist refresh attempt.
SetStatus(ctx context.Context, err error)
}
// EmptyMetrics is the implementation of the [Metrics] interface that does
// nothing.
type EmptyMetrics struct{}
// type check
var _ Metrics = EmptyMetrics{}
// SetSize implements the [Metrics] interface for EmptyMetrics.
func (EmptyMetrics) SetSize(_ context.Context, _ int) {}
// SetStatus plements the [Metrics] interface for EmptyMetrics.
func (EmptyMetrics) SetStatus(_ context.Context, _ error) {}

View File

@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -92,7 +93,7 @@ func TestService_Start(t *testing.T) {
})
srvURL := &url.URL{
Scheme: agdhttp.SchemeHTTP,
Scheme: urlutil.SchemeHTTP,
Host: addr,
}

View File

@ -201,18 +201,21 @@ func (cc *RemoteKV) newInfo(ri *agd.RequestInfo) (inf *info) {
}
// resp returns the corresponding response.
//
// TODO(e.burkov): Inspect the reason for using different message constructors
// for different DNS types, and consider using only one of them.
func (cc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) {
qt := ri.QType
if qt != dns.TypeA && qt != dns.TypeAAAA {
return ri.Messages.NewMsgNODATA(req), nil
return ri.Messages.NewRespRCode(req, dns.RcodeSuccess), nil
}
if qt == dns.TypeA {
return cc.messages.NewIPRespMsg(req, cc.ipv4...)
return cc.messages.NewRespIP(req, cc.ipv4...)
}
return cc.messages.NewIPRespMsg(req, cc.ipv6...)
return cc.messages.NewRespIP(req, cc.ipv6...)
}
// type check

View File

@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -86,7 +87,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
t.Run("hit", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, (&url.URL{
Scheme: "http",
Scheme: urlutil.SchemeHTTP,
Host: randomid + "-" + checkDomain,
Path: "/dnscheck/test",
}).String(), strings.NewReader(""))
@ -103,7 +104,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
t.Run("miss", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, (&url.URL{
Scheme: "http",
Scheme: urlutil.SchemeHTTP,
Host: "non" + randomid + "-" + checkDomain,
Path: "/dnscheck/test",
}).String(), strings.NewReader(""))

View File

@ -18,13 +18,14 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefault_ServeHTTP(t *testing.T) {
const dname = "some-domain.name"
const domain = "domain.example"
testIP := netip.MustParseAddr("1.2.3.4")
@ -39,7 +40,7 @@ func TestDefault_ServeHTTP(t *testing.T) {
rcode,
dnsservertest.NewReq(name, qtype, dns.ClassINET),
dnsservertest.SectionAnswer{
dnsservertest.NewA(dname, 0, testIP),
dnsservertest.NewA(domain, 0, testIP),
},
)
}
@ -52,38 +53,38 @@ func TestDefault_ServeHTTP(t *testing.T) {
}{{
name: "single",
msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
},
wantHdr: successHdr,
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`)},
wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,1`)},
}, {
name: "existing",
msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
},
wantHdr: successHdr,
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,2`)},
wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,2`)},
}, {
name: "different",
msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
newMsg(dns.RcodeSuccess, "sub."+dname, dns.TypeA),
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
newMsg(dns.RcodeSuccess, "sub."+domain, dns.TypeA),
},
wantHdr: successHdr,
wantResp: [][]byte{
[]byte("sub." + dname + `,A,NOERROR,` + testIP.String() + `,1`),
[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`),
[]byte("sub." + domain + `,A,NOERROR,` + testIP.String() + `,1`),
[]byte(domain + `,A,NOERROR,` + testIP.String() + `,1`),
},
}, {
name: "non-recordable",
msgs: []*dns.Msg{
// Not NOERROR.
newMsg(dns.RcodeBadName, dname, dns.TypeA),
newMsg(dns.RcodeBadName, domain, dns.TypeA),
// Not A/AAAA.
newMsg(dns.RcodeSuccess, dname, dns.TypeSRV),
newMsg(dns.RcodeSuccess, domain, dns.TypeSRV),
// Android metrics.
newMsg(dns.RcodeSuccess, dname+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA),
newMsg(dns.RcodeSuccess, domain+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA),
},
wantHdr: successHdr,
wantResp: [][]byte{},
@ -109,7 +110,10 @@ func TestDefault_ServeHTTP(t *testing.T) {
r := httptest.NewRequest(
http.MethodGet,
(&url.URL{Scheme: "http", Host: "example.com"}).String(),
(&url.URL{
Scheme: urlutil.SchemeHTTP,
Host: "dnsdb.example",
}).String(),
nil,
)
r.Header.Add(httphdr.AcceptEncoding, agdhttp.HdrValGzip)

View File

@ -15,6 +15,11 @@ type ConstructorConfig struct {
// Cloner used to clone DNS messages. It must not be nil.
Cloner *Cloner
// StructuredErrors is the configuration for the experimental Structured DNS
// Errors feature. It must not be nil. If enabled,
// [ConstructorConfig.Enabled] should also be true.
StructuredErrors *StructuredDNSErrorsConfig
// BlockingMode is the blocking mode to use in
// [Constructor.NewBlockedRespMsg]. It must not be nil.
BlockingMode BlockingMode
@ -22,6 +27,9 @@ type ConstructorConfig struct {
// FilteredResponseTTL is the time-to-live value used for responses created
// by this message constructor. It must be non-negative.
FilteredResponseTTL time.Duration
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes.
EDEEnabled bool
}
// validate checks the configuration for errors.
@ -33,13 +41,19 @@ func (conf *ConstructorConfig) validate() (err error) {
errs = append(errs, err)
}
err = conf.StructuredErrors.validate(conf.EDEEnabled)
if err != nil {
err = fmt.Errorf("structured errors: %w", err)
errs = append(errs, err)
}
if conf.BlockingMode == nil {
err = fmt.Errorf("blocking mode: %w", errors.ErrNoValue)
errs = append(errs, err)
}
if conf.FilteredResponseTTL < 0 {
err = fmt.Errorf("filtered response TTL: %w", errors.ErrNegative)
err = fmt.Errorf("filtered response ttl: %w", errors.ErrNegative)
errs = append(errs, err)
}
@ -51,7 +65,9 @@ func (conf *ConstructorConfig) validate() (err error) {
type Constructor struct {
cloner *Cloner
blockingMode BlockingMode
sde string
fltRespTTL time.Duration
edeEnabled bool
}
// NewConstructor returns a properly initialized constructor using conf.
@ -60,10 +76,17 @@ func NewConstructor(conf *ConstructorConfig) (c *Constructor, err error) {
return nil, fmt.Errorf("configuration: %w", err)
}
var sde string
if sdeConf := conf.StructuredErrors; sdeConf.Enabled {
sde = sdeConf.iJSON()
}
return &Constructor{
cloner: conf.Cloner,
blockingMode: conf.BlockingMode,
sde: sde,
fltRespTTL: conf.FilteredResponseTTL,
edeEnabled: conf.EDEEnabled,
}, nil
}
@ -72,156 +95,6 @@ func (c *Constructor) Cloner() (cloner *Cloner) {
return c.cloner
}
// FilteredResponseTTL returns the TTL that the constructor uses to build
// blocked responses.
func (c *Constructor) FilteredResponseTTL() (ttl time.Duration) {
return c.fltRespTTL
}
// NewBlockedRespMsg returns a blocked DNS response message based on the
// constructor's blocking mode.
func (c *Constructor) NewBlockedRespMsg(req *dns.Msg) (msg *dns.Msg, err error) {
switch m := c.blockingMode.(type) {
case *BlockingModeCustomIP:
return c.newBlockedCustomIPRespMsg(req, m)
case *BlockingModeNullIP:
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA, dns.TypeAAAA:
return c.NewIPRespMsg(req, netip.Addr{})
default:
return c.NewMsgNODATA(req), nil
}
case *BlockingModeNXDOMAIN:
return c.NewMsgNXDOMAIN(req), nil
case *BlockingModeREFUSED:
return c.NewMsgREFUSED(req), nil
default:
// Consider unhandled sum type members as unrecoverable programmer
// errors.
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
}
}
// newBlockedCustomIPRespMsg returns a blocked DNS response message with either
// the custom IPs from the blocking mode options or a NODATA one.
func (c *Constructor) newBlockedCustomIPRespMsg(
req *dns.Msg,
m *BlockingModeCustomIP,
) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
if len(m.IPv4) > 0 {
return c.NewIPRespMsg(req, m.IPv4...)
}
case dns.TypeAAAA:
if len(m.IPv6) > 0 {
return c.NewIPRespMsg(req, m.IPv6...)
}
default:
// Go on.
}
return c.NewMsgNODATA(req), nil
}
// NewIPRespMsg returns a DNS A or AAAA response message with the given IP
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
// null) IP. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return c.newMsgA(req, ips...)
case dns.TypeAAAA:
return c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
}
// NewCNAMEWithIPs generates a filtered response to req with CNAME record and
// provided ips. cname is the fully-qualified name and must not be empty, ips
// must be of the same family.
func (c *Constructor) NewCNAMEWithIPs(
req *dns.Msg,
cname string,
ips ...netip.Addr,
) (resp *dns.Msg, err error) {
resp = c.NewRespMsg(req)
resp.Answer = make([]dns.RR, 0, len(ips)+1)
resp.Answer = append(resp.Answer, c.NewAnswerCNAME(req, cname))
var ans dns.RR
for i, ip := range ips {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
ans, err = c.NewAnswerA(cname, ip)
case dns.TypeAAAA:
ans, err = c.NewAnswerAAAA(cname, ip)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
if err != nil {
return nil, fmt.Errorf("bad ip at idx %d: %w", i, err)
}
resp.Answer = append(resp.Answer, ans)
}
return resp, err
}
// NewMsgFORMERR returns a properly initialized FORMERR response.
func (c *Constructor) NewMsgFORMERR(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeFormatError)
}
// NewMsgNXDOMAIN returns a properly initialized NXDOMAIN response.
func (c *Constructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeNameError)
}
// NewMsgREFUSED returns a properly initialized REFUSED response.
func (c *Constructor) NewMsgREFUSED(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeRefused)
}
// NewMsgSERVFAIL returns a properly initialized SERVFAIL response.
func (c *Constructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeServerFailure)
}
// NewMsgNODATA returns a properly initialized NODATA response.
//
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (c *Constructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeSuccess)
}
// newMsgRCode returns a properly initialized response with the given RCode.
func (c *Constructor) newMsgRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, int(rc))
resp.Ns = c.newSOARecords(req)
resp.RecursionAvailable = true
return resp
}
// NewTXTRespMsg returns a DNS TXT response message with the given strings as
// content. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewTXTRespMsg(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
ans, err := c.NewAnswerTXT(req, strs)
if err != nil {
return nil, err
}
msg = c.NewRespMsg(req)
msg.Answer = append(msg.Answer, ans)
return msg, nil
}
// AppendDebugExtra appends to response message a DNS TXT extra with CHAOS
// class.
func (c *Constructor) AppendDebugExtra(req, resp *dns.Msg, str string) (err error) {
@ -386,26 +259,23 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
zone = req.Question[0].Name
}
// TODO(a.garipov): A lot of this is copied from AdGuard Home and needs
// to be inspected and refactored.
// TODO(a.garipov): A lot of this is copied from AdGuard Home and needs to
// be inspected and refactored.
soa := &dns.SOA{
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
// Use values from verisign's nonexistent.com domain. Their exact
// values are not important in our use case because they are used for
// domain transfers between primary/secondary DNS servers.
Refresh: 1800,
Retry: 900,
Expire: 604800,
Minttl: 86400,
// copied from AdGuard DNS
// Copied from AdGuard DNS.
Ns: "fake-for-negative-caching.adguard.com.",
Serial: 100500,
// rest is request-specific
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: uint32(c.fltRespTTL.Seconds()),
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
// Rest is request-specific.
Hdr: c.newHdrWithClass(zone, dns.TypeSOA, dns.ClassINET),
// Zone will be appended later if it's not empty or ".".
Mbox: "hostmaster.",
}
if len(zone) > 0 && zone[0] != '.' {
@ -415,25 +285,10 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
return []dns.RR{soa}
}
// NewRespMsg creates a DNS response for req and sets all necessary flags and
// fields. It also guarantees that req.Question will be not empty.
func (c *Constructor) NewRespMsg(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
resp.SetReply(req)
return resp
}
// newMsgA returns a new DNS response with the given IPv4 addresses. If any IP
// address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0.
func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req)
msg = c.NewResp(req)
for i, ip := range ips {
var ans dns.RR
ans, err = c.NewAnswerA(req.Question[0].Name, ip)
@ -450,7 +305,7 @@ func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, er
// newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any
// IP address is nil, it is replaced by an unspecified (aka null) IP, [::].
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req)
msg = c.NewResp(req)
for i, ip := range ips {
var ans dns.RR
ans, err = c.NewAnswerAAAA(req.Question[0].Name, ip)

View File

@ -1,18 +1,16 @@
package dnsmsg_test
import (
"net"
"net/netip"
"net/url"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTXTExtra is a helper constructor of the expected extra data.
@ -27,311 +25,88 @@ func newTXTExtra(ttl uint32, strs ...string) (extra []dns.RR) {
}}
}
// newConstructor returns a new dnsmsg.Constructor with [testFltRespTTL].
func newConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{}),
func TestNewConstructor(t *testing.T) {
t.Parallel()
cloner := agdtest.NewCloner()
badContactURL := errors.Must(url.Parse("invalid-scheme://devteam@adguard.com"))
testCases := []struct {
name string
conf *dnsmsg.ConstructorConfig
wantErrMsg string
}{{
name: "good",
conf: &dnsmsg.ConstructorConfig{
Cloner: cloner,
StructuredErrors: agdtest.NewSDEConfig(true),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(tb, err)
return msgs
}
func TestConstructor_NewBlockedRespMsg_nullIP(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
testCases := []struct {
name string
wantAnsNum int
qt dnsmsg.RRType
}{{
name: "a",
wantAnsNum: 1,
qt: dns.TypeA,
}, {
name: "aaaa",
wantAnsNum: 1,
qt: dns.TypeAAAA,
}, {
name: "txt",
wantAnsNum: 0,
qt: dns.TypeTXT,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET)
resp, respErr := msgs.NewBlockedRespMsg(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
if tc.wantAnsNum == 0 {
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
} else {
require.Len(t, resp.Answer, 1)
ansTTL := resp.Answer[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
}
})
}
}
func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
t.Parallel()
cloner := agdtest.NewCloner()
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
wantA bool
wantAAAA bool
}{{
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
IPv6: []netip.Addr{testIPv6},
EDEEnabled: true,
},
name: "both",
wantA: true,
wantAAAA: true,
wantErrMsg: "",
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
name: "all_bad",
conf: &dnsmsg.ConstructorConfig{
FilteredResponseTTL: -1,
},
name: "ipv4_only",
wantA: true,
wantAAAA: false,
wantErrMsg: "configuration: " +
"cloner: no value\n" +
"structured errors: no value\n" +
"blocking mode: no value\n" +
"filtered response ttl: negative value",
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv6: []netip.Addr{testIPv6},
},
name: "ipv6_only",
wantA: false,
wantAAAA: true,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{},
IPv6: []netip.Addr{},
},
name: "empty",
wantA: false,
wantAAAA: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
name: "sde_enabled",
conf: &dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: agdtest.NewSDEConfig(true),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(t, err)
reqA := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
respA, err := msgs.NewBlockedRespMsg(reqA)
require.NoError(t, err)
require.NotNil(t, respA)
assert.Equal(t, dns.RcodeSuccess, respA.Rcode)
if tc.wantA {
require.Len(t, respA.Answer, 1)
a := testutil.RequireTypeAssert[*dns.A](t, respA.Answer[0])
assert.Equal(t, net.IP(testIPv4.AsSlice()), a.A)
} else {
assert.Empty(t, respA.Answer)
}
reqAAAA := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET)
respAAAA, err := msgs.NewBlockedRespMsg(reqAAAA)
require.NoError(t, err)
require.NotNil(t, respAAAA)
assert.Equal(t, dns.RcodeSuccess, respAAAA.Rcode)
if tc.wantAAAA {
require.Len(t, respAAAA.Answer, 1)
aaaa := testutil.RequireTypeAssert[*dns.AAAA](t, respAAAA.Answer[0])
assert.Equal(t, net.IP(testIPv6.AsSlice()), aaaa.AAAA)
} else {
assert.Empty(t, respAAAA.Answer)
}
})
}
}
func TestConstructor_NewBlockedRespMsg_noAnswer(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
cloner := agdtest.NewCloner()
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
rcode dnsmsg.RCode
}{{
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
name: "nxdomain",
rcode: dns.RcodeNameError,
EDEEnabled: false,
},
wantErrMsg: "configuration: structured errors: " +
"ede must be enabled to enable sde",
}, {
blockingMode: &dnsmsg.BlockingModeREFUSED{},
name: "refused",
rcode: dns.RcodeRefused,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
name: "sde_empty",
conf: &dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
Enabled: true,
},
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedRespMsg(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
})
}
}
func TestConstructor_noAnswerMethods(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
testCases := []struct {
method func(req *dns.Msg) (resp *dns.Msg)
name string
want dnsmsg.RCode
}{{
method: msgs.NewMsgFORMERR,
name: "formerr",
want: dns.RcodeFormatError,
EDEEnabled: true,
},
wantErrMsg: "configuration: structured errors: " +
"contact data: empty value\n" +
"justification: empty value",
}, {
method: msgs.NewMsgNXDOMAIN,
name: "nxdomain",
want: dns.RcodeNameError,
}, {
method: msgs.NewMsgREFUSED,
name: "refused",
want: dns.RcodeRefused,
}, {
method: msgs.NewMsgSERVFAIL,
name: "servfail",
want: dns.RcodeServerFailure,
}, {
method: msgs.NewMsgNODATA,
name: "nodata",
want: dns.RcodeSuccess,
name: "sde_bad",
conf: &dnsmsg.ConstructorConfig{
Cloner: cloner,
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
Enabled: true,
Contact: []*url.URL{badContactURL, nil},
Justification: "\uFFFE",
Organization: "\uFFFE",
},
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
},
wantErrMsg: "configuration: structured errors: " +
`contact data: at index 0: scheme: bad enum value: "invalid-scheme"` + "\n" +
"contact data: at index 1: no value\n" +
"justification: bad code point at index 0\n" +
"organization: bad code point at index 0",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp := tc.method(req)
require.NotNil(t, resp)
require.Len(t, resp.Ns, 1)
assert.Empty(t, resp.Answer)
assert.Equal(t, tc.want, dnsmsg.RCode(resp.Rcode))
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
})
}
}
func TestConstructor_NewTXTRespMsg(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET)
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
testCases := []struct {
name string
wantErrMsg string
strs []string
}{{
name: "success",
wantErrMsg: "",
strs: []string{"111"},
}, {
name: "success_many",
wantErrMsg: "",
strs: []string{"111", "222"},
}, {
name: "success_nil",
wantErrMsg: "",
strs: nil,
}, {
name: "success_empty",
wantErrMsg: "",
strs: []string{},
}, {
name: "too_long",
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255",
strs: []string{tooLong},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp, respErr := msgs.NewTXTRespMsg(req, tc.strs...)
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr)
if tc.wantErrMsg != "" {
return
}
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ans := resp.Answer[0]
ansTTL := ans.Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
assert.Equal(t, tc.strs, txt.Txt)
_, err := dnsmsg.NewConstructor(tc.conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
@ -339,7 +114,7 @@ func TestConstructor_NewTXTRespMsg(t *testing.T) {
func TestConstructor_AppendDebugExtra(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
msgs := agdtest.NewConstructor(t)
shortText := "This is a short test text"
longText := strings.Repeat("a", 2*dnsmsg.MaxTXTStringLen)

View File

@ -78,7 +78,9 @@ func TestClone(t *testing.T) {
},
name: "empty_slice_ans",
}, {
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}),
name: "a",
}}

View File

@ -72,8 +72,7 @@ func (c *optCloner) clone(rr *dns.OPT) (clone *dns.OPT, full bool) {
optClone = opt
case *dns.EDNS0_EDE:
opt := c.ede.Get()
opt.InfoCode = orig.InfoCode
opt.ExtraText = orig.ExtraText
*opt = *orig
optClone = opt
case *dns.EDNS0_SUBNET:

228
internal/dnsmsg/response.go Normal file
View File

@ -0,0 +1,228 @@
package dnsmsg
import (
"fmt"
"net/netip"
"github.com/miekg/dns"
)
// NewResp creates a response DNS message for req and sets all necessary flags
// and fields. resp contains no resource records.
func (c *Constructor) NewResp(req *dns.Msg) (resp *dns.Msg) {
return (&dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}).SetReply(req)
}
// NewBlockedResp returns a blocked response DNS message based on the
// constructor's blocking mode.
func (c *Constructor) NewBlockedResp(req *dns.Msg) (msg *dns.Msg, err error) {
switch m := c.blockingMode.(type) {
case *BlockingModeCustomIP:
return c.newBlockedCustomIPResp(req, m)
case *BlockingModeNullIP:
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA, dns.TypeAAAA:
return c.NewBlockedNullIPResp(req)
default:
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
msg.Ns = c.newSOARecords(req)
}
case *BlockingModeNXDOMAIN:
msg = c.NewBlockedRespRCode(req, dns.RcodeNameError)
msg.Ns = c.newSOARecords(req)
case *BlockingModeREFUSED:
msg = c.NewBlockedRespRCode(req, dns.RcodeRefused)
msg.Ns = c.newSOARecords(req)
default:
// Consider unhandled sum type members as unrecoverable programmer
// errors.
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
}
return msg, nil
}
// NewRespRCode returns a response DNS message with given response code and a
// predefined authority section.
//
// Use [dns.RcodeSuccess] for a proper NODATA response, see
// https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (c *Constructor) NewRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = c.NewResp(req)
resp.Rcode = int(rc)
resp.Ns = c.newSOARecords(req)
return resp
}
// NewBlockedRespRCode returns a blocked response DNS message with given
// response code.
//
// TODO(e.burkov): Add SOA records to the response, like in
// [Constructor.NewRespRCode].
func (c *Constructor) NewBlockedRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = c.NewResp(req)
resp.Rcode = int(rc)
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
return resp
}
// NewRespTXT returns a DNS TXT response message with the given strings as
// content. The TTL of the TXT answer is set to c.FilteredResponseTTL.
func (c *Constructor) NewRespTXT(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
ans, err := c.NewAnswerTXT(req, strs)
if err != nil {
return nil, err
}
msg = c.NewResp(req)
msg.Answer = append(msg.Answer, ans)
return msg, nil
}
// NewRespIP returns an A or AAAA DNS response message with the given IP
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
// null) IP. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return c.newMsgA(req, ips...)
case dns.TypeAAAA:
return c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
}
// NewBlockedRespIP returns an A or AAAA DNS response message with the given IP
// addresses. The TTL of each record is set to c.FilteredResponseTTL. ips
// should not contain zero values due to the extended error code semantics, use
// [NewBlockedNullIPResp] for this case.
//
// TODO(a.garipov): Consider merging with [NewRespIP] if AddEDE with the Forged
// Answer code isn't used again.
func (c *Constructor) NewBlockedRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
msg, err = c.newMsgA(req, ips...)
case dns.TypeAAAA:
msg, err = c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for an ip resp: %d", qt)
}
if err != nil {
return nil, err
}
return msg, nil
}
// NewBlockedNullIPResp returns a blocked A or AAAA DNS response message with an
// unspecified (aka null) IP address. The TTL of the record is set to the
// constructor's FilteredResponseTTL.
func (c *Constructor) NewBlockedNullIPResp(req *dns.Msg) (resp *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
resp, err = c.newMsgA(req, netip.Addr{})
case dns.TypeAAAA:
resp, err = c.newMsgAAAA(req, netip.Addr{})
default:
err = fmt.Errorf("bad qtype for an ip resp: %d", qt)
}
if err != nil {
return nil, err
}
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
return resp, nil
}
// AddEDE adds an Extended DNS Error (EDE) option to the blocked response
// message, if the feature is enabled in the Constructor and the request
// indicates EDNS support. It does not overwrite EDE if there already is one.
// req and resp must not be nil.
func (c *Constructor) AddEDE(req, resp *dns.Msg, code uint16) {
if !c.edeEnabled {
return
}
reqOpt := req.IsEdns0()
if reqOpt == nil {
// Requestor doesn't implement EDNS, see
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
return
}
respOpt := resp.IsEdns0()
if respOpt == nil {
respOpt = newOPT(c.cloner, reqOpt.UDPSize(), reqOpt.Do())
resp.Extra = append(resp.Extra, respOpt)
} else if findEDE(respOpt) != nil {
// Do not add an EDE option if there already is one.
return
}
sdeText := c.sdeForReqOpt(reqOpt)
respOpt.Option = append(respOpt.Option, newEDNS0EDE(c.cloner, code, sdeText))
}
// findEDE returns the EDE option if there is one. opt must not be nil.
func findEDE(opt *dns.OPT) (ede *dns.EDNS0_EDE) {
for _, o := range opt.Option {
var ok bool
if ede, ok = o.(*dns.EDNS0_EDE); ok {
return ede
}
}
return nil
}
// sdeForReqOpt returns either the configured SDE text or empty string depending
// on the request's EDNS options.
func (c *Constructor) sdeForReqOpt(reqOpt *dns.OPT) (sde string) {
ede := findEDE(reqOpt)
if ede != nil && ede.InfoCode == 0 && ede.ExtraText == "" {
return c.sde
}
return ""
}
// newBlockedCustomIPResp returns a blocked DNS response message with either the
// custom IPs from the blocking mode options or a NODATA one.
func (c *Constructor) newBlockedCustomIPResp(
req *dns.Msg,
m *BlockingModeCustomIP,
) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
if len(m.IPv4) > 0 {
return c.NewBlockedRespIP(req, m.IPv4...)
}
case dns.TypeAAAA:
if len(m.IPv6) > 0 {
return c.NewBlockedRespIP(req, m.IPv6...)
}
default:
// Go on.
}
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
msg.Ns = c.newSOARecords(req)
return msg, nil
}

View File

@ -0,0 +1,416 @@
package dnsmsg_test
import (
"net/netip"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConstructor_NewBlockedResp_nullIP(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
reqExtra := dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}
filteredSDE := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})
testCases := []struct {
name string
wantAns []dns.RR
wantExtra []dns.RR
qt dnsmsg.RRType
}{{
name: "a",
wantAns: []dns.RR{dnsservertest.NewA(
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified(),
)},
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeA,
}, {
name: "aaaa",
wantAns: []dns.RR{dnsservertest.NewAAAA(
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv6Unspecified(),
)},
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeAAAA,
}, {
name: "txt",
wantAns: nil,
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeTXT,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAns, resp.Answer)
assert.Equal(t, tc.wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewBlockedResp_customIP(t *testing.T) {
t.Parallel()
cloner := agdtest.NewCloner()
// TODO(a.garipov): Test the forged extra as well if the EDE with that code
// is used again.
reqExtra := dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}
filteredExtra := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})
ansA := dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv4)
ansAAAA := dnsservertest.NewAAAA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv6)
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
wantAnsA []dns.RR
wantAnsAAAA []dns.RR
wantExtraA []dns.RR
wantExtraAAAA []dns.RR
}{{
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
IPv6: []netip.Addr{testIPv6},
},
name: "both",
wantAnsA: []dns.RR{ansA},
wantAnsAAAA: []dns.RR{ansAAAA},
wantExtraA: nil,
wantExtraAAAA: nil,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
},
name: "ipv4_only",
wantAnsA: []dns.RR{ansA},
wantAnsAAAA: nil,
wantExtraA: nil,
wantExtraAAAA: []dns.RR{filteredExtra},
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv6: []netip.Addr{testIPv6},
},
name: "ipv6_only",
wantAnsA: nil,
wantAnsAAAA: []dns.RR{ansAAAA},
wantExtraA: []dns.RR{filteredExtra},
wantExtraAAAA: nil,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{},
IPv6: []netip.Addr{},
},
name: "empty",
wantAnsA: nil,
wantAnsAAAA: nil,
wantExtraA: []dns.RR{filteredExtra},
wantExtraAAAA: []dns.RR{filteredExtra},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
t.Run("a", func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAnsA, resp.Answer)
assert.Equal(t, tc.wantExtraA, resp.Extra)
})
t.Run("aaaa", func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAnsAAAA, resp.Answer)
assert.Equal(t, tc.wantExtraAAAA, resp.Extra)
})
})
}
}
func TestConstructor_NewBlockedResp_nodata(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
cloner := agdtest.NewCloner()
wantExtra := []dns.RR{dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})}
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
rcode dnsmsg.RCode
}{{
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
name: "nxdomain",
rcode: dns.RcodeNameError,
}, {
blockingMode: &dnsmsg.BlockingModeREFUSED{},
name: "refused",
rcode: dns.RcodeRefused,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedResp(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
assert.Equal(t, wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewBlockedResp_sde(t *testing.T) {
t.Parallel()
reqEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
reqNoEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
wantAns := []dns.RR{
dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified()),
}
testCases := []struct {
req *dns.Msg
sde *dnsmsg.StructuredDNSErrorsConfig
name string
wantExtra []dns.RR
ede bool
}{{
req: reqEDNS,
sde: agdtest.NewSDEConfig(true),
name: "ede_sde",
wantExtra: []dns.RR{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
}),
},
ede: true,
}, {
req: reqEDNS,
sde: agdtest.NewSDEConfig(false),
name: "ede_no_sde",
wantExtra: []dns.RR{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
}),
},
ede: true,
}, {
req: reqEDNS,
sde: agdtest.NewSDEConfig(false),
name: "no_ede",
wantExtra: nil,
ede: false,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(true),
name: "unsupported_ede_sde",
wantExtra: nil,
ede: true,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(false),
name: "unsupported_ede_no_sde",
wantExtra: nil,
ede: true,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(false),
name: "unsupported_no_ede",
wantExtra: nil,
ede: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: agdtest.NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: tc.sde,
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: tc.ede,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedResp(tc.req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, wantAns, resp.Answer)
assert.Equal(t, tc.wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewRespRCode(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
for rcode, name := range dns.RcodeToString {
t.Run(name, func(t *testing.T) {
t.Parallel()
resp := msgs.NewRespRCode(req, dnsmsg.RCode(rcode))
require.NotNil(t, resp)
require.Empty(t, resp.Answer)
assert.Equal(t, rcode, resp.Rcode)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
assert.Empty(t, resp.Extra)
})
}
}
func TestConstructor_NewRespTXT(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
testCases := []struct {
name string
wantErrMsg string
strs []string
}{{
name: "success",
wantErrMsg: "",
strs: []string{"111"},
}, {
name: "success_many",
wantErrMsg: "",
strs: []string{"111", "222"},
}, {
name: "success_nil",
wantErrMsg: "",
strs: nil,
}, {
name: "success_empty",
wantErrMsg: "",
strs: []string{},
}, {
name: "too_long",
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255",
strs: []string{tooLong},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp, respErr := msgs.NewRespTXT(req, tc.strs...)
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr)
if tc.wantErrMsg != "" {
return
}
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ans := resp.Answer[0]
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), txt.Hdr.Ttl)
assert.Equal(t, tc.strs, txt.Txt)
assert.Empty(t, resp.Extra)
})
}
}

View File

@ -140,3 +140,36 @@ func newTXT(c *Cloner, txt []string) (rr *dns.TXT) {
return rr
}
// newOPT constructs a new resource record of type OPT, optionally using c to
// allocate the structure.
func newOPT(c *Cloner, udpSize uint16, doBit bool) (opt *dns.OPT) {
if c == nil {
opt = &dns.OPT{}
} else {
opt = c.opt.rr.Get()
opt.Option = opt.Option[:0]
}
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.SetUDPSize(udpSize)
opt.SetDo(doBit)
return opt
}
// newEDNS0EDE constructs a new resource record of type EDNS0_EDE, optionally
// using c to allocate the structure.
func newEDNS0EDE(c *Cloner, infoCode uint16, extraText string) (opt *dns.EDNS0_EDE) {
if c == nil {
opt = &dns.EDNS0_EDE{}
} else {
opt = c.opt.ede.Get()
}
opt.InfoCode = infoCode
opt.ExtraText = extraText
return opt
}

View File

@ -0,0 +1,143 @@
package dnsmsg
import (
"encoding/json"
"fmt"
"net/url"
"strings"
"unicode"
"github.com/AdguardTeam/golibs/errors"
)
// StructuredDNSErrorsConfig is the configuration structure for the experimental
// Structured DNS Errors feature.
//
// See https://www.ietf.org/archive/id/draft-ietf-dnsop-structured-dns-error-09.html.
//
// TODO(a.garipov): Add sub-error?
type StructuredDNSErrorsConfig struct {
// Justification for this particular DNS filtering. It must not be empty.
Justification string
// Organization is an optional description of the organization.
Organization string
// Contact information for the DNS service. It must not be empty. All
// items must not be nil and must be valid mailto, sips, or tel URLs.
Contact []*url.URL
// Enabled, if true, enables the experimental Structured DNS Errors feature.
Enabled bool
}
// iJSON returns the I-JSON representation of this configuration. c must be
// valid.
func (c *StructuredDNSErrorsConfig) iJSON() (s string) {
data := &structuredDNSErrorData{
Justification: c.Justification,
Organization: c.Organization,
}
for _, cont := range c.Contact {
data.Contact = append(data.Contact, cont.String())
}
// The only error that could be returned here is a type error from JSON
// encoding, and these should never happen.
b := errors.Must(json.Marshal(data))
return string(b)
}
// structuredDNSErrorData is the structure for the JSON representation of the
// SDE data.
//
// TODO(a.garipov): Add sub-error?
type structuredDNSErrorData struct {
Justification string `json:"j"`
Organization string `json:"o,omitempty"`
Contact []string `json:"c"`
}
// forbiddenRanges contains the ranges of forbidden code points for structured
// DNS errors according to the I-JSON specification.
//
// See https://datatracker.ietf.org/doc/html/rfc7493#section-2.1.
var forbiddenRanges = []*unicode.RangeTable{unicode.Cs, unicode.Noncharacter_Code_Point}
// isSurrogateOrNonCharacter returns true if r is a surrogate or a non-character
// code point.
func isSurrogateOrNonCharacter(r rune) (ok bool) {
return unicode.IsOneOf(forbiddenRanges, r)
}
// validateSDEString returns an error if s contains a surrogate or a
// non-character code point. It always returns nil for an empty string.
func validateSDEString(s string) (err error) {
if i := strings.IndexFunc(s, isSurrogateOrNonCharacter); i >= 0 {
return fmt.Errorf("bad code point at index %d", i)
}
return nil
}
// validate checks the configuration for errors.
func (c *StructuredDNSErrorsConfig) validate(edeEnabled bool) (err error) {
if c == nil {
return errors.ErrNoValue
}
if !c.Enabled {
return nil
} else if !edeEnabled {
return errors.Error("ede must be enabled to enable sde")
}
var errs []error
if len(c.Contact) == 0 {
err = fmt.Errorf("contact data: %w", errors.ErrEmptyValue)
errs = append(errs, err)
}
for i, cont := range c.Contact {
err = validateSDEContactURL(cont)
if err != nil {
err = fmt.Errorf("contact data: at index %d: %w", i, err)
errs = append(errs, err)
}
}
if c.Justification == "" {
err = fmt.Errorf("justification: %w", errors.ErrEmptyValue)
errs = append(errs, err)
} else if err = validateSDEString(c.Justification); err != nil {
err = fmt.Errorf("justification: %w", err)
errs = append(errs, err)
}
if err = validateSDEString(c.Organization); err != nil {
err = fmt.Errorf("organization: %w", err)
errs = append(errs, err)
}
return errors.Join(errs...)
}
// validateSDEContactURL returns an error if u is not a valid SDE contact URL.
// It doesn't check for bad code points in the URL since [url.URL.String]
// escapes them.
func validateSDEContactURL(u *url.URL) (err error) {
if u == nil {
return errors.ErrNoValue
}
switch strings.ToLower(u.Scheme) {
case "mailto", "sips", "tel":
// TODO(a.garipov): Consider more thorough validations for each scheme.
default:
return fmt.Errorf("scheme: %w: %q", errors.ErrBadEnumValue, u.Scheme)
}
return nil
}

View File

@ -31,8 +31,8 @@ type Middleware struct {
// cacheMinTTL is the minimum supported TTL for cache items.
cacheMinTTL time.Duration
// useTTLOverride shows if the TTL overrides logic should be used.
useTTLOverride bool
// overrideTTL shows if the TTL overrides logic should be used.
overrideTTL bool
}
// MiddlewareConfig is the configuration structure for NewMiddleware.
@ -49,8 +49,8 @@ type MiddlewareConfig struct {
// MinTTL is the minimum supported TTL for cache items.
MinTTL time.Duration
// UseTTLOverride shows if the TTL overrides logic should be used.
UseTTLOverride bool
// OverrideTTL shows if the TTL overrides logic should be used.
OverrideTTL bool
}
// NewMiddleware initializes a new LRU caching middleware. c must not be nil.
@ -66,7 +66,7 @@ func NewMiddleware(c *MiddlewareConfig) (m *Middleware) {
metrics: metrics,
cache: gcache.New(c.Size).LRU().Build(),
cacheMinTTL: c.MinTTL,
useTTLOverride: c.UseTTLOverride,
overrideTTL: c.OverrideTTL,
}
}
@ -150,7 +150,7 @@ func (m *Middleware) set(msg *dns.Msg) (err error) {
}
exp := time.Duration(ttl) * time.Second
if m.useTTLOverride && msg.Rcode != dns.RcodeServerFailure {
if m.overrideTTL && msg.Rcode != dns.RcodeServerFailure {
exp = max(exp, m.cacheMinTTL)
setMinTTL(msg, uint32(exp.Seconds()))
}

View File

@ -188,7 +188,7 @@ func TestMiddleware_Wrap(t *testing.T) {
cache.NewMiddleware(&cache.MiddlewareConfig{
Size: 100,
MinTTL: minTTL,
UseTTLOverride: tc.minTTL != nil,
OverrideTTL: tc.minTTL != nil,
}),
)

View File

@ -2,6 +2,7 @@ package dnsservertest
import (
"context"
"fmt"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
@ -13,9 +14,15 @@ import (
// AnswerTTL is the default TTL of the test handler's answers.
const AnswerTTL time.Duration = 100 * time.Second
// CreateTestHandler creates a [dnsserver.Handler] with the specified
// NewDefaultHandler returns a simple handler that always returns a response
// with a single A record.
func NewDefaultHandler() (handler dnsserver.Handler) {
return NewDefaultHandlerWithCount(1)
}
// NewDefaultHandlerWithCount creates a [dnsserver.Handler] with the specified
// parameters. All responses will have the [TestAnsTTL] TTL.
func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
func NewDefaultHandlerWithCount(recordsCount int) (h dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
// Check that necessary context keys are set.
si := dnsserver.MustServerInfoFromContext(ctx)
@ -49,8 +56,12 @@ func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
return dnsserver.HandlerFunc(f)
}
// DefaultHandler returns a simple handler that always returns a response with
// a single A record.
func DefaultHandler() (handler dnsserver.Handler) {
return CreateTestHandler(1)
// NewPanicHandler returns a DNS handler that panics with an error.
func NewPanicHandler() (handler dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
// TODO(a.garipov): Add a helper for these kinds of errors to golibs.
panic(fmt.Errorf("unexpected call to ServeDNS(%v, %v)", rw, req))
}
return dnsserver.HandlerFunc(f)
}

View File

@ -273,6 +273,22 @@ func NewNS(name string, ttl uint32, ns string) (rr dns.RR) {
}
}
// NewOPT constructs the new resource record of type OPT.
func NewOPT(do bool, udpSize uint16, opts ...dns.EDNS0) (rr dns.RR) {
opt := &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
Option: opts,
}
opt.SetDo(do)
opt.SetUDPSize(udpSize)
return opt
}
// NewECSExtra constructs a new OPT RR for the extra section.
func NewECSExtra(ip net.IP, fam uint16, mask, scope uint8) (extra dns.RR) {
return &dns.OPT{
@ -300,11 +316,8 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
padLen := requestPaddingBlockSize - msgLen%requestPaddingBlockSize
// Truncate padding to fit in UDP buffer.
if msgLen+padLen > int(UDPBufferSize) {
padLen = int(UDPBufferSize) - msgLen
if padLen < 0 {
padLen = 0
}
if bufSzInt := int(UDPBufferSize); msgLen+padLen > bufSzInt {
padLen = max(bufSzInt-msgLen, 0)
}
return &dns.OPT{

View File

@ -23,7 +23,7 @@ func TestMain(m *testing.M) {
const testTimeout = 1 * time.Second
func TestHandler_ServeDNS(t *testing.T) {
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// No-fallbacks handler.
handler := forward.NewHandler(&forward.HandlerConfig{
@ -47,7 +47,7 @@ func TestHandler_ServeDNS(t *testing.T) {
}
func TestHandler_ServeDNS_fallbackNetError(t *testing.T) {
srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
handler := forward.NewHandler(&forward.HandlerConfig{
UpstreamsAddresses: []*forward.UpstreamPlainConfig{{
Network: forward.NetworkAny,

View File

@ -20,7 +20,7 @@ func TestHandler_Refresh(t *testing.T) {
var upstreamIsUp atomic.Bool
var upstreamRequestsCount atomic.Int64
defaultHandler := dnsservertest.DefaultHandler()
defaultHandler := dnsservertest.NewDefaultHandler()
// This handler writes an empty message if upstreamUp flag is false.
handlerFunc := dnsserver.HandlerFunc(func(

View File

@ -36,7 +36,7 @@ func TestUpstreamPlain_Exchange(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
ups := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
Network: tc.network,
Address: netip.MustParseAddrPort(addr),
@ -65,7 +65,7 @@ func TestUpstreamPlain_Exchange_truncated(t *testing.T) {
rw.LocalAddr(),
rw.RemoteAddr(),
)
handler := dnsservertest.DefaultHandler()
handler := dnsservertest.NewDefaultHandler()
err = handler.ServeDNS(ctx, nrw, req)
if err != nil {
return err

View File

@ -1,9 +1,9 @@
module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.23.1
go 1.23.2
require (
github.com/AdguardTeam/golibs v0.28.0
github.com/AdguardTeam/golibs v0.30.1
github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2
@ -11,12 +11,12 @@ require (
github.com/miekg/dns v1.1.62
github.com/panjf2000/ants/v2 v2.10.0
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/prometheus/client_golang v1.20.1
github.com/quic-go/quic-go v0.47.0
github.com/prometheus/client_golang v1.20.5
github.com/quic-go/quic-go v0.48.1
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
golang.org/x/net v0.29.0
golang.org/x/sys v0.25.0
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.30.0
golang.org/x/sys v0.26.0
)
require (
@ -26,22 +26,23 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect
github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.20.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.55.0 // indirect
github.com/prometheus/common v0.60.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
golang.org/x/tools v0.25.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.26.0 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,4 +1,5 @@
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4=
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
@ -25,10 +26,10 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@ -47,18 +48,18 @@ github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8=
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y=
github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E=
github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -69,29 +70,29 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -23,7 +23,7 @@ func TestCacheMetricsListener_integration_cache(t *testing.T) {
})
handlerWithMiddleware := dnsserver.WithMiddlewares(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
cacheMiddleware,
)

View File

@ -18,7 +18,7 @@ import (
// normal unit test, we create a forward handler, emulate a query and then
// check if prom metrics were incremented.
func TestForwardMetricsListener_integration_request(t *testing.T) {
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// Initialize a new forward.Handler and set the metrics listener.
handler := forward.NewHandler(&forward.HandlerConfig{

View File

@ -19,16 +19,21 @@ import (
// normal unit test, we create a cache middleware, emulate a query and then
// check if prom metrics were incremented.
func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
rps := 5
const (
count = 5
ivl = time.Second
)
rl := ratelimit.NewBackoff(&ratelimit.BackoffConfig{
Allowlist: ratelimit.NewDynamicAllowlist([]netip.Prefix{}, []netip.Prefix{}),
Period: time.Minute,
Duration: time.Minute,
Count: uint(rps),
Count: count,
ResponseSizeEstimate: 1 * datasize.KB,
IPv4RPS: uint(rps),
IPv6RPS: uint(rps),
IPv4Count: count,
IPv4Interval: ivl,
IPv6Count: count,
IPv6Interval: ivl,
RefuseANY: true,
})
rlMw, err := ratelimit.NewMiddleware(&ratelimit.MiddlewareConfig{
@ -38,7 +43,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
require.NoError(t, err)
handlerWithMiddleware := dnsserver.WithMiddlewares(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
rlMw,
)
@ -55,7 +60,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
err = handlerWithMiddleware.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
if i < rps {
if i < count {
dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false)
} else {
require.Nil(t, nrw.Msg())

View File

@ -24,7 +24,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
ConfigBase: dnsserver.ConfigBase{
Name: "test",
Addr: "127.0.0.1:0",
Handler: dnsservertest.DefaultHandler(),
Handler: dnsservertest.NewDefaultHandler(),
Metrics: prometheus.NewServerMetricsListener(testNamespace),
},
}

View File

@ -36,19 +36,29 @@ type BackoffConfig struct {
// as several responses.
ResponseSizeEstimate datasize.ByteSize
// IPv4RPS is the maximum number of requests per second allowed from a
// single subnet for IPv4 addresses. Any requests above this rate are
// counted as the client's backoff count. RPS must be greater than zero.
IPv4RPS uint
// IPv4Count is the maximum number of requests per a specified interval
// allowed from a single subnet for IPv4 addresses. Any requests above this
// rate are counted as the client's backoff count. It must be greater than
// zero.
IPv4Count uint
// IPv4Interval is the time during which to count the number of requests
// for IPv4 addresses.
IPv4Interval time.Duration
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
IPv4SubnetKeyLen int
// IPv6RPS is the maximum number of requests per second allowed from a
// single subnet for IPv6 addresses. Any requests above this rate are
// counted as the client's backoff count. RPS must be greater than zero.
IPv6RPS uint
// IPv6Count is the maximum number of requests per a specified interval
// allowed from a single subnet for IPv6 addresses. Any requests above this
// rate are counted as the client's backoff count. It must be greater than
// zero.
IPv6Count uint
// IPv6Interval is the time during which to count the number of requests
// for IPv6 addresses.
IPv6Interval time.Duration
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
@ -75,9 +85,11 @@ type Backoff struct {
allowlist Allowlist
respSzEst datasize.ByteSize
count uint
ipv4rps uint
ipv4Count uint
ipv4Interval time.Duration
ipv4SubnetKeyLen int
ipv6rps uint
ipv6Count uint
ipv6Interval time.Duration
ipv6SubnetKeyLen int
refuseANY bool
}
@ -93,9 +105,11 @@ func NewBackoff(c *BackoffConfig) (l *Backoff) {
allowlist: c.Allowlist,
respSzEst: c.ResponseSizeEstimate,
count: c.Count,
ipv4rps: c.IPv4RPS,
ipv4Count: c.IPv4Count,
ipv4Interval: c.IPv4Interval,
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
ipv6rps: c.IPv6RPS,
ipv6Count: c.IPv6Count,
ipv6Interval: c.IPv6Interval,
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
refuseANY: c.RefuseANY,
}
@ -133,12 +147,12 @@ func (l *Backoff) IsRateLimited(
return true, false, nil
}
rps := l.ipv4rps
count, ivl := l.ipv4Count, l.ipv4Interval
if ip.Is6() {
rps = l.ipv6rps
count, ivl = l.ipv6Count, l.ipv6Interval
}
return l.hasHitRateLimit(key, rps), false, nil
return l.hasHitRateLimit(key, count, ivl), false, nil
}
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
@ -198,15 +212,15 @@ func (l *Backoff) incBackoff(key string) {
l.hitCounters.SetDefault(key, counter)
}
// hasHitRateLimit checks value for a subnet with rps as a maximum number
// requests per second.
func (l *Backoff) hasHitRateLimit(subnetIPStr string, rps uint) (ok bool) {
// hasHitRateLimit checks if the value of requests for given subnet hit the
// maximum count of requests per given interval.
func (l *Backoff) hasHitRateLimit(subnetIPStr string, count uint, ivl time.Duration) (ok bool) {
var r *RequestCounter
rVal, ok := l.reqCounters.Get(subnetIPStr)
if ok {
r = rVal.(*RequestCounter)
} else {
r = NewRequestCounter(rps, time.Second)
r = NewRequestCounter(count, ivl)
l.reqCounters.SetDefault(subnetIPStr, r)
}

View File

@ -22,7 +22,10 @@ func TestMain(m *testing.M) {
}
func TestRatelimitMiddleware(t *testing.T) {
const rps = 10
const (
rps = 10
ivl = time.Second
)
persistent := []netip.Prefix{
netip.MustParsePrefix("4.3.2.1/8"),
@ -99,9 +102,11 @@ func TestRatelimitMiddleware(t *testing.T) {
Duration: time.Minute,
Count: rps,
ResponseSizeEstimate: 128 * datasize.B,
IPv4RPS: rps,
IPv4Count: rps,
IPv4Interval: ivl,
IPv4SubnetKeyLen: 24,
IPv6RPS: rps,
IPv6Count: rps,
IPv6Interval: ivl,
IPv6SubnetKeyLen: 48,
RefuseANY: true,
})
@ -112,7 +117,7 @@ func TestRatelimitMiddleware(t *testing.T) {
require.NoError(t, err)
withMw := dnsserver.WithMiddlewares(
dnsservertest.CreateTestHandler(tc.respCount),
dnsservertest.NewDefaultHandlerWithCount(tc.respCount),
rlMw,
)

View File

@ -309,6 +309,10 @@ func (s *ServerBase) serveDNSMsgInternal(
s.metrics.OnError(ctx, err)
resp = genErrorResponse(req, dns.RcodeServerFailure)
if isNonCriticalNetError(err) {
addEDE(req, resp, dns.ExtendedErrorCodeNetworkError, "")
}
err = rw.WriteMsg(ctx, req, resp)
if err != nil {
log.Debug("[%d]: error writing a response: %s", req.Id, err)
@ -316,6 +320,28 @@ func (s *ServerBase) serveDNSMsgInternal(
}
}
// addEDE adds an Extended DNS Error (EDE) option to the blocked response
// message, if the request indicates EDNS support.
func addEDE(req, resp *dns.Msg, code uint16, text string) {
reqOpt := req.IsEdns0()
if reqOpt == nil {
// Requestor doesn't implement EDNS, see
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
return
}
respOpt := resp.IsEdns0()
if respOpt == nil {
resp.SetEdns0(reqOpt.UDPSize(), reqOpt.Do())
respOpt = resp.Extra[len(resp.Extra)-1].(*dns.OPT)
}
respOpt.Option = append(respOpt.Option, &dns.EDNS0_EDE{
InfoCode: code,
ExtraText: text,
})
}
// acceptMsg checks if we should process the incoming DNS query.
func (s *ServerBase) acceptMsg(m *dns.Msg) (action dns.MsgAcceptAction) {
if m.Response {

View File

@ -37,7 +37,7 @@ func BenchmarkServeDNS(b *testing.B) {
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
_, addr := dnsservertest.RunDNSServer(b, dnsservertest.DefaultHandler())
_, addr := dnsservertest.RunDNSServer(b, dnsservertest.NewDefaultHandler())
// Prepare a test message.
m := new(dns.Msg)
@ -105,7 +105,7 @@ func readMsg(resBuf []byte, network dnsserver.Network, conn net.Conn) (err error
func BenchmarkServeTLS(b *testing.B) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(b, dnsservertest.DefaultHandler(), tlsConfig)
addr := dnsservertest.RunTLSServer(b, dnsservertest.NewDefaultHandler(), tlsConfig)
// Prepare a test message
m := new(dns.Msg)
@ -171,7 +171,7 @@ func BenchmarkServeDoH(b *testing.B) {
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tc.tlsConfig,
nil,
)
@ -250,7 +250,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
Net: string(tc.network),
}
s := dnsservertest.RunDNSCryptServer(b, dnsservertest.DefaultHandler())
s := dnsservertest.RunDNSCryptServer(b, dnsservertest.NewDefaultHandler())
stamp := dnsstamps.ServerStamp{
ServerAddrStr: s.ServerAddr,
ServerPk: s.ResolverPk,
@ -282,7 +282,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
func BenchmarkServeQUIC(b *testing.B) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
)
require.NoError(b, err)

View File

@ -20,7 +20,7 @@ import (
)
func TestServerDNS_StartShutdown(t *testing.T) {
_, _ = dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
_, _ = dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
}
func TestServerDNS_integration_query(t *testing.T) {
@ -42,7 +42,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess,
}, {
@ -54,7 +54,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess,
}, {
@ -89,7 +89,7 @@ func TestServerDNS_integration_query(t *testing.T) {
},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantMsg: func(t *testing.T, m *dns.Msg) {
opt := m.IsEdns0()
require.NotNil(t, opt)
@ -109,7 +109,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 0,
wantRCode: dns.RcodeFormatError,
}, {
@ -122,7 +122,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "eXaMplE.oRg.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess,
}, {
@ -136,7 +136,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 0,
wantRCode: dns.RcodeNotImplemented,
}, {
@ -170,7 +170,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess,
}, {
@ -185,7 +185,7 @@ func TestServerDNS_integration_query(t *testing.T) {
},
},
// Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64),
handler: dnsservertest.NewDefaultHandlerWithCount(64),
wantRecordsCount: 0,
wantRCode: dns.RcodeSuccess,
wantTruncated: true,
@ -210,7 +210,7 @@ func TestServerDNS_integration_query(t *testing.T) {
},
},
// Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64),
handler: dnsservertest.NewDefaultHandlerWithCount(64),
wantRecordsCount: 64,
wantRCode: dns.RcodeSuccess,
wantTruncated: false,
@ -226,7 +226,7 @@ func TestServerDNS_integration_query(t *testing.T) {
},
},
// Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64),
handler: dnsservertest.NewDefaultHandlerWithCount(64),
// No truncate
wantRecordsCount: 64,
wantRCode: dns.RcodeSuccess,
@ -257,7 +257,7 @@ func TestServerDNS_integration_query(t *testing.T) {
},
},
},
handler: dnsservertest.DefaultHandler(),
handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess,
}}
@ -304,7 +304,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
// As per RFC 7766 we should support queries pipelining for TCP, that is
// server must be able to process incoming queries in parallel and write
// responses possibly out of order within the same connection.
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// Establish a connection.
conn, err := net.Dial("tcp", addr)
@ -372,7 +372,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
}
func TestServerDNS_integration_udpMsgIgnore(t *testing.T) {
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
conn, err := net.Dial("udp", addr)
require.Nil(t, err)
@ -469,7 +469,7 @@ func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
conn, err := net.Dial("tcp", addr)
require.Nil(t, err)

View File

@ -49,7 +49,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
name: "udp_truncate_response",
network: dnsserver.NetworkUDP,
// Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64),
handler: dnsservertest.NewDefaultHandlerWithCount(64),
// DNSCrypt server removes all records from a truncated response
expectedRecordsCount: 0,
expectedRCode: dns.RcodeSuccess,
@ -66,7 +66,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
name: "udp_edns0_no_truncate",
network: dnsserver.NetworkUDP,
// Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64),
handler: dnsservertest.NewDefaultHandlerWithCount(64),
expectedRecordsCount: 64,
expectedRCode: dns.RcodeSuccess,
expectedTruncated: false,
@ -89,7 +89,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
handler := tc.handler
if tc.handler == nil {
handler = dnsservertest.DefaultHandler()
handler = dnsservertest.NewDefaultHandler()
}
s := dnsservertest.RunDNSCryptServer(t, handler)

View File

@ -19,6 +19,7 @@ import (
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
@ -304,9 +305,9 @@ func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Lis
// application won't be able to continue listening to DoH.
defer s.handlePanicAndExit(ctx)
scheme := "https"
scheme := urlutil.SchemeHTTPS
if s.conf.TLSConfig == nil {
scheme = "http"
scheme = urlutil.SchemeHTTP
}
u := &url.URL{

View File

@ -109,7 +109,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
nil,
)
@ -146,7 +146,7 @@ func TestServerHTTPS_integration_nonDNSHandler(t *testing.T) {
})
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
nil,
testHandler,
)
@ -321,7 +321,7 @@ func TestDNSMsgToJSONMsg(t *testing.T) {
func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
nil,
)
@ -346,7 +346,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
func TestServerHTTPS_0RTT(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
nil,
)
@ -514,7 +514,7 @@ func createDoH3Client(
tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = []string{http3.NextProtoH3}
transport := &http3.RoundTripper{
transport := &http3.Transport{
DisableCompression: true,
Dial: func(
ctx context.Context,

View File

@ -7,13 +7,18 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
// JSONMsg represents a *dns.Msg in the JSON format defined here:
// https://developers.google.com/speed/public-dns/docs/doh/json#dns_response_in_json
// Note, that we do not implement some parts of it. There is no "Comment" field
// and there's no "edns_client_subnet".
//
// NOTE: This API differs from the Google one in the following ways:
// 1. The "Comment" field is not implemented.
// 2. The "edns_client_subnet" query parameter is not supported.
// 3. The "sde" query parameter is added and supported for the experimental
// Structured DNS Errors feature.
type JSONMsg struct {
Question []JSONQuestion `json:"Question"`
Answer []JSONAnswer `json:"Answer"`
@ -26,13 +31,13 @@ type JSONMsg struct {
Status int `json:"Status"`
}
// JSONQuestion is a part of JSONMsg definition.
// JSONQuestion is a part of [JSONMsg] definition.
type JSONQuestion struct {
Name string `json:"name"`
Type uint16 `json:"type"`
}
// JSONAnswer is a part of JSONMsg definition.
// JSONAnswer is a part of [JSONMsg] definition.
type JSONAnswer struct {
Name string `json:"name"`
Data string `json:"data"`
@ -41,7 +46,7 @@ type JSONAnswer struct {
Class uint16 `json:"class"`
}
// DNSMsgToJSONMsg converts the *dns.Msg to the JSON format (*JSONMsg).
// DNSMsgToJSONMsg converts the *dns.Msg to the JSON format.
func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
msg = &JSONMsg{
Status: m.Rcode,
@ -74,9 +79,9 @@ func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
func rrToJSON(rr dns.RR) (j JSONAnswer) {
hdr := rr.Header()
// Extracting the RR value is a bit tricky since miekg/dns does not
// expose the necessary methods. This way we can benefit from the
// proper string serialization code that's used inside miekg/dns.
// Extracting the RR value is a bit tricky since miekg/dns does not expose
// the necessary methods. This way we can benefit from the proper string
// serialization code that's used inside miekg/dns.
hdrStr := hdr.String()
valStr := rr.String()
data := strings.TrimLeft(strings.TrimPrefix(valStr, hdrStr), " ")
@ -90,19 +95,17 @@ func rrToJSON(rr dns.RR) (j JSONAnswer) {
}
}
// dnsMsgToJSON converts the *dns.Msg to the JSON format (JSONMsg) and returns
// it in the serialized form.
// dnsMsgToJSON converts the *dns.Msg to the JSON format and returns it in the
// serialized form.
func dnsMsgToJSON(m *dns.Msg) (b []byte, err error) {
msg := DNSMsgToJSONMsg(m)
return json.Marshal(msg)
return json.Marshal(DNSMsgToJSONMsg(m))
}
// httpRequestToMsgJSON builds a DNS message from the request parameters.
// We use the same parameters as the ones defined here:
// https://developers.google.com/speed/public-dns/docs/doh/json#supported_parameters
// Some parameters are not supported: "ct", "edns_client_subnet".
func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) {
q := req.URL.Query()
//
// See [JSONMsg].
func httpRequestToMsgJSON(httpReq *http.Request) (b []byte, err error) {
q := httpReq.URL.Query()
// Query name, the only required parameter.
name := q.Get("name")
@ -111,71 +114,91 @@ func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) {
return nil, ErrInvalidArgument
}
// RR type can be represented as a number in [1, 65535] or a
// canonical string (case-insensitive, such as A or AAAA).
var t uint16
t, err = urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
// RR type can be represented as a number in [1, 65535] or a canonical
// string (case-insensitive, such as A or AAAA).
qt, err := urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// Query class can be represented as a number in [1, 65535] or a
// canonical string (case-insensitive).
var qc uint16
qc, err = urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
// Query class can be represented as a number in [1, 65535] or a canonical
// string (case-insensitive).
qc, err := urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable
// DNSSEC validation; use cd=0, cd=false, or no cd parameter to
// enable DNSSEC validation.
var cd bool
cd, err = urlQueryParameterToBoolean(q, "cd", false)
// The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable DNSSEC
// validation; use cd=0, cd=false, or no cd parameter to enable DNSSEC
// validation.
cd, err := urlQueryParameterToBoolean(q, "cd", false)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// The DO (DNSSEC OK) flag. Use do=1, or do=true to include DNSSEC
// records (RRSIG, NSEC, NSEC3); use do=0, do=false, or no do parameter
// to omit DNSSEC records.
var do bool
do, err = urlQueryParameterToBoolean(q, "do", false)
// The DO (DNSSEC OK) flag. Use do=1 (or do=true) to include DNSSEC records
// (RRSIG, NSEC, NSEC3); use do=0 (do=false) or no do parameter to omit
// DNSSEC records.
do, err := urlQueryParameterToBoolean(q, "do", false)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// The experimental Structured DNS Errors feature.
sde, err := urlQueryParameterToBoolean(q, "sde", false)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// Now build a DNS message with all those parameters
r := &dns.Msg{
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
CheckingDisabled: cd,
RecursionDesired: true,
},
Question: []dns.Question{
{
Question: []dns.Question{{
Name: dns.Fqdn(name),
Qtype: t,
Qtype: qt,
Qclass: qc,
},
},
}},
}
if do {
r.SetEdns0(dns.MaxMsgSize, do)
setEDNSFromQuery(req, do, sde)
return req.Pack()
}
// setEDNSFromQuery sets the EDNS parameters on the request depending on the
// query parameters.
func setEDNSFromQuery(req *dns.Msg, do, sde bool) {
if !do && !sde {
return
}
return r.Pack()
req.SetEdns0(dns.MaxMsgSize, do)
if sde {
opt := req.Extra[0].(*dns.OPT)
opt.Option = append(opt.Option, &dns.EDNS0_EDE{})
}
}
// urlQueryParameterToUint16 is a helper function that extracts a uint16 value
// from a query parameter. See httpRequestToMsgJSON to see how it's used.
// from a query parameter.
func urlQueryParameterToUint16(
q url.Values,
name string,
defaultValue uint16,
strValuesMap map[string]uint16,
) (v uint16, err error) {
defer func() { err = errors.Annotate(err, "parameter %q: %w", name) }()
strValue := q.Get(name)
uintValue, convErr := strconv.ParseUint(strValue, 10, 16)
switch {
@ -199,12 +222,8 @@ func urlQueryParameterToUint16(
}
// urlQueryParameterToBoolean is a helper function that extracts a boolean value
// from a query parameter. See httpRequestToMsgJSON to see how it's used.
func urlQueryParameterToBoolean(
q url.Values,
name string,
defaultValue bool,
) (v bool, err error) {
// from a query parameter.
func urlQueryParameterToBoolean(q url.Values, name string, defaultValue bool) (v bool, err error) {
strValue := q.Get(name)
switch strValue {
case "1", "true", "True":

View File

@ -25,7 +25,7 @@ import (
func TestServerQUIC_integration_query(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
)
require.NoError(t, err)
@ -74,7 +74,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
)
require.NoError(t, err)
@ -108,7 +108,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
func TestServerQUIC_integration_0RTT(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
)
require.NoError(t, err)
@ -146,7 +146,7 @@ func TestServerQUIC_integration_0RTT(t *testing.T) {
func TestServerQUIC_integration_largeQuery(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(),
dnsservertest.NewDefaultHandler(),
tlsConfig,
)
require.NoError(t, err)

View File

@ -17,7 +17,7 @@ import (
func TestServerTLS_integration_queryTLS(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
// Create a test message.
req := new(dns.Msg)
@ -94,7 +94,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
t.Parallel()
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
h := dnsservertest.DefaultHandler()
h := dnsservertest.NewDefaultHandler()
addr := dnsservertest.RunTLSServer(t, h, tlsConfig)
conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
@ -120,7 +120,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
// Handler that writes a huge response which would not fit
// into a UDP response, but it should fit a TCP response just okay.
handler := dnsservertest.CreateTestHandler(64)
handler := dnsservertest.NewDefaultHandlerWithCount(64)
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, handler, tlsConfig)
@ -155,7 +155,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
// i.e. we should be able to process incoming queries in parallel and
// write responses out of order.
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
// First - establish a connection
conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
@ -221,7 +221,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
req.Extra = []dns.RR{dnsservertest.NewEDNS0Padding(req.Len(), dns.DefaultMsgSize)}

228
internal/dnssvc/config.go Normal file
View File

@ -0,0 +1,228 @@
package dnssvc
import (
"log/slog"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/prometheus/client_golang/prometheus"
)
// Config is the configuration of the AdGuard DNS service.
type Config struct {
// Handlers are the handlers to use in this DNS service.
Handlers Handlers
// NewListener, when set, is used instead of the package-level function
// [NewListener] when creating a DNS listener.
//
// TODO(a.garipov): This is only used for tests. Replace with a
// [netext.ListenConfig].
NewListener NewListenerFunc
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse. It must not be nil.
Cloner *dnsmsg.Cloner
// ControlConf is the configuration of socket options.
ControlConf *netext.ControlConfig
// ConnLimiter, if not nil, is used to limit the number of simultaneously
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// ErrColl is the error collector that is used to collect critical and
// non-critical errors. It must not be nil.
ErrColl errcoll.Interface
// NonDNS is the handler for non-DNS HTTP requests. It must not be nil.
NonDNS http.Handler
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// ServerGroups are the DNS server groups. Each element must be non-nil.
ServerGroups []*agd.ServerGroup
// HandleTimeout defines the timeout for the entire handling of a single
// query. It must be greater than zero.
HandleTimeout time.Duration
}
// NewListenerFunc is the type for DNS listener constructors.
type NewListenerFunc func(
srv *agd.Server,
baseConf dnsserver.ConfigBase,
nonDNS http.Handler,
) (l Listener, err error)
// Listener is a type alias for dnsserver.Server to make internal naming more
// consistent.
type Listener = dnsserver.Server
// HandlersConfig is the configuration necessary to create or wrap the main DNS
// handler.
//
// TODO(a.garipov): Consider adding validation functions.
type HandlersConfig struct {
// BaseLogger is used to create loggers with custom prefixes for middlewares
// and the service itself. It must not be nil.
BaseLogger *slog.Logger
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse. It must not be nil.
Cloner *dnsmsg.Cloner
// Cache is the configuration for the DNS cache.
Cache *CacheConfig
// HumanIDParser is used to normalize and parse human-readable device
// identifiers. It must not be nil if at least one server group has
// profiles enabled.
HumanIDParser *agd.HumanIDParser
// Messages is the message constructor used to create blocked and other
// messages for this DNS service. It must not be nil.
Messages *dnsmsg.Constructor
// PluginRegistry is used to override configuration parameters.
PluginRegistry *plugin.Registry
// StructuredErrors is the configuration for the experimental Structured DNS
// Errors feature in the profiles' message constructors. It must not be
// nil.
StructuredErrors *dnsmsg.StructuredDNSErrorsConfig
// AccessManager is used to block requests. It must not be nil.
AccessManager access.Interface
// BillStat is used to collect billing statistics. It must not be nil.
BillStat billstat.Recorder
// CacheManager is the global cache manager. It must not be nil.
CacheManager agdcache.Manager
// DNSCheck is used by clients to check if they use AdGuard DNS. It must
// not be nil.
DNSCheck dnscheck.Interface
// DNSDB is used to update anonymous statistics about DNS queries. It must
// not be nil.
DNSDB dnsdb.Interface
// ErrColl is the error collector that is used to collect critical and
// non-critical errors. It must not be nil.
ErrColl errcoll.Interface
// FilterStorage is the storage of all filters. It must not be nil.
FilterStorage filter.Storage
// GeoIP is the GeoIP database used to detect geographic data about IP
// addresses in requests and responses. It must not be nil.
GeoIP geoip.Interface
// Handler is the ultimate handler of the DNS query to be wrapped by
// middlewares. It must not be nil.
Handler dnsserver.Handler
// HashMatcher is the safe-browsing hash matcher for TXT queries. It must
// not be nil.
HashMatcher filter.HashMatcher
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on. It must not be nil if at least one server
// group has profiles enabled.
ProfileDB profiledb.Interface
// PrometheusRegisterer is used to register Prometheus metrics. It must not
// be nil.
PrometheusRegisterer prometheus.Registerer
// QueryLog is used to write the logs into. It must not be nil.
QueryLog querylog.Interface
// RateLimit is used for allow or decline requests. It must not be nil.
RateLimit ratelimit.Interface
// RuleStat is used to collect statistics about matched filtering rules and
// rule lists. It must not be nil.
RuleStat rulestat.Interface
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// FilteringGroups are the DNS filtering groups. Each element must be
// non-nil.
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
// ServerGroups are the DNS server groups for which to build handlers. Each
// element and its servers must be non-nil.
ServerGroups []*agd.ServerGroup
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes in
// the profiles' message constructors.
EDEEnabled bool
}
// Handlers contains the map of handlers for each server of each server group.
// The pointers are the same as those passed in a [HandlersConfig] to
// [NewHandlers].
type Handlers map[HandlerKey]dnsserver.Handler
// HandlerKey is a key for the [Handlers] map.
type HandlerKey struct {
Server *agd.Server
ServerGroup *agd.ServerGroup
}
// CacheConfig is the configuration for the DNS cache.
type CacheConfig struct {
// MinTTL is the minimum supported TTL for cache items.
MinTTL time.Duration
// ECSCount is the size of the DNS cache for domain names that support
// ECS, in entries. It must be greater than zero if [CacheConfig.CacheType]
// is [CacheTypeECS].
ECSCount int
// NoECSCount is the size of the DNS cache for domain names that don't
// support ECS, in entries. It must be greater than zero if
// [CacheConfig.CacheType] is [CacheTypeSimple] or [CacheTypeECS].
NoECSCount int
// Type is the cache type. It must be valid.
Type CacheType
// OverrideCacheTTL shows if the TTL overriding logic should be used.
OverrideCacheTTL bool
}
// CacheType is the type of the cache to use.
type CacheType uint8
// CacheType constants.
const (
CacheTypeNone CacheType = iota + 1
CacheTypeSimple
CacheTypeECS
)

View File

@ -0,0 +1,35 @@
package dnssvc
import (
"context"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
)
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
// returns a context with the given timeout as well as a new [agd.RequestID].
type contextConstructor struct {
timeout time.Duration
}
// newContextConstructor returns a new properly initialized *contextConstructor.
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
return &contextConstructor{
timeout: timeout,
}
}
// type check
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
// New implements the [dnsserver.ContextConstructor] interface for
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
// as its timeout and the corresponding cancelation function.
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
return ctx, cancel
}

View File

@ -7,289 +7,27 @@ package dnssvc
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/service"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
)
// Config is the configuration of the AdGuard DNS service.
type Config struct {
// BaseLogger is used to create loggers with custom prefixes for middlewares
// and the service itself.
BaseLogger *slog.Logger
// Messages is the message constructor used to create blocked and other
// messages for this DNS service.
Messages *dnsmsg.Constructor
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse.
Cloner *dnsmsg.Cloner
// ControlConf is the configuration of socket options.
ControlConf *netext.ControlConfig
// ConnLimiter, if not nil, is used to limit the number of simultaneously
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// HumanIDParser is used to normalize and parse human-readable device
// identifiers.
HumanIDParser *agd.HumanIDParser
// PluginRegistry is used to override configuration parameters.
PluginRegistry *plugin.Registry
// AccessManager is used to block requests.
AccessManager access.Interface
// SafeBrowsing is the safe browsing TXT hash matcher.
SafeBrowsing filter.HashMatcher
// BillStat is used to collect billing statistics.
BillStat billstat.Recorder
// CacheManager is the global cache manager. CacheManager must not be nil.
CacheManager agdcache.Manager
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on.
ProfileDB profiledb.Interface
// PrometheusRegisterer is used to register Prometheus metrics.
PrometheusRegisterer prometheus.Registerer
// DNSCheck is used by clients to check if they use AdGuard DNS.
DNSCheck dnscheck.Interface
// NonDNS is the handler for non-DNS HTTP requests.
NonDNS http.Handler
// DNSDB is used to update anonymous statistics about DNS queries.
DNSDB dnsdb.Interface
// ErrColl is the error collector that is used to collect critical and
// non-critical errors.
ErrColl errcoll.Interface
// FilterStorage is the storage of all filters.
FilterStorage filter.Storage
// GeoIP is the GeoIP database used to detect geographic data about IP
// addresses in requests and responses.
GeoIP geoip.Interface
// QueryLog is used to write the logs into.
QueryLog querylog.Interface
// RuleStat is used to collect statistics about matched filtering rules and
// rule lists.
RuleStat rulestat.Interface
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
// TODO(a.garipov): The handler and service logic should really not be
// intertwined in this way. See AGDNS-1327.
NewListener NewListenerFunc
// Handler is used as the main DNS handler instead of a simple forwarder.
// It must not be nil.
//
// TODO(a.garipov): Think of a better way to make the DNS server logic more
// testable.
Handler dnsserver.Handler
// RateLimit is used for allow or decline requests.
RateLimit ratelimit.Interface
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// FilteringGroups are the DNS filtering groups. Each element must be
// non-nil.
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
// ServerGroups are the DNS server groups. Each element must be non-nil.
ServerGroups []*agd.ServerGroup
// HandleTimeout defines the timeout for the entire handling of a single
// query.
HandleTimeout time.Duration
// CacheSize is the size of the DNS cache for domain names that don't
// support ECS.
//
// TODO(a.garipov): Extract this and following fields to cache configuration
// struct.
CacheSize int
// ECSCacheSize is the size of the DNS cache for domain names that support
// ECS.
ECSCacheSize int
// CacheMinTTL is the minimum supported TTL for cache items. This setting
// is used when UseCacheTTLOverride set to true.
CacheMinTTL time.Duration
// UseCacheTTLOverride shows if the TTL overrides logic should be used.
UseCacheTTLOverride bool
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used.
UseECSCache bool
// Service is the main DNS service of AdGuard DNS.
type Service struct {
groups []*serverGroup
}
type (
// MainMiddlewareMetrics is a re-export of the internal filtering-middleware
// metrics interface.
MainMiddlewareMetrics = mainmw.Metrics
// RatelimitMiddlewareMetrics is a re-export of the metrics interface of the
// internal access and ratelimiting middleware.
RatelimitMiddlewareMetrics = ratelimitmw.Metrics
)
// New returns a new DNS service.
func New(c *Config) (svc *Service, err error) {
// Use either the configured listener initializer or the default one.
newListener := c.NewListener
if newListener == nil {
newListener = NewListener
}
// Configure the end of the request handling pipeline.
handler := c.Handler
if handler == nil {
return nil, errors.Error("handler in config must not be nil")
}
// Configure the pre-upstream middleware common for all servers of all
// groups.
preUps := preupstream.New(&preupstream.Config{
Cloner: c.Cloner,
CacheManager: c.CacheManager,
DB: c.DNSDB,
GeoIP: c.GeoIP,
CacheSize: c.CacheSize,
ECSCacheSize: c.ECSCacheSize,
UseECSCache: c.UseECSCache,
CacheMinTTL: c.CacheMinTTL,
UseCacheTTLOverride: c.UseCacheTTLOverride,
})
handler = preUps.Wrap(handler)
errCollListener := &errCollMetricsListener{
errColl: c.ErrColl,
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
}
// Configure the service itself.
groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{
groups: groups,
}
mainMwMtrc, err := newMainMiddlewareMetrics(c)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
for i, srvGrp := range c.ServerGroups {
// The Filtering Middlewares
//
// These are middlewares common to all filtering and server groups.
// They change the flow of request handling, so they are separated.
dnsHdlr := dnsserver.WithMiddlewares(
handler,
preservice.New(&preservice.Config{
Messages: c.Messages,
HashMatcher: c.SafeBrowsing,
Checker: c.DNSCheck,
}),
mainmw.New(&mainmw.Config{
Metrics: mainMwMtrc,
Messages: c.Messages,
Cloner: c.Cloner,
BillStat: c.BillStat,
ErrColl: c.ErrColl,
FilterStorage: c.FilterStorage,
GeoIP: c.GeoIP,
QueryLog: c.QueryLog,
RuleStat: c.RuleStat,
}),
)
var servers []*server
servers, err = newServers(c, srvGrp, dnsHdlr, rlMwMtrc, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
}
groups[i] = &serverGroup{
name: srvGrp.Name,
servers: servers,
}
}
return svc, nil
}
// newMainMiddlewareMetrics returns a filtering-middleware metrics
// implementation from the config.
func newMainMiddlewareMetrics(c *Config) (mainMwMtrc MainMiddlewareMetrics, err error) {
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
if mainMwMtrc != nil {
return mainMwMtrc, nil
}
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("mainmw metrics: %w", err)
}
return mainMwMtrc, nil
// serverGroup is a group of servers.
type serverGroup struct {
name agd.ServerGroupName
servers []*server
}
// server is a group of listeners.
@ -303,27 +41,171 @@ type server struct {
listeners []*listener
}
// serverGroup is a group of servers.
type serverGroup struct {
name agd.ServerGroupName
servers []*server
// listener is a Listener along with some of its associated data.
type listener struct {
Listener
name string
}
// Service is the main DNS service of AdGuard DNS.
type Service struct {
groups []*serverGroup
}
// mustStartListener starts l and panics on any error.
func mustStartListener(
grp agd.ServerGroupName,
srv agd.ServerName,
l *listener,
) {
err := l.Start(context.Background())
if err != nil {
panic(fmt.Errorf("group %q: server %q: starting %q: %w", grp, srv, l.name, err))
// New returns a new DNS service.
func New(c *Config) (svc *Service, err error) {
// Use either the configured listener initializer or the default one.
newListener := c.NewListener
if newListener == nil {
newListener = NewListener
}
errCollListener := &errCollMetricsListener{
errColl: c.ErrColl,
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
}
// Configure the service itself.
groups := make([]*serverGroup, 0, len(c.ServerGroups))
for _, srvGrp := range c.ServerGroups {
g := &serverGroup{
name: srvGrp.Name,
}
g.servers, err = newServers(c, srvGrp, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
}
groups = append(groups, g)
}
svc = &Service{
groups: groups,
}
return svc, nil
}
// newServers creates a slice of servers.
func newServers(
c *Config,
srvGrp *agd.ServerGroup,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (servers []*server, err error) {
servers = make([]*server, 0, len(srvGrp.Servers))
for _, srv := range srvGrp.Servers {
k := HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handler, ok := c.Handlers[k]
if !ok {
return nil, fmt.Errorf("no handler for server %q of group %q", srv.Name, srvGrp.Name)
}
s := &server{
name: srv.Name,
handler: handler,
}
s.listeners, err = newListeners(c, srv, handler, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("server %q: %w", s.name, err)
}
servers = append(servers, s)
}
return servers, nil
}
// newListeners creates a slice of listeners for a server.
func newListeners(
c *Config,
srv *agd.Server,
handler dnsserver.Handler,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (listeners []*listener, err error) {
bindData := srv.BindData()
listeners = make([]*listener, 0, len(bindData))
for i, bindData := range bindData {
var addr string
if bindData.PrefixAddr == nil {
addr = bindData.AddrPort.String()
} else {
addr = bindData.PrefixAddr.String()
}
proto := srv.Protocol
name := listenerName(srv.Name, addr, proto)
baseConf := dnsserver.ConfigBase{
Network: dnsserver.NetworkAny,
Handler: handler,
Metrics: errCollListener,
Disposer: c.Cloner,
RequestContext: newContextConstructor(c.HandleTimeout),
ListenConfig: newListenConfig(
bindData.ListenConfig,
c.ControlConf,
c.ConnLimiter,
proto,
),
Name: name,
Addr: addr,
}
l := &listener{
name: name,
}
l.Listener, err = newListener(srv, baseConf, c.NonDNS)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
listeners = append(listeners, l)
}
return listeners, nil
}
// listenerName returns a standard name for a listener.
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
}
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
original netext.ListenConfig,
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if original != nil {
if l == nil {
return original
}
return connlimiter.NewListenConfig(original, l)
}
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {
lc = netext.DefaultListenConfig(ctrlConf)
}
if l != nil {
lc = connlimiter.NewListenConfig(lc, l)
}
return lc
}
// type check
@ -331,13 +213,13 @@ var _ service.Interface = (*Service)(nil)
// Start implements the [service.Interface] interface for *Service. It panics
// if one of the listeners could not start.
func (svc *Service) Start(_ context.Context) (err error) {
func (svc *Service) Start(ctx context.Context) (err error) {
for _, g := range svc.groups {
for _, s := range g.servers {
for _, l := range s.listeners {
// Consider inability to start any one DNS listener a fatal
// error.
mustStartListener(g.name, s.name, l)
mustStartListener(ctx, g.name, s.name, l)
}
}
}
@ -345,17 +227,17 @@ func (svc *Service) Start(_ context.Context) (err error) {
return nil
}
// shutdownListeners is a helper function that shuts down all listeners of a
// server.
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
for _, l := range listeners {
err = l.Shutdown(ctx)
// mustStartListener starts l and panics on any error.
func mustStartListener(
ctx context.Context,
srvGrp agd.ServerGroupName,
srv agd.ServerName,
l *listener,
) {
err := l.Start(ctx)
if err != nil {
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
panic(fmt.Errorf("group %q: server %q: starting %q: %w", srvGrp, srv, l.name, err))
}
}
return nil
}
// Shutdown implements the [service.Interface] interface for *Service.
@ -378,9 +260,22 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil
}
// shutdownListeners is a helper function that shuts down all listeners of a
// server.
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
for _, l := range listeners {
err = l.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
}
}
return nil
}
// Handle is a simple helper to test the handling of DNS requests.
//
// TODO(a.garipov): Remove once the mainmw refactoring is complete.
// TODO(a.garipov): Remove once the refactoring is complete.
func (svc *Service) Handle(
ctx context.Context,
grpName agd.ServerGroupName,
@ -417,29 +312,6 @@ func (svc *Service) Handle(
return srv.handler.ServeDNS(ctx, rw, r)
}
// Listener is a type alias for dnsserver.Server to make internal naming more
// consistent.
type Listener = dnsserver.Server
// NewListenerFunc is the type for DNS listener constructors.
type NewListenerFunc func(
s *agd.Server,
baseConf dnsserver.ConfigBase,
nonDNS http.Handler,
) (l Listener, err error)
// listener is a Listener along with some of its associated data.
type listener struct {
Listener
name string
}
// listenerName returns a standard name for a listener.
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
}
// NewListener returns a new Listener. It is the default DNS listener
// constructor.
//
@ -500,213 +372,8 @@ func NewListener(
TLSConfig: s.TLS,
})
default:
return nil, fmt.Errorf("bad protocol %v", p)
return nil, fmt.Errorf("protocol: %w: %d", errors.ErrBadEnumValue, p)
}
return l, nil
}
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
// that returns a context with the given timeout as well as a new
// [agd.RequestID].
type contextConstructor struct {
timeout time.Duration
}
// newContextConstructor returns a new properly initialized *contextConstructor.
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
return &contextConstructor{
timeout: timeout,
}
}
// type check
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
// New implements the [dnsserver.ContextConstructor] interface for
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
// as its timeout and the corresponding cancelation function.
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
return ctx, cancel
}
// newServers creates a slice of servers.
//
// TODO(a.garipov): Refactor this into a builder pattern.
func newServers(
c *Config,
srvGrp *agd.ServerGroup,
handler dnsserver.Handler,
rlMwMtrc ratelimitmw.Metrics,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (servers []*server, err error) {
servers = make([]*server, len(srvGrp.Servers))
for i, s := range srvGrp.Servers {
// The Initial Middlewares
//
// These middlewares are either specific to the server or must be the
// furthest away from the handler and thus are the first to process
// a request.
// Assume that all the validations have been made during the
// configuration validation step back in package cmd. If we ever get
// new ways of receiving configuration, remove this assumption and
// validate fg.
fg := c.FilteringGroups[srvGrp.FilteringGroup]
df := newDeviceFinder(c, srvGrp, s)
rlm := ratelimitmw.New(&ratelimitmw.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw"),
Messages: c.Messages,
FilteringGroup: fg,
ServerGroup: srvGrp,
Server: s,
AccessManager: c.AccessManager,
DeviceFinder: df,
ErrColl: c.ErrColl,
GeoIP: c.GeoIP,
Metrics: rlMwMtrc,
Limiter: c.RateLimit,
// Only apply rate-limiting logic to plain DNS.
Protocols: []agd.Protocol{agd.ProtoDNS},
})
if err != nil {
return nil, fmt.Errorf("ratelimit: %w", err)
}
imw := initial.New(&initial.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
})
h := dnsserver.WithMiddlewares(
handler,
// Keep the rate limiting and access middlewares as the outer ones
// to make sure that the application logic isn't touched if the
// request is ratelimited or blocked by access settings.
rlm,
imw,
)
srvName := s.Name
var listeners []*listener
listeners, err = newListeners(c, s, h, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("server %q: %w", srvName, err)
}
servers[i] = &server{
name: srvName,
handler: h,
listeners: listeners,
}
}
return servers, nil
}
// newDeviceFinder returns a new [agd.DeviceFinder] for a server based on the
// configuration.
func newDeviceFinder(c *Config, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
if !g.ProfilesEnabled {
return agd.EmptyDeviceFinder{}
}
return devicefinder.NewDefault(&devicefinder.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
ProfileDB: c.ProfileDB,
HumanIDParser: c.HumanIDParser,
Server: s,
DeviceDomains: g.TLS.DeviceDomains,
})
}
// newServers creates a slice of listeners for a server.
func newListeners(
c *Config,
srv *agd.Server,
handler dnsserver.Handler,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (listeners []*listener, err error) {
bindData := srv.BindData()
listeners = make([]*listener, 0, len(bindData))
for i, bindData := range bindData {
var addr string
if bindData.PrefixAddr == nil {
addr = bindData.AddrPort.String()
} else {
addr = bindData.PrefixAddr.String()
}
proto := srv.Protocol
name := listenerName(srv.Name, addr, proto)
baseConf := dnsserver.ConfigBase{
Network: dnsserver.NetworkAny,
Handler: handler,
Metrics: errCollListener,
Disposer: c.Cloner,
RequestContext: newContextConstructor(c.HandleTimeout),
ListenConfig: newListenConfig(
bindData.ListenConfig,
c.ControlConf,
c.ConnLimiter,
proto,
),
Name: name,
Addr: addr,
}
var l Listener
l, err = newListener(srv, baseConf, c.NonDNS)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
listeners = append(listeners, &listener{
name: name,
Listener: l,
})
}
return listeners, nil
}
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
original netext.ListenConfig,
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if original != nil {
if l == nil {
return original
}
return connlimiter.NewListenConfig(original, l)
}
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {
lc = netext.DefaultListenConfig(ctrlConf)
}
if l != nil {
lc = connlimiter.NewListenConfig(lc, l)
}
return lc
}

View File

@ -10,26 +10,17 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testSrvGrpName is the [agd.ServerGroupName] for tests.
const testSrvGrpName agd.ServerGroupName = "test_group"
// type check
var _ agdservice.Refresher = (*forward.Handler)(nil)
@ -159,16 +150,23 @@ func TestService_Start(t *testing.T) {
AddrPort: netip.MustParseAddrPort("127.0.0.1:53"),
})
c := &dnssvc.Config{
BaseLogger: slogutil.NewDiscardLogger(),
NewListener: newTestListenerFunc(tl),
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
Handler: dnsservertest.DefaultHandler(),
MetricsNamespace: "test_start",
ServerGroups: []*agd.ServerGroup{{
Name: "test_group",
srvGrp := &agd.ServerGroup{
Name: dnssvctest.ServerGroupName,
Servers: []*agd.Server{srv},
}},
}
k := dnssvc.HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
c := &dnssvc.Config{
NewListener: newTestListenerFunc(tl),
Handlers: dnssvc.Handlers{
k: dnsservertest.NewDefaultHandler(),
},
MetricsNamespace: "test_start",
ServerGroups: []*agd.ServerGroup{srvGrp},
}
svc, err := dnssvc.New(c)
@ -206,15 +204,25 @@ func TestNew(t *testing.T) {
}),
}
c := &dnssvc.Config{
BaseLogger: slogutil.NewDiscardLogger(),
Handler: dnsservertest.DefaultHandler(),
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
MetricsNamespace: "test_new",
ServerGroups: []*agd.ServerGroup{{
Name: "test_group",
srvGrp := &agd.ServerGroup{
Name: dnssvctest.ServerGroupName,
Servers: srvs,
}},
}
handlers := dnssvc.Handlers{}
for _, srv := range srvs {
k := dnssvc.HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handlers[k] = dnsservertest.NewDefaultHandler()
}
c := &dnssvc.Config{
Handlers: handlers,
MetricsNamespace: "test_new",
ServerGroups: []*agd.ServerGroup{srvGrp},
}
svc, err := dnssvc.New(c)

211
internal/dnssvc/handler.go Normal file
View File

@ -0,0 +1,211 @@
package dnssvc
import (
"context"
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/cache"
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
"github.com/AdguardTeam/AdGuardDNS/internal/ecscache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// NewHandlers returns the main DNS handlers wrapped in all necessary
// middlewares. c must not be nil.
func NewHandlers(ctx context.Context, c *HandlersConfig) (handlers Handlers, err error) {
handler := wrapPreUpstreamMw(ctx, c)
mainMwMtrc, err := newMainMiddlewareMetrics(c)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
mainMw := mainmw.New(&mainmw.Config{
Cloner: c.Cloner,
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "mainmw"),
Messages: c.Messages,
BillStat: c.BillStat,
ErrColl: c.ErrColl,
FilterStorage: c.FilterStorage,
GeoIP: c.GeoIP,
QueryLog: c.QueryLog,
Metrics: mainMwMtrc,
RuleStat: c.RuleStat,
})
handler = mainMw.Wrap(handler)
preSvcMw := preservice.New(&preservice.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "presvcmw"),
Messages: c.Messages,
HashMatcher: c.HashMatcher,
Checker: c.DNSCheck,
})
handler = preSvcMw.Wrap(handler)
postInitMw := c.PluginRegistry.PostInitialMiddleware()
if postInitMw != nil {
handler = postInitMw.Wrap(handler)
}
initMw := initial.New(&initial.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
})
handler = initMw.Wrap(handler)
return newHandlersForServers(c, handler)
}
// wrapPreUpstreamMw returns the handler wrapped into the pre-upstream
// middlewares.
//
// TODO(a.garipov): Adapt the cache tests that previously were in package
// preupstream.
func wrapPreUpstreamMw(ctx context.Context, c *HandlersConfig) (wrapped dnsserver.Handler) {
// TODO(a.garipov): Use in other places if necessary.
l := c.BaseLogger.With(slogutil.KeyPrefix, "dnssvc")
wrapped = c.Handler
switch conf := c.Cache; conf.Type {
case CacheTypeNone:
l.WarnContext(ctx, "cache disabled")
case CacheTypeSimple:
l.InfoContext(ctx, "plain cache enabled", "count", conf.NoECSCount)
cacheMw := cache.NewMiddleware(&cache.MiddlewareConfig{
// TODO(a.garipov): Do not use promauto and refactor.
MetricsListener: dnssrvprom.NewCacheMetricsListener(metrics.Namespace()),
Size: conf.NoECSCount,
MinTTL: conf.MinTTL,
OverrideTTL: conf.OverrideCacheTTL,
})
wrapped = cacheMw.Wrap(wrapped)
case CacheTypeECS:
l.InfoContext(
ctx,
"ecs cache enabled",
"ecs_count", conf.ECSCount,
"no_ecs_count", conf.NoECSCount,
)
cacheMw := ecscache.NewMiddleware(&ecscache.MiddlewareConfig{
Cloner: c.Cloner,
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ecscache"),
CacheManager: c.CacheManager,
GeoIP: c.GeoIP,
NoECSCount: conf.NoECSCount,
ECSCount: conf.ECSCount,
MinTTL: conf.MinTTL,
OverrideTTL: conf.OverrideCacheTTL,
})
wrapped = cacheMw.Wrap(wrapped)
default:
panic(fmt.Errorf("cache type: %w: %d", errors.ErrBadEnumValue, conf.Type))
}
preUps := preupstream.New(ctx, &preupstream.Config{
DB: c.DNSDB,
})
wrapped = preUps.Wrap(wrapped)
return wrapped
}
// newMainMiddlewareMetrics returns a filtering-middleware metrics
// implementation from the config.
func newMainMiddlewareMetrics(c *HandlersConfig) (mainMwMtrc MainMiddlewareMetrics, err error) {
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
if mainMwMtrc != nil {
return mainMwMtrc, nil
}
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("mainmw metrics: %w", err)
}
return mainMwMtrc, nil
}
// newHandlersForServers returns a handler map for each server group and each
// server.
func newHandlersForServers(c *HandlersConfig, h dnsserver.Handler) (handlers Handlers, err error) {
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("ratelimit middleware metrics: %w", err)
}
handlers = Handlers{}
rlMwLogger := c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw")
for _, srvGrp := range c.ServerGroups {
fltGrp, ok := c.FilteringGroups[srvGrp.FilteringGroup]
if !ok {
return nil, fmt.Errorf(
"no filtering group %q for server group %q",
srvGrp.FilteringGroup,
srvGrp.Name,
)
}
for _, srv := range srvGrp.Servers {
rlMw := ratelimitmw.New(&ratelimitmw.Config{
Logger: rlMwLogger,
Messages: c.Messages,
FilteringGroup: fltGrp,
ServerGroup: srvGrp,
Server: srv,
StructuredErrors: c.StructuredErrors,
AccessManager: c.AccessManager,
DeviceFinder: newDeviceFinder(c, srvGrp, srv),
ErrColl: c.ErrColl,
GeoIP: c.GeoIP,
Metrics: rlMwMtrc,
Limiter: c.RateLimit,
Protocols: []agd.Protocol{agd.ProtoDNS},
EDEEnabled: c.EDEEnabled,
})
k := HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handlers[k] = rlMw.Wrap(h)
}
}
return handlers, nil
}
// newDeviceFinder returns a new agd.DeviceFinder for a server based on the
// configuration. All arguments must not be nil.
func newDeviceFinder(c *HandlersConfig, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
if !g.ProfilesEnabled {
return agd.EmptyDeviceFinder{}
}
return devicefinder.NewDefault(&devicefinder.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
ProfileDB: c.ProfileDB,
HumanIDParser: c.HumanIDParser,
Server: s,
DeviceDomains: g.TLS.DeviceDomains,
})
}

View File

@ -0,0 +1,188 @@
package dnssvc_test
import (
"context"
"net/netip"
"path"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewHandlers(t *testing.T) {
t.Parallel()
accessMgr := &agdtest.AccessManager{
OnIsBlockedHost: func(host string, qt uint16) (blocked bool) { panic("not implemented") },
OnIsBlockedIP: func(ip netip.Addr) (blocked bool) { panic("not implemented") },
}
billStat := &agdtest.BillStatRecorder{
OnRecord: func(
_ context.Context,
_ agd.DeviceID,
_ geoip.Country,
_ geoip.ASN,
_ time.Time,
_ agd.Protocol,
) {
panic("not implemented")
},
}
dnsCk := &agdtest.DNSCheck{
OnCheck: func(
_ context.Context,
_ *dns.Msg,
_ *agd.RequestInfo,
) (resp *dns.Msg, err error) {
panic("not implemented")
},
}
dnsDB := &agdtest.DNSDB{
OnRecord: func(_ context.Context, _ *dns.Msg, _ *agd.RequestInfo) {
panic("not implemented")
},
}
fltGrps := map[agd.FilteringGroupID]*agd.FilteringGroup{
dnssvctest.FilteringGroupID: {
ID: dnssvctest.FilteringGroupID,
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
RuleListsEnabled: true,
},
}
fltStrg := &agdtest.FilterStorage{
OnFilterFromContext: func(_ context.Context, _ *agd.RequestInfo) (f filter.Interface) {
panic("not implemented")
},
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
}
hashMatcher := &agdtest.HashMatcher{
OnMatchByPrefix: func(
_ context.Context,
_ string,
) (hashes []string, matched bool, err error) {
panic("not implemented")
},
}
queryLog := &agdtest.QueryLog{
OnWrite: func(_ context.Context, _ *querylog.Entry) (err error) {
panic("not implemented")
},
}
ruleStat := &agdtest.RuleStat{
OnCollect: func(_ context.Context, _ agd.FilterListID, _ agd.FilterRuleText) {
panic("not implemented")
},
}
srv := dnssvctest.NewServer(dnssvctest.ServerName, agd.ProtoDoT, &agd.ServerBindData{
AddrPort: dnssvctest.ServerAddrPort,
})
srvGrp := &agd.ServerGroup{
DDR: &agd.DDR{
Enabled: true,
},
TLS: &agd.TLS{},
Name: dnssvctest.ServerGroupName,
FilteringGroup: dnssvctest.FilteringGroupID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}
testCases := []struct {
cacheConf *dnssvc.CacheConfig
name string
}{{
cacheConf: &dnssvc.CacheConfig{
Type: dnssvc.CacheTypeNone,
},
name: "no_cache",
}, {
cacheConf: &dnssvc.CacheConfig{
MinTTL: 10 * time.Second,
NoECSCount: 100,
Type: dnssvc.CacheTypeSimple,
OverrideCacheTTL: true,
},
name: "cache_simple",
}, {
cacheConf: &dnssvc.CacheConfig{
MinTTL: 10 * time.Second,
ECSCount: 100,
NoECSCount: 100,
Type: dnssvc.CacheTypeECS,
OverrideCacheTTL: true,
},
name: "cache_ecs",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.ContextWithTimeout(t, dnssvctest.Timeout)
handlers, err := dnssvc.NewHandlers(ctx, &dnssvc.HandlersConfig{
BaseLogger: slogutil.NewDiscardLogger(),
Cloner: agdtest.NewCloner(),
Cache: tc.cacheConf,
HumanIDParser: agd.NewHumanIDParser(),
Messages: agdtest.NewConstructor(t),
PluginRegistry: nil,
StructuredErrors: agdtest.NewSDEConfig(true),
AccessManager: accessMgr,
BillStat: billStat,
// TODO(a.garipov): Create a test implementation?
CacheManager: agdcache.EmptyManager{},
DNSCheck: dnsCk,
DNSDB: dnsDB,
ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg,
GeoIP: agdtest.NewGeoIP(),
Handler: dnsservertest.NewPanicHandler(),
HashMatcher: hashMatcher,
ProfileDB: agdtest.NewProfileDB(),
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
QueryLog: queryLog,
RateLimit: agdtest.NewRateLimit(),
RuleStat: ruleStat,
MetricsNamespace: path.Base(t.Name()),
FilteringGroups: fltGrps,
ServerGroups: []*agd.ServerGroup{srvGrp},
EDEEnabled: true,
})
require.NoError(t, err)
assert.Len(t, handlers, 1)
for k, v := range handlers {
assert.Same(t, srv, k.Server)
assert.Same(t, srvGrp, k.ServerGroup)
assert.NotNil(t, v)
break
}
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
@ -19,6 +20,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/logutil/slogutil"
@ -46,7 +48,7 @@ import (
// from the service. The channels must not be nil. Each sending to a channel
// wrapped with [testutil.RequireSend] using [dnssvctest.Timeout].
//
// It also uses the [dnsservertest.DefaultHandler] to create the DNS handler.
// It also uses the [dnsservertest.NewDefaultHandler] to create the DNS handler.
func newTestService(
t testing.TB,
flt filter.Interface,
@ -81,46 +83,14 @@ func newTestService(
QueryLogEnabled: true,
}
db := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = func(
_ context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
testutil.RequireSend(pt, profileDBCh, id, dnssvctest.Timeout)
return prof, dev, nil
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
accessManager := &agdtest.AccessManager{
@ -145,18 +115,12 @@ func newTestService(
Continent: geoip.ContinentEU,
ASN: 42,
}
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(host string, _ netip.Addr) (l *geoip.Location, err error) {
geoIP := agdtest.NewGeoIP()
geoIP.OnData = func(host string, _ netip.Addr) (l *geoip.Location, err error) {
testutil.RequireSend(pt, geoIPCh, host, dnssvctest.Timeout)
return loc, nil
},
}
fltStrg := &agdtest.FilterStorage{
@ -205,25 +169,38 @@ func newTestService(
},
}
rl := &agdtest.RateLimit{
OnIsRateLimited: func(
rl := agdtest.NewRateLimit()
rl.OnIsRateLimited = func(
_ context.Context,
_ *dns.Msg,
_ netip.Addr,
) (drop, allowlisted bool, err error) {
) (shouldDrop, isAllowlisted bool, err error) {
return true, false, nil
},
OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {
panic("not implemented")
},
}
testFltGrpID := agd.FilteringGroupID("1234")
srvGrps := []*agd.ServerGroup{{
DDR: &agd.DDR{
Enabled: true,
},
TLS: &agd.TLS{
DeviceDomains: []string{dnssvctest.DomainForDevices},
},
Name: dnssvctest.ServerGroupName,
FilteringGroup: dnssvctest.FilteringGroupID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}}
c := &dnssvc.Config{
hdlrConf := &dnssvc.HandlersConfig{
BaseLogger: slogutil.NewDiscardLogger(),
AccessManager: accessManager,
Cache: &dnssvc.CacheConfig{
Type: dnssvc.CacheTypeNone,
},
StructuredErrors: agdtest.NewSDEConfig(true),
Cloner: agdtest.NewCloner(),
HumanIDParser: agd.NewHumanIDParser(),
Messages: agdtest.NewConstructor(t),
AccessManager: accessManager,
BillStat: &agdtest.BillStatRecorder{
OnRecord: func(
_ context.Context,
@ -235,42 +212,47 @@ func newTestService(
) {
},
},
ProfileDB: db,
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
CacheManager: agdcache.EmptyManager{},
DNSCheck: dnsCk,
NonDNS: http.NotFoundHandler(),
DNSDB: dnsDB,
ErrColl: errColl,
FilterStorage: fltStrg,
GeoIP: geoIP,
Handler: dnsservertest.NewDefaultHandler(),
HashMatcher: hashprefix.NewMatcher(nil),
ProfileDB: profDB,
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
QueryLog: ql,
RuleStat: ruleStat,
NewListener: newTestListenerFunc(tl),
Handler: dnsservertest.DefaultHandler(),
RateLimit: rl,
RuleStat: ruleStat,
MetricsNamespace: path.Base(t.Name()),
FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{
testFltGrpID: {
ID: testFltGrpID,
dnssvctest.FilteringGroupID: {
ID: dnssvctest.FilteringGroupID,
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
RuleListsEnabled: true,
},
},
ServerGroups: []*agd.ServerGroup{{
DDR: &agd.DDR{
Enabled: true,
},
TLS: &agd.TLS{
DeviceDomains: []string{dnssvctest.DomainForDevices},
},
Name: testSrvGrpName,
FilteringGroup: testFltGrpID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}},
ServerGroups: srvGrps,
EDEEnabled: true,
}
svc, err := dnssvc.New(c)
ctx := context.Background()
handlers, err := dnssvc.NewHandlers(ctx, hdlrConf)
require.NoError(t, err)
c := &dnssvc.Config{
Handlers: handlers,
NewListener: newTestListenerFunc(tl),
Cloner: agdtest.NewCloner(),
ErrColl: errColl,
NonDNS: http.NotFoundHandler(),
MetricsNamespace: path.Base(t.Name()),
ServerGroups: srvGrps,
HandleTimeout: dnssvctest.Timeout,
}
svc, err = dnssvc.New(c)
require.NoError(t, err)
require.NotNil(t, svc)
@ -283,6 +265,7 @@ func newTestService(
return svc, srvAddr
}
// TODO(a.garipov): Refactor to test handlers separately from the service.
func TestService_Wrap(t *testing.T) {
profileDBCh := make(chan agd.DeviceID, 1)
querylogCh := make(chan *querylog.Entry, 1)
@ -346,7 +329,7 @@ func TestService_Wrap(t *testing.T) {
TLSServerName: dnssvctest.DeviceIDSrvName,
})
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req)
err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
require.NoError(t, err)
resp := rw.Msg()
@ -420,7 +403,7 @@ func TestService_Wrap(t *testing.T) {
TLSServerName: dnssvctest.DeviceIDSrvName,
})
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req)
err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
require.NoError(t, err)
resp := rw.Msg()

View File

@ -80,35 +80,9 @@ func TestDefault_Find_plainAddrs(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr),
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: newOnProfileByLinkedIP(dnssvctest.LinkedAddr),
}
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDedicatedIP = newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr)
profDB.OnProfileByLinkedIP = newOnProfileByLinkedIP(dnssvctest.LinkedAddr)
df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(),
@ -177,40 +151,8 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(),
@ -231,44 +173,12 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
func TestDefault_Find_deleted(t *testing.T) {
t.Parallel()
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
profDB := agdtest.NewProfileDB()
profDB.OnProfileByLinkedIP = func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return profDeleted, devNormal, nil
},
}
df := devicefinder.NewDefault(&devicefinder.Config{
@ -291,44 +201,21 @@ func TestDefault_Find_byHumanID(t *testing.T) {
// device-type and profile data regardless of the case.
extIDStr := "OTR-" + strings.ToUpper(dnssvctest.ProfileIDStr) + "-" + dnssvctest.HumanIDStr + "-!!!"
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
profDB := agdtest.NewProfileDB()
profDB.OnCreateAutoDevice = func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
return profNormal, devAuto, nil
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
}
profDB.OnProfileByHumanID = func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
df := devicefinder.NewDefault(&devicefinder.Config{

View File

@ -2,7 +2,6 @@ package devicefinder_test
import (
"context"
"net/netip"
"net/url"
"path"
"testing"
@ -88,24 +87,8 @@ func TestDefault_Find_DoHAuth(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
@ -114,22 +97,6 @@ func TestDefault_Find_DoHAuth(t *testing.T) {
}
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
df := devicefinder.NewDefault(&devicefinder.Config{
@ -220,44 +187,12 @@ func TestDefault_Find_DoHAuthOnly(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return profNormal, tc.profDBDev, nil
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
df := devicefinder.NewDefault(&devicefinder.Config{
@ -351,34 +286,9 @@ func TestDefault_Find_DoH(t *testing.T) {
name: "human_id_path_match",
}}
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@ -450,34 +360,9 @@ func TestDefault_Find_stdEncrypted(t *testing.T) {
deviceDomains: []string{dnssvctest.DomainForDevices},
}}
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
profDB := agdtest.NewProfileDB()
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
srvData := []struct {
srv *agd.Server

View File

@ -1,8 +1,6 @@
package devicefinder_test
import (
"context"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
@ -49,53 +47,13 @@ func TestDefault_Find_humanID(t *testing.T) {
in: "otr-abcd1234-!!!",
}}
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(),
ProfileDB: profDB,
ProfileDB: agdtest.NewProfileDB(),
HumanIDParser: agd.NewHumanIDParser(),
Server: srvDoT,
DeviceDomains: []string{dnssvctest.DomainForDevices},

View File

@ -55,8 +55,16 @@ const (
DomainRewrittenCNAMEFQDN = DomainRewrittenCNAME + "."
)
// ServerName is the common server name for tests.
const ServerName agd.ServerName = "test_server_dns_tls"
const (
// FilteringGroupID is the common filtering-group ID for tests.
FilteringGroupID agd.FilteringGroupID = "test_filtering_group"
// ServerName is the common server name for tests.
ServerName agd.ServerName = "test_server_dns_tls"
// ServerGroupName is the common server-group name for tests.
ServerGroupName agd.ServerGroupName = "test_server_group"
)
const (
// DomainForDevices is the upper-level domain name for requests with device

View File

@ -29,8 +29,8 @@ const (
// Resolvers for querying the resolver with unknown or absent name.
DDRDomain = DDRLabel + "." + ResolverARPADomain
// FirefoxCanaryHost is the hostname that Firefox uses to check if it
// should use its own DNS-over-HTTPS settings.
// FirefoxCanaryHost is the hostname that Firefox uses to check if it should
// use its own DNS-over-HTTPS settings.
//
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
FirefoxCanaryHost = "use-application-dns.net"
@ -56,6 +56,11 @@ func (mw *Middleware) reqInfoSpecialHandler(
return nil, ""
}
// As per RFC-9462 section 6.4, resolvers SHOULD respond to queries of any
// type other than SVCB for _dns.resolver.arpa. with NODATA and queries of
// any type for any domain name under resolver.arpa with NODATA.
//
// TODO(e.burkov): Consider adding SOA records for these NODATA responses.
if mw.isDDRRequest(ri) {
if _, ok := ri.DeviceResult.(*agd.DeviceResultAuthenticationFailure); ok {
return mw.handleDDRNoData, "ddr_doh"
@ -84,8 +89,6 @@ type reqInfoHandlerFunc func(
ri *agd.RequestInfo,
) (err error)
// DDR And Resolver ARPA Domain
// isDDRRequest determines if the message is the request for Discovery of
// Designated Resolvers as defined by the RFC draft. The request is considered
// ARPA if the requested host is a subdomain of resolver.arpa SUDN.
@ -141,8 +144,8 @@ func isDDRDomain(ri *agd.RequestInfo, host string) (ok bool) {
return false
}
// handleDDR checks if the request is for the Discovery of Designated Resolvers
// and writes a response if needed.
// handleDDR responds to Discovery of Designated Resolvers (DDR) queries with a
// response containing the designated resolvers.
func (mw *Middleware) handleDDR(
ctx context.Context,
rw dnsserver.ResponseWriter,
@ -157,11 +160,11 @@ func (mw *Middleware) handleDDR(
return rw.WriteMsg(ctx, req, mw.newRespDDR(req, ri))
}
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
}
// handleDDRNoData processes DDR (Discovery of Designated Resolvers) requests
// for devices which need NODATA response and writes the response if needed.
// handleDDRNoData responds to Discovery of Designated Resolvers (DDR) queries
// with a NODATA response.
func (mw *Middleware) handleDDRNoData(
ctx context.Context,
rw dnsserver.ResponseWriter,
@ -173,17 +176,17 @@ func (mw *Middleware) handleDDRNoData(
metrics.DNSSvcDDRRequestsTotal.Inc()
if ri.ServerGroup.DDR.Enabled {
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNODATA(req))
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeSuccess))
}
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
}
// newRespDDR returns a new Discovery of Designated Resolvers response copying
// it from the prebuilt templates in srvGrp and modifying it in accordance with
// the request data. req must not be nil.
func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.Msg) {
resp = ri.Messages.NewRespMsg(req)
resp = ri.Messages.NewResp(req)
name := req.Question[0].Name
ddr := ri.ServerGroup.DDR
@ -210,7 +213,8 @@ func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.M
return resp
}
// handleBadResolverARPA writes a NODATA response.
// handleBadResolverARPA responds to badly formed resolver.arpa queries with a
// NODATA response.
func (mw *Middleware) handleBadResolverARPA(
ctx context.Context,
rw dnsserver.ResponseWriter,
@ -219,7 +223,8 @@ func (mw *Middleware) handleBadResolverARPA(
) (err error) {
metrics.DNSSvcBadResolverARPA.Inc()
err = rw.WriteMsg(ctx, req, ri.Messages.NewRespMsg(req))
resp := ri.Messages.NewRespRCode(req, dns.RcodeSuccess)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
}
@ -257,8 +262,6 @@ func (mw *Middleware) specialDomainHandler(
return nil, ""
}
// Apple Private Relay
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
// Relay check domain and the request information or profile indicates that
// Apple Private Relay should be blocked.
@ -280,13 +283,12 @@ func (mw *Middleware) handlePrivateRelay(
) (err error) {
metrics.DNSSvcApplePrivateRelayRequestsTotal.Inc()
err = rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
resp := ri.Messages.NewRespRCode(req, dns.RcodeNameError)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing private relay resp: %w")
}
// Firefox canary domain
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
// domain and the request information or profile indicates that Firefox canary
// domain should be blocked.
@ -298,8 +300,8 @@ func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool)
return ri.FilteringGroup.BlockFirefoxCanary
}
// handleFirefoxCanary checks if the request is for the fully-qualified domain
// name that Firefox uses to check DoH settings and writes a response if needed.
// handleFirefoxCanary responds to Firefox canary domain queries with a REFUSED
// response.
func (mw *Middleware) handleFirefoxCanary(
ctx context.Context,
rw dnsserver.ResponseWriter,
@ -308,7 +310,7 @@ func (mw *Middleware) handleFirefoxCanary(
) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := ri.Messages.NewMsgREFUSED(req)
resp := ri.Messages.NewRespRCode(req, dns.RcodeRefused)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w")

View File

@ -44,7 +44,9 @@ func TestMiddleware_writeDebugResponse(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)

View File

@ -162,7 +162,7 @@ func (mw *Middleware) setFilteredResponse(
mw.setFilteredResponseNoReq(ctx, fctx, ri)
case *filter.ResultBlocked:
var err error
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest)
fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
if err != nil {
mw.reportf(ctx, "creating blocked resp for filtered req: %w", err)
fctx.filteredResponse = fctx.originalResponse
@ -199,7 +199,7 @@ func (mw *Middleware) setFilteredResponseNoReq(
fctx.filteredResponse = fctx.originalResponse
case *filter.ResultBlocked:
var err error
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest)
fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
if err != nil {
mw.reportf(ctx, "creating blocked resp for filtered resp: %w", err)
fctx.filteredResponse = fctx.originalResponse

View File

@ -4,6 +4,7 @@ package mainmw
import (
"context"
"log/slog"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
@ -14,7 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors"
@ -24,14 +25,15 @@ import (
// Middleware is the main middleware of AdGuard DNS.
type Middleware struct {
messages *dnsmsg.Constructor
cloner *dnsmsg.Cloner
fltCtxPool *syncutil.Pool[filteringContext]
metrics Metrics
logger *slog.Logger
messages *dnsmsg.Constructor
billStat billstat.Recorder
errColl errcoll.Interface
fltStrg filter.Storage
geoIP geoip.Interface
metrics Metrics
queryLog querylog.Interface
ruleStat rulestat.Interface
}
@ -39,19 +41,17 @@ type Middleware struct {
// Config is the configuration structure for the main middleware. All fields
// must be non-nil.
type Config struct {
// Metrics is used to collect the statistics.
Metrics Metrics
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse.
Cloner *dnsmsg.Cloner
// Logger is used to log the operation of the middleware.
Logger *slog.Logger
// Messages is the message constructor used to create blocked and other
// messages for this middleware.
Messages *dnsmsg.Constructor
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse.
//
// TODO(a.garipov): Use.
Cloner *dnsmsg.Cloner
// BillStat is used to collect billing statistics.
BillStat billstat.Recorder
@ -66,6 +66,9 @@ type Config struct {
// addresses in requests and responses.
GeoIP geoip.Interface
// Metrics is used to collect the statistics.
Metrics Metrics
// QueryLog is used to write the logs into.
QueryLog querylog.Interface
@ -77,16 +80,17 @@ type Config struct {
// New returns a new main middleware. c must not be nil.
func New(c *Config) (mw *Middleware) {
return &Middleware{
metrics: c.Metrics,
messages: c.Messages,
cloner: c.Cloner,
fltCtxPool: syncutil.NewPool(func() (v *filteringContext) {
return &filteringContext{}
}),
logger: c.Logger,
messages: c.Messages,
billStat: c.BillStat,
errColl: c.ErrColl,
fltStrg: c.FilterStorage,
geoIP: c.GeoIP,
metrics: c.Metrics,
queryLog: c.QueryLog,
ruleStat: c.RuleStat,
}
@ -106,8 +110,20 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
defer mw.fltCtxPool.Put(fctx)
ri := agd.MustRequestInfoFromContext(ctx)
optlog.Debug2("processing request %q from %s", ri.ID, ri.RemoteIP)
defer optlog.Debug2("finished processing request %q from %s", ri.ID, ri.RemoteIP)
optslog.Debug2(
ctx,
mw.logger,
"processing request",
"req_id", ri.ID,
"remote_ip", ri.RemoteIP,
)
defer optslog.Debug2(
ctx,
mw.logger,
"finished processing request",
"req_id", ri.ID,
"remote_ip", ri.RemoteIP,
)
flt := mw.fltStrg.FilterFromContext(ctx, ri)
mw.filterRequest(ctx, fctx, flt, ri)
@ -184,10 +200,12 @@ func (mw *Middleware) nextParams(
ctx = agd.ContextWithRequestInfo(ctx, modReqInfo)
optlog.Debug2(
"mainmw: request for %q rewritten to %q by CNAME rewrite rule",
ri.Host,
modReqInfo.Host,
optslog.Debug2(
ctx,
mw.logger,
"request rewritten by cname rewrite rule",
"orig_host", ri.Host,
"mod_host", modReqInfo.Host,
)
return ctx, origRW, modReq

View File

@ -17,17 +17,13 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// Common constants for tests.
const (
testASN geoip.ASN = 12345
@ -121,14 +117,8 @@ func TestMiddleware_Wrap(t *testing.T) {
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
}
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) {
geoIP := agdtest.NewGeoIP()
geoIP.OnData = func(host string, addr netip.Addr) (l *geoip.Location, err error) {
pt := testutil.PanicT{}
require.Equal(pt, dnssvctest.Domain, host)
if addr.Is4() {
@ -138,7 +128,6 @@ func TestMiddleware_Wrap(t *testing.T) {
}
return nil, nil
},
}
ruleStat := &agdtest.RuleStat{
@ -153,7 +142,9 @@ func TestMiddleware_Wrap(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
@ -227,13 +218,14 @@ func TestMiddleware_Wrap(t *testing.T) {
}
c := &mainmw.Config{
Metrics: mainmw.EmptyMetrics{},
Messages: msgs,
Cloner: cloner,
Logger: slogutil.NewDiscardLogger(),
Messages: msgs,
BillStat: tc.billStat,
ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg,
GeoIP: geoIP,
Metrics: mainmw.EmptyMetrics{},
QueryLog: queryLog,
RuleStat: ruleStat,
}
@ -411,16 +403,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
}
)
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) {
geoIP := agdtest.NewGeoIP()
geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
return nil, nil
},
}
var (
@ -537,7 +522,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
@ -701,13 +688,14 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
}
c := &mainmw.Config{
Metrics: mainmw.EmptyMetrics{},
Messages: msgs,
Cloner: cloner,
Logger: slogutil.NewDiscardLogger(),
Messages: msgs,
BillStat: tc.billStat,
ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg,
GeoIP: geoIP,
Metrics: mainmw.EmptyMetrics{},
QueryLog: queryLog,
RuleStat: ruleStat,
}

View File

@ -12,7 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
@ -112,8 +112,7 @@ func (mw *Middleware) responseCountry(
}
ctry = mw.country(ctx, host, respIP)
// TODO(a.garipov): Use optslog.Trace2.
optlog.Debug2("mainmw: got ctry %q for resp ip %v", ctry, respIP)
optslog.Trace2(ctx, mw.logger, "geoip for resp", "ctry", ctry, "resp_ip", respIP)
return ctry
}

View File

@ -16,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -134,20 +134,13 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
Country: testCtry,
}
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
geoIP := agdtest.NewGeoIP()
geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
if !tc.wantGeoIP {
t.Error("unexpected call to geoip")
}
return loc, nil
},
}
queryLogCalled := false
@ -164,6 +157,7 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
}
mw := &Middleware{
logger: slogutil.NewDiscardLogger(),
billStat: billstat.EmptyRecorder{},
geoIP: geoIP,
queryLog: queryLog,

View File

@ -5,15 +5,16 @@ package preservice
import (
"context"
"fmt"
"log/slog"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
)
@ -22,19 +23,18 @@ import (
// names that may be filtered by safe browsing or parental control filters as
// well as handling of the DNS-server check queries.
type Middleware struct {
// messages is used to construct TXT responses.
logger *slog.Logger
messages *dnsmsg.Constructor
// hashMatcher is the safe browsing DNS hashMatcher.
hashMatcher filter.HashMatcher
// checker is used to detect and process DNS-check requests.
checker dnscheck.Interface
}
// Config is the configurational structure for the preservice middleware. All
// fields must be non-nil.
type Config struct {
// Logger is used to log the operation of the middleware.
Logger *slog.Logger
// Messages is used to construct TXT responses.
Messages *dnsmsg.Constructor
@ -48,6 +48,7 @@ type Config struct {
// New returns a new preservice middleware. c must not be nil.
func New(c *Config) (mw *Middleware) {
return &Middleware{
logger: c.Logger,
messages: c.Messages,
hashMatcher: c.HashMatcher,
checker: c.Checker,
@ -60,7 +61,7 @@ var _ dnsserver.Middleware = (*Middleware)(nil)
// Wrap implements the [dnsserver.Middleware] interface for *Middleware.
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
defer func() { err = errors.Annotate(err, "preservice mw: %w") }()
defer func() { err = errors.Annotate(err, "presvcmw: %w") }()
ri := agd.MustRequestInfoFromContext(ctx)
if ri.QType == dns.TypeTXT {
@ -71,15 +72,22 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
resp, err := mw.checker.Check(ctx, req, ri)
if err != nil {
return fmt.Errorf("calling dnscheck: %w", err)
} else if resp != nil {
return errors.Annotate(rw.WriteMsg(ctx, req, resp), "writing dnscheck response: %w")
}
if resp == nil {
// Don't wrap the error, because this is the main flow, and there is
// already [errors.Annotate] here.
return next.ServeDNS(ctx, rw, req)
}
err = rw.WriteMsg(ctx, req, resp)
if err != nil {
return fmt.Errorf("writing dnscheck response: %w", err)
}
return nil
}
return dnsserver.HandlerFunc(f)
}
@ -92,15 +100,19 @@ func (mw *Middleware) respondWithHashes(
req *dns.Msg,
ri *agd.RequestInfo,
) (err error) {
optlog.Debug1("preservice mw: safe browsing: got txt req for %q", ri.Host)
optslog.Debug1(ctx, mw.logger, "got txt req", "host", ri.Host)
hashes, matched, err := mw.hashMatcher.MatchByPrefix(ctx, ri.Host)
if err != nil {
// Don't return or collect this error to prevent DDoS of the error
// collector by sending bad requests.
log.Error("preservice mw: safe browsing: matching hashes: %s", err)
mw.logger.ErrorContext(
ctx,
"matching hashes",
slogutil.KeyError, err,
)
resp := mw.messages.NewMsgREFUSED(req)
resp := mw.messages.NewRespRCode(req, dns.RcodeRefused)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing refused response: %w")
@ -110,7 +122,7 @@ func (mw *Middleware) respondWithHashes(
return next.ServeDNS(ctx, rw, req)
}
resp, err := mw.messages.NewTXTRespMsg(req, hashes...)
resp, err := mw.messages.NewRespTXT(req, hashes...)
if err != nil {
// Technically should never happen since the only error that could arise
// in [dnsmsg.Constructor.NewTXTRespMsg] is the one about request type
@ -118,7 +130,7 @@ func (mw *Middleware) respondWithHashes(
return fmt.Errorf("creating safe browsing result: %w", err)
}
optlog.Debug1("preservice mw: safe browsing: writing hashes %q", hashes)
optslog.Debug1(ctx, mw.logger, "writing hashes", "hashes", hashes)
err = rw.WriteMsg(ctx, req, resp)
if err != nil {

Some files were not shown because too many files have changed in this diff Show More