Sync v2.10.0

This commit is contained in:
Andrey Meshkov 2024-11-08 16:26:22 +03:00
parent da0cb6fd0e
commit 87137bddcf
150 changed files with 8610 additions and 6094 deletions

View File

@ -7,6 +7,94 @@ The format is **not** based on [Keep a Changelog][kec], since the project **does
[kec]: https://keepachangelog.com/en/1.0.0/ [kec]: https://keepachangelog.com/en/1.0.0/
[sem]: https://semver.org/spec/v2.0.0.html [sem]: https://semver.org/spec/v2.0.0.html
## AGDNS-2484/ Build 886
- Property `type` of the `ratelimit` object has been moved to the underlying `allowlist` object. So replace this:
```yaml
ratelimit:
type: 'consul'
# …
allowlist:
# …
```
with this:
```yaml
ratelimit:
# …
allowlist:
type: 'consul'
# …
```
## AGDNS-2443 / Build 877
- The object `filters` has new properties: `ede_enabled`, and `sde_enabled`. So replace this:
```yaml
filters:
# …
```
with this:
```yaml
filters:
# …
ede_enabled: true
sde_enabled: true
```
## AGDNS-2456 / Build 873
- The environment variables `BACKEND_RATELIMIT_URL` and `BACKEND_RATELIMIT_API_KEY` have been added.
- Added the `type` property within the `ratelimit` object. So add it:
```yaml
ratelimit:
type: 'consul'
# …
```
## AGDNS-2431 / Build 872
- The objects `ratelimit.ipv4` and `ratelimit.ipv6` have been modified. Its `rps` properties have been replaced with the new properties `count` and `interval`. So replace this:
```yaml
ratelimit:
# …
ipv4:
rps: 30
ipv6:
rps: 300
```
with this:
```yaml
ratelimit:
# …
ipv4:
# …
count: 300
interval: 10s
ipv6:
# …
count: 3000
interval: 10s
```
Adjust the value and add new ones, if necessary.
## AGDNS-2457 / Build 871
- The environment variables `DNSCHECK_REMOTEKV_URL` and `DNSCHECK_REMOTEKV_API_KEY` have been added.
- The property `kv.type` within the `check` object now supports the `backend` value.
## AGDNS-2468 / Build 869 ## AGDNS-2468 / Build 869
- The environment variable `PROFILES_MAX_RESP_SIZE` has been added. It sets the maximum size of the response from the profiles endpoint of the backend API. The default value is `8MB`. - The environment variable `PROFILES_MAX_RESP_SIZE` has been added. It sets the maximum size of the response from the profiles endpoint of the backend API. The default value is `8MB`.

View File

@ -24,7 +24,7 @@ BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)}
GOAMD64 = v1 GOAMD64 = v1
GOPROXY = https://proxy.golang.org|direct GOPROXY = https://proxy.golang.org|direct
GOTELEMETRY = off GOTELEMETRY = off
GOTOOLCHAIN = go1.23.1 GOTOOLCHAIN = go1.23.2
RACE = 0 RACE = 0
REVISION = $${REVISION:-$$(git rev-parse --short HEAD)} REVISION = $${REVISION:-$$(git rev-parse --short HEAD)}
VERSION = 0 VERSION = 0

View File

@ -10,15 +10,19 @@ ratelimit:
response_size_estimate: 1KB response_size_estimate: 1KB
# Rate limit options for IPv4 addresses. # Rate limit options for IPv4 addresses.
ipv4: ipv4:
# Rate of requests per second for one subnet for IPv4 addresses. # Requests per configured interval for one subnet for IPv4 addresses.
rps: 30 count: 300
# The time during which to count the number of requests.
interval: 10s
# The lengths of the subnet prefixes used to calculate rate limiter # The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv4 addresses. # bucket keys for IPv4 addresses.
subnet_key_len: 24 subnet_key_len: 24
# Rate limit options for IPv6 addresses. # Rate limit options for IPv6 addresses.
ipv6: ipv6:
# Rate of requests per second for one subnet for IPv6 addresses. # Requests per configured interval for one subnet for IPv6 addresses.
rps: 300 count: 3000
# The time during which to count the number of requests.
interval: 10s
# The lengths of the subnet prefixes used to calculate rate limiter # The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv6 addresses. # bucket keys for IPv6 addresses.
subnet_key_len: 48 subnet_key_len: 48
@ -40,6 +44,9 @@ ratelimit:
- '127.0.0.1/24' - '127.0.0.1/24'
# Time between two updates of allow list. # Time between two updates of allow list.
refresh_interval: 1h refresh_interval: 1h
# Defines where the rate limiting settings are received from. Allowed
# values are "backend" and "consul".
type: 'consul'
# Configuration for the stream connection limiting. # Configuration for the stream connection limiting.
connection_limit: connection_limit:
@ -167,8 +174,8 @@ geoip:
check: check:
# Domains to use for DNS checking. # Domains to use for DNS checking.
kv: kv:
# Defines the type of remote kay-value storage. Allowed values are # Defines the type of remote key-value storage. Allowed values are
# "consul" and "redis". # "backend", "consul", and "redis".
type: 'consul' type: 'consul'
# For how long to keep the information about the client. # For how long to keep the information about the client.
ttl: 30s ttl: 30s
@ -313,6 +320,10 @@ filters:
enabled: true enabled: true
# The size of the LRU cache of rule-list filtering results. # The size of the LRU cache of rule-list filtering results.
size: 10000 size: 10000
# Enable the Extended DNS Errors feature.
ede_enabled: true
# Enable the Structured DNS Errors feature. Requires ede_enabled: true.
sde_enabled: true
# Filtering groups are a set of different filtering configurations. These # Filtering groups are a set of different filtering configurations. These
# filtering configurations are then used by server_groups. # filtering configurations are then used by server_groups.

View File

@ -104,9 +104,13 @@ The `ratelimit` object has the following properties:
- <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>: The ipv4 configuration object. It has the following fields: - <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>: The ipv4 configuration object. It has the following fields:
- <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>: The rate of requests per second for one subnet. Requests above this are counted in the backoff count. - <a href="#ratelimit-ipv4-count" id="ratelimit-ipv4-count" name="ratelimit-ipv4-count">`count`</a>: Requests per configured interval for one subnet for IPv4 addresses. Requests above this are counted in the backoff count.
**Example:** `30`. **Example:** `300`.
- <a href="#ratelimit-ipv4-interval" id="ratelimit-ipv4-interval" name="ratelimit-ipv4-interval">`interval`</a>: The time during which to count the number of requests.
**Example:** `10s`.
- <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>: The length of the subnet prefix used to calculate rate limiter bucket keys. - <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>: The length of the subnet prefix used to calculate rate limiter bucket keys.
@ -134,7 +138,11 @@ The `ratelimit` object has the following properties:
**Example:** `30s`. **Example:** `30s`.
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, and `ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`. - <a href="#ratelimit-allowlist-type" id="ratelimit-allowlist-type" name="ratelimit-allowlist-type">`type`</a>: Defines where the rate limit settings are received from. Allowed values are `backend` and `consul`.
**Example:** `consul`.
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, `ipv4-count` is `5`, and `ipv4-interval` is `1s`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`.
### <a href="#ratelimit-connection_limit" id="ratelimit-connection_limit" name="ratelimit-connection_limit">Stream connection limit</a> ### <a href="#ratelimit-connection_limit" id="ratelimit-connection_limit" name="ratelimit-connection_limit">Stream connection limit</a>
@ -370,12 +378,14 @@ The `check` object has the following properties:
- <a href="#check_kv" id="check_kv" name="check_kv">`kv`</a>: Remote key-value storage settings. It has the following properties: - <a href="#check_kv" id="check_kv" name="check_kv">`kv`</a>: Remote key-value storage settings. It has the following properties:
- <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `consul` and `redis`. - <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `backend`, `consul`, and `redis`.
**Example:** `consul`. **Example:** `consul`.
- <a href="#check-kv-ttl" id="check-kv-ttl" name="check-kv-ttl">`ttl`</a>: For how long to keep the information about a single user in remote KV, as a human-readable duration. - <a href="#check-kv-ttl" id="check-kv-ttl" name="check-kv-ttl">`ttl`</a>: For how long to keep the information about a single user in remote KV, as a human-readable duration.
For `backend`, the TTL must be greater than `0s`.
For `consul`, the TTL must be between `10s` and `1d`. Note that the actual TTL can be up to twice as long. For `consul`, the TTL must be between `10s` and `1d`. Note that the actual TTL can be up to twice as long.
For `redis`, the TTL must be greater than or equal to `1ms`. For `redis`, the TTL must be greater than or equal to `1ms`.
@ -592,6 +602,14 @@ The `filters` object has the following properties:
**Example:** `10000`. **Example:** `10000`.
- <a href="#filters-ede_enabled" id="filters-ede_enabled" name="filters-ede_enabled">`ede_enabled`</a>: Shows if Extended DNS Error codes should be added.
**Example:** `true`.
- <a href="#filters-sde_enabled" id="filters-sde_enabled" name="filters-sde_enabled">`sde_enabled`</a>: Shows if the experimental Structured DNS Errors feature should be enabled. `ede_enabled` must be `true` to enable SDE.
**Example:** `true`.
[env-blocked_services]: environment.md#BLOCKED_SERVICE_INDEX_URL [env-blocked_services]: environment.md#BLOCKED_SERVICE_INDEX_URL
## <a href="#filtering_groups" id="filtering_groups" name="filtering_groups">Filtering groups</a> ## <a href="#filtering_groups" id="filtering_groups" name="filtering_groups">Filtering groups</a>

View File

@ -159,7 +159,7 @@ You'll need to supply the following:
See the [external HTTP API documentation][externalhttp]. See the [external HTTP API documentation][externalhttp].
You may use `go run ./scripts/backend` to start mock GRPC server for `BILLSTAT_URL` and `PROFILES_URL` endpoints. You may use `go run ./scripts/backend` to start mock GRPC server for `BACKEND_PROFILES_URL`, `BILLSTAT_URL`, `DNSCHECK_REMOTEKV_URL`, and `PROFILES_URL` endpoints.
You may need to change the listen ports in `config.yaml` which are less than 1024 to some other ports. Otherwise, `sudo` or `doas` is required to run `AdGuardDNS`. You may need to change the listen ports in `config.yaml` which are less than 1024 to some other ports. Otherwise, `sudo` or `doas` is required to run `AdGuardDNS`.

View File

@ -6,6 +6,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
- [`ADULT_BLOCKING_ENABLED`](#ADULT_BLOCKING_ENABLED) - [`ADULT_BLOCKING_ENABLED`](#ADULT_BLOCKING_ENABLED)
- [`ADULT_BLOCKING_URL`](#ADULT_BLOCKING_URL) - [`ADULT_BLOCKING_URL`](#ADULT_BLOCKING_URL)
- [`BACKEND_RATELIMIT_API_KEY`](#BACKEND_RATELIMIT_API_KEY)
- [`BACKEND_RATELIMIT_URL`](#BACKEND_RATELIMIT_URL)
- [`BILLSTAT_API_KEY`](#BILLSTAT_API_KEY) - [`BILLSTAT_API_KEY`](#BILLSTAT_API_KEY)
- [`BILLSTAT_URL`](#BILLSTAT_URL) - [`BILLSTAT_URL`](#BILLSTAT_URL)
- [`BLOCKED_SERVICE_ENABLED`](#BLOCKED_SERVICE_ENABLED) - [`BLOCKED_SERVICE_ENABLED`](#BLOCKED_SERVICE_ENABLED)
@ -14,6 +16,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
- [`CONSUL_ALLOWLIST_URL`](#CONSUL_ALLOWLIST_URL) - [`CONSUL_ALLOWLIST_URL`](#CONSUL_ALLOWLIST_URL)
- [`CONSUL_DNSCHECK_KV_URL`](#CONSUL_DNSCHECK_KV_URL) - [`CONSUL_DNSCHECK_KV_URL`](#CONSUL_DNSCHECK_KV_URL)
- [`CONSUL_DNSCHECK_SESSION_URL`](#CONSUL_DNSCHECK_SESSION_URL) - [`CONSUL_DNSCHECK_SESSION_URL`](#CONSUL_DNSCHECK_SESSION_URL)
- [`DNSCHECK_REMOTEKV_API_KEY`](#DNSCHECK_REMOTEKV_API_KEY)
- [`DNSCHECK_REMOTEKV_URL`](#DNSCHECK_REMOTEKV_URL)
- [`FILTER_CACHE_PATH`](#FILTER_CACHE_PATH) - [`FILTER_CACHE_PATH`](#FILTER_CACHE_PATH)
- [`FILTER_INDEX_URL`](#FILTER_INDEX_URL) - [`FILTER_INDEX_URL`](#FILTER_INDEX_URL)
- [`GENERAL_SAFE_ENABLED`](#GENERAL_SAFE_SEARCH_ENABLED) - [`GENERAL_SAFE_ENABLED`](#GENERAL_SAFE_SEARCH_ENABLED)
@ -62,6 +66,21 @@ The HTTP(S) URL of source list of rules for adult blocking filter.
**Default:** No default value, the variable is required if `ADULT_BLOCKING_ENABLED` is set to `1`. **Default:** No default value, the variable is required if `ADULT_BLOCKING_ENABLED` is set to `1`.
## <a href="#BACKEND_RATELIMIT_API_KEY" id="BACKEND_RATELIMIT_API_KEY" name="BACKEND_RATELIMIT_API_KEY">`BACKEND_RATELIMIT_API_KEY`</a>
The API key to use when authenticating requests to the backend rate limiter API, if any. The API key should be valid as defined by [RFC 6750].
**Default:** **Unset.**
## <a href="#BACKEND_RATELIMIT_URL" id="BACKEND_RATELIMIT_URL" name="BACKEND_RATELIMIT_URL">`BACKEND_RATELIMIT_URL`</a>
The base backend URL for backend rate limiter. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-backend-ratelimit].
**Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `backend` in the configuration file.
[conf-ratelimit-type]: configuration.md#ratelimit-type
[ext-backend-ratelimit]: externalhttp.md#backend-ratelimit
## <a href="#BILLSTAT_API_KEY" id="BILLSTAT_API_KEY" name="BILLSTAT_API_KEY">`BILLSTAT_API_KEY`</a> ## <a href="#BILLSTAT_API_KEY" id="BILLSTAT_API_KEY" name="BILLSTAT_API_KEY">`BILLSTAT_API_KEY`</a>
The API key to use when authenticating queries to the billing statistics API, if any. The API key should be valid as defined by [RFC 6750]. The API key to use when authenticating queries to the billing statistics API, if any. The API key should be valid as defined by [RFC 6750].
@ -72,7 +91,7 @@ The API key to use when authenticating queries to the billing statistics API, if
## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a> ## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a>
The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat]. The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat].
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled. **Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
@ -103,7 +122,7 @@ The path to the configuration file.
The HTTP(S) URL of the Consul instance serving the dynamic part of the rate-limit allowlist. See the [external HTTP API requirements section][ext-consul] on the expected format of the response. The HTTP(S) URL of the Consul instance serving the dynamic part of the rate-limit allowlist. See the [external HTTP API requirements section][ext-consul] on the expected format of the response.
**Default:** No default value, the variable is **required.** **Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `consul` in the configuration file.
[ext-consul]: externalhttp.md#consul [ext-consul]: externalhttp.md#consul
@ -123,6 +142,20 @@ The HTTP(S) URL of the session API of the Consul instance used as a key-value da
**Example:** `http://localhost:8500/v1/session/create` **Example:** `http://localhost:8500/v1/session/create`
## <a href="#DNSCHECK_REMOTEKV_API_KEY" id="DNSCHECK_REMOTEKV_API_KEY" name="DNSCHECK_REMOTEKV_API_KEY">`DNSCHECK_REMOTEKV_API_KEY`</a>
The API key to use when authenticating queries to the backend key-value database API, if any. The API key should be valid as defined by [RFC 6750].
**Default:** **Unset.**
## <a href="#DNSCHECK_REMOTEKV_URL" id="DNSCHECK_REMOTEKV_URL" name="DNSCHECK_REMOTEKV_URL">`DNSCHECK_REMOTEKV_URL`</a>
The base backend URL used as a key-value database for the DNS server checking. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-backend-dnscheck].
**Default:** **Unset.**
[ext-backend-dnscheck]: externalhttp.md#backend-dnscheck
## <a href="#FILTER_CACHE_PATH" id="FILTER_CACHE_PATH" name="FILTER_CACHE_PATH">`FILTER_CACHE_PATH`</a> ## <a href="#FILTER_CACHE_PATH" id="FILTER_CACHE_PATH" name="FILTER_CACHE_PATH">`FILTER_CACHE_PATH`</a>
The path to the directory used to store the cached version of all filters and filter indexes. The path to the directory used to store the cached version of all filters and filter indexes.
@ -238,11 +271,11 @@ The profile cache is read on start and is later updated on every [full refresh][
The maximum size of the response from the profiles API in a human-readable format. The maximum size of the response from the profiles API in a human-readable format.
**Default:** `8MB`. **Default:** `64MB`.
## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a> ## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a>
The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-profiles]. The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-profiles].
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled. **Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
@ -252,7 +285,7 @@ The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`
Redis server address. Can be an IP address or a hostname. Redis server address. Can be an IP address or a hostname.
**Default:** No default value, the variable if required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file. **Default:** No default value, the variable is required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file.
[conf-check-kv-type]: configuration.md#check-kv-type [conf-check-kv-type]: configuration.md#check-kv-type

View File

@ -10,7 +10,9 @@ AdGuard DNS uses information from external HTTP APIs for filtering and other pie
## Contents ## Contents
- [Backend billing statistics](#backend-billstat) - [Backend billing statistics](#backend-billstat)
- [Backend DNSCheck service](#backend-dnscheck)
- [Backend profiles service](#backend-profiles) - [Backend profiles service](#backend-profiles)
- [Backend ratelimit service](#backend-ratelimit)
- [Consul key-value storage](#consul) - [Consul key-value storage](#consul)
- [Filtering](#filters) - [Filtering](#filters)
- [Blocked services](#filters-blocked-services) - [Blocked services](#filters-blocked-services)
@ -28,6 +30,15 @@ This service is disabled when all server groups have property [`profiles_enabled
[env-billstat_url]: environment.md#BILLSTAT_URL [env-billstat_url]: environment.md#BILLSTAT_URL
[conf-srvgrp-prof]: configuration.md#sg-*-profiles_enabled [conf-srvgrp-prof]: configuration.md#sg-*-profiles_enabled
## <a href="#backend-dnscheck" id="backend-dnscheck" name="backend-dnscheck">Backend DNSCheck service</a>
This is the service to which the [`DNSCHECK_REMOTEKV_URL`][env-dnscheck_remotekv_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
This service is only enabled when the `check.kv` object has the [`type`][conf-check-kv-type] property set to `backend`.
[env-dnscheck_remotekv_url]: environment.md#DNSCHECK_REMOTEKV_URL
[conf-check-kv-type]: configuration.md#check-kv-type
## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend profiles service</a> ## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend profiles service</a>
This is the service to which the [`PROFILES_URL`][env-profiles_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`. This is the service to which the [`PROFILES_URL`][env-profiles_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
@ -36,6 +47,15 @@ This service is disabled when all server groups have property [`profiles_enabled
[env-profiles_url]: environment.md#PROFILES_URL [env-profiles_url]: environment.md#PROFILES_URL
## <a href="#backend-ratelimit" id="backend-ratelimit" name="backend-ratelimit">Backend ratelimit_service</a>
This is the service to which the [`BACKEND_RATELIMIT_URL`][env-backend_ratelimit_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
This service is only enabled when the `ratelimit` object has the [`type`][conf-ratelimit-type] property set to `backend`.
[conf-ratelimit-type]: configuration.md#ratelimit-type
[env-backend_ratelimit_url]: environment.md#BACKEND_RATELIMIT_URL
## <a href="#consul" id="consul" name="consul">Consul key-value storage</a> ## <a href="#consul" id="consul" name="consul">Consul key-value storage</a>
A [Consul][consul-io] service can be used for the DNS server check and dynamic rate-limit allowlist features. Currently used endpoints can be seen in the documentation of the [`CONSUL_ALLOWLIST_URL`][env-consul-allowlist], [`CONSUL_DNSCHECK_KV_URL`][env-consul-dnscheck-kv], and [`CONSUL_DNSCHECK_SESSION_URL`][env-consul-dnscheck-session] environment variables. A [Consul][consul-io] service can be used for the DNS server check and dynamic rate-limit allowlist features. Currently used endpoints can be seen in the documentation of the [`CONSUL_ALLOWLIST_URL`][env-consul-allowlist], [`CONSUL_DNSCHECK_KV_URL`][env-consul-dnscheck-kv], and [`CONSUL_DNSCHECK_SESSION_URL`][env-consul-dnscheck-session] environment variables.

43
go.mod
View File

@ -1,34 +1,34 @@
module github.com/AdguardTeam/AdGuardDNS module github.com/AdguardTeam/AdGuardDNS
go 1.23.1 go 1.23.2
require ( require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.0.0-20240607112746-5690301129fe github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.0.0-20240607112746-5690301129fe
github.com/AdguardTeam/golibs v0.28.0 github.com/AdguardTeam/golibs v0.30.1
github.com/AdguardTeam/urlfilter v0.19.0 github.com/AdguardTeam/urlfilter v0.20.0
github.com/ameshkov/dnscrypt/v2 v2.3.0 github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/axiomhq/hyperloglog v0.2.0 github.com/axiomhq/hyperloglog v0.2.0
github.com/bluele/gcache v0.0.2 github.com/bluele/gcache v0.0.2
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
github.com/caarlos0/env/v7 v7.1.0 github.com/caarlos0/env/v7 v7.1.0
github.com/getsentry/sentry-go v0.28.1 github.com/getsentry/sentry-go v0.29.1
github.com/gomodule/redigo v1.9.2 github.com/gomodule/redigo v1.9.2
github.com/google/renameio/v2 v2.0.0 github.com/google/renameio/v2 v2.0.0
github.com/miekg/dns v1.1.62 github.com/miekg/dns v1.1.62
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/prometheus/client_golang v1.20.1 github.com/prometheus/client_golang v1.20.5
github.com/prometheus/client_model v0.6.1 github.com/prometheus/client_model v0.6.1
github.com/prometheus/common v0.55.0 github.com/prometheus/common v0.60.0
github.com/quic-go/quic-go v0.47.0 github.com/quic-go/quic-go v0.48.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.28.0
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.29.0 golang.org/x/net v0.30.0
golang.org/x/sys v0.25.0 golang.org/x/sys v0.26.0
golang.org/x/time v0.6.0 golang.org/x/time v0.7.0
google.golang.org/grpc v1.65.0 google.golang.org/grpc v1.67.1
google.golang.org/protobuf v1.34.2 google.golang.org/protobuf v1.35.1
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
) )
@ -41,24 +41,21 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.20.2 // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect
github.com/panjf2000/ants/v2 v2.10.0 // indirect github.com/panjf2000/ants/v2 v2.10.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.5.0 // indirect
golang.org/x/mod v0.21.0 // indirect golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/text v0.18.0 // indirect golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.25.0 // indirect golang.org/x/tools v0.26.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )
replace github.com/AdguardTeam/AdGuardDNS/internal/dnsserver => ./internal/dnsserver replace github.com/AdguardTeam/AdGuardDNS/internal/dnsserver => ./internal/dnsserver
// TODO(a.garipov): Remove once https://github.com/quic-go/quic-go/pull/4685 is merged.
replace github.com/quic-go/quic-go => github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd

76
go.sum
View File

@ -1,13 +1,11 @@
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4= github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.28.0/go.mod h1:iWdjXPCwmK2g2FKIb/OwEPnovSXeMqRhI8FWLxF5oxE= github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/AdguardTeam/urlfilter v0.19.0 h1:q7eH13+yNETlpD/VD3u5rLQOripcUdEktqZFy+KiQLk= github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
github.com/AdguardTeam/urlfilter v0.19.0/go.mod h1:+N54ZvxqXYLnXuvpaUhK2exDQW+djZBRSb6F6j0rkBY= github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd h1:mw4LqrCiv3vcKuCxBRg7kA17xfHKM+9hZgFWmyhe/AY=
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs= github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs=
github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE= github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo= github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
@ -29,8 +27,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 h1:y7y0Oa6UawqTFPCDw9JG6pdKt4F9pAhHv0B7FMGaGD0= github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 h1:y7y0Oa6UawqTFPCDw9JG6pdKt4F9pAhHv0B7FMGaGD0=
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
github.com/getsentry/sentry-go v0.28.1 h1:zzaSm/vHmGllRM6Tpx1492r0YDzauArdBfkJRtY6P5k= github.com/getsentry/sentry-go v0.29.1 h1:DyZuChN8Hz3ARxGVV8ePaNXh1dQ7d76AiB117xcREwA=
github.com/getsentry/sentry-go v0.28.1/go.mod h1:1fQZ+7l7eeJ3wYi82q5Hg8GqAPgefRq+FP/QhafYVgg= github.com/getsentry/sentry-go v0.29.1/go.mod h1:x3AtIzN01d6SiWkderzaH28Tm0lgkafpJ5Bm3li39O0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
@ -43,12 +41,12 @@ github.com/gomodule/redigo v1.9.2 h1:HrutZBLhSIU8abiSfW8pj8mPhOyMYjZT/wcA4/L9L9s
github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw= github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4= github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@ -79,16 +77,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8= github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
@ -109,33 +109,33 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd h1:6TEm2ZxXoQmFWFlt1vNxvVOa1Q0dXFQD1m/rYjXmS0E= google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -1,4 +1,4 @@
go 1.23.1 go 1.23.2
use ( use (
. .

View File

@ -1,5 +1,6 @@
cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w=
cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
@ -45,6 +46,7 @@ cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGB
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
cloud.google.com/go/contactcenterinsights v1.13.0/go.mod h1:ieq5d5EtHsu8vhe2y3amtZ+BE+AQwX5qAy7cpo0POsI= cloud.google.com/go/contactcenterinsights v1.13.0/go.mod h1:ieq5d5EtHsu8vhe2y3amtZ+BE+AQwX5qAy7cpo0POsI=
cloud.google.com/go/container v1.31.0/go.mod h1:7yABn5s3Iv3lmw7oMmyGbeV6tQj86njcTijkkGuvdZA= cloud.google.com/go/container v1.31.0/go.mod h1:7yABn5s3Iv3lmw7oMmyGbeV6tQj86njcTijkkGuvdZA=
cloud.google.com/go/containeranalysis v0.11.4/go.mod h1:cVZT7rXYBS9NG1rhQbWL9pWbXCKHWJPYraE8/FTSYPE= cloud.google.com/go/containeranalysis v0.11.4/go.mod h1:cVZT7rXYBS9NG1rhQbWL9pWbXCKHWJPYraE8/FTSYPE=
@ -156,6 +158,9 @@ github.com/AdguardTeam/golibs v0.19.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMss
github.com/AdguardTeam/golibs v0.25.2/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI= github.com/AdguardTeam/golibs v0.25.2/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
github.com/AdguardTeam/golibs v0.25.3 h1:A06JZGSuAhAC0uq/s7IlNsv/V8TyNJfLalB0vhkd1vA= github.com/AdguardTeam/golibs v0.25.3 h1:A06JZGSuAhAC0uq/s7IlNsv/V8TyNJfLalB0vhkd1vA=
github.com/AdguardTeam/golibs v0.25.3/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI= github.com/AdguardTeam/golibs v0.25.3/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
github.com/AdguardTeam/golibs v0.30.0/go.mod h1:vjw1OVZG6BYyoqGRY88U4LCJLOMfhBFhU0UJBdaSAuQ=
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU= github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/AdguardTeam/gomitmproxy v0.2.1 h1:p9gr8Er1TYvf+7ic81Ax1sZ62UNCsMTZNbm7tC59S9o= github.com/AdguardTeam/gomitmproxy v0.2.1 h1:p9gr8Er1TYvf+7ic81Ax1sZ62UNCsMTZNbm7tC59S9o=
github.com/AdguardTeam/gomitmproxy v0.2.1/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.1/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
@ -245,6 +250,7 @@ github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50 h1:DBmgJDC9dTfkVyGgipa
github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50/go.mod h1:5e1+Vvlzido69INQaVO6d87Qn543Xr6nooe9Kz7oBFM= github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50/go.mod h1:5e1+Vvlzido69INQaVO6d87Qn543Xr6nooe9Kz7oBFM=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q= github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM=
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI=
@ -262,12 +268,14 @@ github.com/envoyproxy/go-control-plane v0.11.1 h1:wSUXTlLfiAQRWs2F+p+EKOY9rUyis1
github.com/envoyproxy/go-control-plane v0.11.1/go.mod h1:uhMcXKCQMEJHiAb0w+YGefQLaTEw+YhGluxZkrTmD0g= github.com/envoyproxy/go-control-plane v0.11.1/go.mod h1:uhMcXKCQMEJHiAb0w+YGefQLaTEw+YhGluxZkrTmD0g=
github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI= github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI=
github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0= github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0=
github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8=
github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA=
github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE=
github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A=
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
@ -338,6 +346,7 @@ github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4= github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4=
github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@ -576,6 +585,7 @@ github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwb
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas= github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
@ -677,6 +687,7 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
github.com/urfave/negroni/v3 v3.1.1/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc= github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc=
@ -828,6 +839,7 @@ golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8= golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -912,6 +924,7 @@ golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@ -1000,6 +1013,7 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE= google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw= google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw=
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU=
google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo=
google.golang.org/genproto/googleapis/bytestream v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:vh/N7795ftP0AkN1w8XKqN4w1OdUKXW5Eummda+ofv8= google.golang.org/genproto/googleapis/bytestream v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:vh/N7795ftP0AkN1w8XKqN4w1OdUKXW5Eummda+ofv8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs= google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil/urlutil"
) )
// Client is a wrapper around http.Client. // Client is a wrapper around http.Client.
@ -93,6 +94,7 @@ func (c *Client) do(
req.Header.Set(httphdr.UserAgent, c.userAgent) req.Header.Set(httphdr.UserAgent, c.userAgent)
resp, err = c.http.Do(req) resp, err = c.http.Do(req)
urlutil.RedactUserinfoInURLError(u, err)
if err != nil && resp != nil && resp.Header != nil { if err != nil && resp != nil && resp.Header != nil {
// A non-nil Response with a non-nil error only occurs when // A non-nil Response with a non-nil error only occurs when
// CheckRedirect fails. // CheckRedirect fails.

View File

@ -5,51 +5,13 @@ import (
"net/url" "net/url"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil/urlutil"
) )
// Known scheme constants.
//
// TODO(a.garipov): Move to agdurl or golibs.
//
// TODO(a.garipov): Use more.
const (
SchemeFile = "file"
SchemeGRPC = "grpc"
SchemeGRPCS = "grpcs"
SchemeHTTP = "http"
SchemeHTTPS = "https"
)
// CheckGRPCURLScheme returns true if s is a valid gRPC URL scheme. That is,
// [SchemeGRPC] or [SchemeGRPCS]
//
// TODO(a.garipov): Move to golibs?
func CheckGRPCURLScheme(s string) (ok bool) {
switch s {
case SchemeGRPC, SchemeGRPCS:
return true
default:
return false
}
}
// CheckHTTPURLScheme returns true if s is a valid HTTP URL scheme. That is,
// [SchemeHTTP] or [SchemeHTTPS]
//
// TODO(a.garipov): Move to golibs?
func CheckHTTPURLScheme(s string) (ok bool) {
switch s {
case SchemeHTTP, SchemeHTTPS:
return true
default:
return false
}
}
// ParseHTTPURL parses an absolute URL and makes sure that it is a valid HTTP(S) // ParseHTTPURL parses an absolute URL and makes sure that it is a valid HTTP(S)
// URL. All returned errors will have the underlying type [*url.Error]. // URL. All returned errors will have the underlying type [*url.Error].
// //
// TODO(a.garipov): Define as a type? // TODO(a.garipov): Define as a type?
func ParseHTTPURL(s string) (u *url.URL, err error) { func ParseHTTPURL(s string) (u *url.URL, err error) {
u, err = url.Parse(s) u, err = url.Parse(s)
if err != nil { if err != nil {
@ -63,7 +25,7 @@ func ParseHTTPURL(s string) (u *url.URL, err error) {
URL: s, URL: s,
Err: errors.Error("empty host"), Err: errors.Error("empty host"),
} }
case !CheckHTTPURLScheme(u.Scheme): case !urlutil.IsValidHTTPURLScheme(u.Scheme):
return nil, &url.Error{ return nil, &url.Error{
Op: "parse", Op: "parse",
URL: s, URL: s,

View File

@ -6,12 +6,19 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
// Common user credentials for tests.
const (
testUsername = "user"
testPassword = "pass"
)
func TestParseHTTPURL(t *testing.T) { func TestParseHTTPURL(t *testing.T) {
goodURL := testURL() goodURL := testURL(url.UserPassword(testUsername, testPassword))
badSchemeURL := netutil.CloneURL(goodURL) badSchemeURL := netutil.CloneURL(goodURL)
badSchemeURL.Scheme = "ftp" badSchemeURL.Scheme = "ftp"
@ -61,10 +68,11 @@ func TestParseHTTPURL(t *testing.T) {
} }
} }
func testURL() (u *url.URL) { // testURL is a helper function that returns an url with dummy values.
func testURL(info *url.Userinfo) (u *url.URL) {
return &url.URL{ return &url.URL{
Scheme: agdhttp.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
User: url.UserPassword("user", "pass"), User: info,
Host: "example.com", Host: "example.com",
Path: "/a/b/c/", Path: "/a/b/c/",
RawQuery: "d=e", RawQuery: "d=e",

View File

@ -3,6 +3,7 @@
package agdtest package agdtest
import ( import (
"net/url"
"testing" "testing"
"time" "time"
@ -18,7 +19,7 @@ const FilteredResponseTTL = FilteredResponseTTLSec * time.Second
// number to simplify message creation. // number to simplify message creation.
const FilteredResponseTTLSec = 10 const FilteredResponseTTLSec = 10
// NewConstructorWithTTL returns a standard dnsmsg.Constructor for tests, using // NewConstructorWithTTL returns a standard *dnsmsg.Constructor for tests, using
// ttl as the TTL for filtered responses. // ttl as the TTL for filtered responses.
func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Constructor) { func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Constructor) {
tb.Helper() tb.Helper()
@ -26,28 +27,57 @@ func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Construc
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: NewCloner(), Cloner: NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: NewSDEConfig(true),
FilteredResponseTTL: ttl, FilteredResponseTTL: ttl,
EDEEnabled: true,
}) })
require.NoError(tb, err) require.NoError(tb, err)
return c return c
} }
// NewConstructor returns a standard dnsmsg.Constructor for tests, using // NewConstructor returns a standard *dnsmsg.Constructor for tests, using
// [FilteredResponseTTL] as the TTL for filtered responses. // [FilteredResponseTTL] as the TTL for filtered responses. The returned
// constructor also has the Structured DNS Errors feature enabled.
func NewConstructor(tb testing.TB) (c *dnsmsg.Constructor) { func NewConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
tb.Helper() tb.Helper()
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: NewCloner(), Cloner: NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: NewSDEConfig(true),
FilteredResponseTTL: FilteredResponseTTL, FilteredResponseTTL: FilteredResponseTTL,
EDEEnabled: true,
}) })
require.NoError(tb, err) require.NoError(tb, err)
return c return c
} }
// SDEText is a test Structured DNS Error text.
//
// NOTE: Keep in sync with [NewSDEConfig].
//
// TODO(e.burkov): Add some helper when this message becomes configurable.
const SDEText = `{` +
`"j":"Filtering",` +
`"o":"Test Org",` +
`"c":["mailto:support@dns.example"]` +
`}`
// NewSDEConfig returns a standard *dnsmsg.StructuredDNSErrorsConfig for tests.
func NewSDEConfig(enabled bool) (c *dnsmsg.StructuredDNSErrorsConfig) {
return &dnsmsg.StructuredDNSErrorsConfig{
Contact: []*url.URL{{
Scheme: "mailto",
Opaque: "support@dns.example",
}},
Justification: "Filtering",
Organization: "Test Org",
Enabled: enabled,
}
}
// NewCloner returns a standard dnsmsg.Cloner for tests. // NewCloner returns a standard dnsmsg.Cloner for tests.
func NewCloner() (c *dnsmsg.Cloner) { func NewCloner() (c *dnsmsg.Cloner) {
return dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{}) return dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{})

View File

@ -116,7 +116,7 @@ func (r *Refresher) Refresh(ctx context.Context) (err error) {
// type check // type check
var _ billstat.Recorder = (*BillStatRecorder)(nil) var _ billstat.Recorder = (*BillStatRecorder)(nil)
// BillStatRecorder is a billstat.Recorder for tests. // BillStatRecorder is a [billstat.Recorder] for tests.
type BillStatRecorder struct { type BillStatRecorder struct {
OnRecord func( OnRecord func(
ctx context.Context, ctx context.Context,
@ -128,7 +128,7 @@ type BillStatRecorder struct {
) )
} }
// Record implements the billstat.Recorder interface for *BillStatRecorder. // Record implements the [billstat.Recorder] interface for *BillStatRecorder.
func (r *BillStatRecorder) Record( func (r *BillStatRecorder) Record(
ctx context.Context, ctx context.Context,
id agd.DeviceID, id agd.DeviceID,
@ -143,12 +143,12 @@ func (r *BillStatRecorder) Record(
// type check // type check
var _ billstat.Uploader = (*BillStatUploader)(nil) var _ billstat.Uploader = (*BillStatUploader)(nil)
// BillStatUploader is a billstat.Uploader for tests. // BillStatUploader is a [billstat.Uploader] for tests.
type BillStatUploader struct { type BillStatUploader struct {
OnUpload func(ctx context.Context, records billstat.Records) (err error) OnUpload func(ctx context.Context, records billstat.Records) (err error)
} }
// Upload implements the billstat.Uploader interface for *BillStatUploader. // Upload implements the [billstat.Uploader] interface for *BillStatUploader.
func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records) (err error) { func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records) (err error) {
return b.OnUpload(ctx, records) return b.OnUpload(ctx, records)
} }
@ -158,7 +158,7 @@ func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records)
// type check // type check
var _ dnscheck.Interface = (*DNSCheck)(nil) var _ dnscheck.Interface = (*DNSCheck)(nil)
// DNSCheck is a dnscheck.Interface for tests. // DNSCheck is a [dnscheck.Interface] for tests.
type DNSCheck struct { type DNSCheck struct {
OnCheck func(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (reqp *dns.Msg, err error) OnCheck func(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (reqp *dns.Msg, err error)
} }
@ -177,12 +177,12 @@ func (db *DNSCheck) Check(
// type check // type check
var _ dnsdb.Interface = (*DNSDB)(nil) var _ dnsdb.Interface = (*DNSDB)(nil)
// DNSDB is a dnsdb.Interface for tests. // DNSDB is a [dnsdb.Interface] for tests.
type DNSDB struct { type DNSDB struct {
OnRecord func(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) OnRecord func(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo)
} }
// Record implements the dnsdb.Interface interface for *DNSDB. // Record implements the [dnsdb.Interface] interface for *DNSDB.
func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) { func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) {
db.OnRecord(ctx, resp, ri) db.OnRecord(ctx, resp, ri)
} }
@ -204,7 +204,7 @@ func (c *ErrorCollector) Collect(ctx context.Context, err error) {
c.OnCollect(ctx, err) c.OnCollect(ctx, err)
} }
// NewErrorCollector returns a new [ErrorCollector] all methods of which panic. // NewErrorCollector returns a new *ErrorCollector all methods of which panic.
func NewErrorCollector() (c *ErrorCollector) { func NewErrorCollector() (c *ErrorCollector) {
return &ErrorCollector{ return &ErrorCollector{
OnCollect: func(_ context.Context, err error) { OnCollect: func(_ context.Context, err error) {
@ -297,21 +297,38 @@ func (s *FilterStorage) HasListID(id agd.FilterListID) (ok bool) {
// type check // type check
var _ geoip.Interface = (*GeoIP)(nil) var _ geoip.Interface = (*GeoIP)(nil)
// GeoIP is a geoip.Interface for tests. // GeoIP is a [geoip.Interface] for tests.
type GeoIP struct { type GeoIP struct {
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
OnData func(host string, ip netip.Addr) (l *geoip.Location, err error) OnData func(host string, ip netip.Addr) (l *geoip.Location, err error)
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
} }
// SubnetByLocation implements the geoip.Interface interface for *GeoIP. // Data implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) SubnetByLocation(l *geoip.Location, fam netutil.AddrFamily, func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) {
return g.OnData(host, ip)
}
// SubnetByLocation implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) SubnetByLocation(
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) { ) (n netip.Prefix, err error) {
return g.OnSubnetByLocation(l, fam) return g.OnSubnetByLocation(l, fam)
} }
// Data implements the geoip.Interface interface for *GeoIP. // NewGeoIP returns a new *GeoIP all methods of which panic.
func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) { func NewGeoIP() (c *GeoIP) {
return g.OnData(host, ip) return &GeoIP{
OnData: func(host string, ip netip.Addr) (l *geoip.Location, err error) {
panic(fmt.Errorf("unexpected call to GeoIP.Data(%v, %v)", host, ip))
},
OnSubnetByLocation: func(
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic(fmt.Errorf("unexpected call to GeoIP.SubnetByLocation(%v, %v)", l, fam))
},
}
} }
// Package profiledb // Package profiledb
@ -398,6 +415,58 @@ func (db *ProfileDB) ProfileByLinkedIP(
return db.OnProfileByLinkedIP(ctx, ip) return db.OnProfileByLinkedIP(ctx, ip)
} }
// NewProfileDB returns a new *ProfileDB all methods of which panic.
func NewProfileDB() (db *ProfileDB) {
return &ProfileDB{
OnCreateAutoDevice: func(
_ context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf(
"unexpected call to ProfileDB.CreateAutoDevice(%v, %v, %v)",
id,
humanID,
devType,
))
},
OnProfileByDedicatedIP: func(
_ context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDedicatedIP(%v)", ip))
},
OnProfileByDeviceID: func(
_ context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDeviceID(%v)", id))
},
OnProfileByHumanID: func(
_ context.Context,
profID agd.ProfileID,
humanID agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf(
"unexpected call to ProfileDB.ProfileByHumanID(%v, %v)",
profID,
humanID,
))
},
OnProfileByLinkedIP: func(
_ context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByLinkedIP(%v)", ip))
},
}
}
// type check // type check
var _ profiledb.Storage = (*ProfileStorage)(nil) var _ profiledb.Storage = (*ProfileStorage)(nil)
@ -436,12 +505,12 @@ func (s *ProfileStorage) Profiles(
// type check // type check
var _ querylog.Interface = (*QueryLog)(nil) var _ querylog.Interface = (*QueryLog)(nil)
// QueryLog is a querylog.Interface for tests. // QueryLog is a [querylog.Interface] for tests.
type QueryLog struct { type QueryLog struct {
OnWrite func(ctx context.Context, e *querylog.Entry) (err error) OnWrite func(ctx context.Context, e *querylog.Entry) (err error)
} }
// Write implements the querylog.Interface interface for *QueryLog. // Write implements the [querylog.Interface] interface for *QueryLog.
func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) { func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
return ql.OnWrite(ctx, e) return ql.OnWrite(ctx, e)
} }
@ -451,12 +520,12 @@ func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
// type check // type check
var _ rulestat.Interface = (*RuleStat)(nil) var _ rulestat.Interface = (*RuleStat)(nil)
// RuleStat is a rulestat.Interface for tests. // RuleStat is a [rulestat.Interface] for tests.
type RuleStat struct { type RuleStat struct {
OnCollect func(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText) OnCollect func(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText)
} }
// Collect implements the rulestat.Interface interface for *RuleStat. // Collect implements the [rulestat.Interface] interface for *RuleStat.
func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText) { func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText) {
s.OnCollect(ctx, id, text) s.OnCollect(ctx, id, text)
} }
@ -501,7 +570,7 @@ func (c *ListenConfig) ListenPacket(
// type check // type check
var _ ratelimit.Interface = (*RateLimit)(nil) var _ ratelimit.Interface = (*RateLimit)(nil)
// RateLimit is a ratelimit.Interface for tests. // RateLimit is a [ratelimit.Interface] for tests.
type RateLimit struct { type RateLimit struct {
OnIsRateLimited func( OnIsRateLimited func(
ctx context.Context, ctx context.Context,
@ -511,7 +580,7 @@ type RateLimit struct {
OnCountResponses func(ctx context.Context, resp *dns.Msg, ip netip.Addr) OnCountResponses func(ctx context.Context, resp *dns.Msg, ip netip.Addr)
} }
// IsRateLimited implements the ratelimit.Interface interface for *RateLimit. // IsRateLimited implements the [ratelimit.Interface] interface for *RateLimit.
func (l *RateLimit) IsRateLimited( func (l *RateLimit) IsRateLimited(
ctx context.Context, ctx context.Context,
req *dns.Msg, req *dns.Msg,
@ -520,12 +589,27 @@ func (l *RateLimit) IsRateLimited(
return l.OnIsRateLimited(ctx, req, ip) return l.OnIsRateLimited(ctx, req, ip)
} }
// CountResponses implements the ratelimit.Interface interface for // CountResponses implements the [ratelimit.Interface] interface for *RateLimit.
// *RateLimit.
func (l *RateLimit) CountResponses(ctx context.Context, req *dns.Msg, ip netip.Addr) { func (l *RateLimit) CountResponses(ctx context.Context, req *dns.Msg, ip netip.Addr) {
l.OnCountResponses(ctx, req, ip) l.OnCountResponses(ctx, req, ip)
} }
// NewRateLimit returns a new *RateLimit all methods of which panic.
func NewRateLimit() (c *RateLimit) {
return &RateLimit{
OnIsRateLimited: func(
_ context.Context,
req *dns.Msg,
addr netip.Addr,
) (shouldDrop, isAllowlisted bool, err error) {
panic(fmt.Errorf("unexpected call to RateLimit.IsRateLimited(%v, %v)", req, addr))
},
OnCountResponses: func(_ context.Context, resp *dns.Msg, addr netip.Addr) {
panic(fmt.Errorf("unexpected call to RateLimit.CountResponses(%v, %v)", resp, addr))
},
}
}
// RemoteKV is an [remotekv.Interface] implementation for tests. // RemoteKV is an [remotekv.Interface] implementation for tests.
type RemoteKV struct { type RemoteKV struct {
OnGet func(ctx context.Context, key string) (val []byte, ok bool, err error) OnGet func(ctx context.Context, key string) (val []byte, ok bool, err error)

View File

@ -16,8 +16,8 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
// newClient returns new properly initialized DNSServiceClient. // newClient returns new properly initialized gRPC connection to the API server.
func newClient(apiURL *url.URL) (client DNSServiceClient, err error) { func newClient(apiURL *url.URL) (client *grpc.ClientConn, err error) {
var creds credentials.TransportCredentials var creds credentials.TransportCredentials
switch s := apiURL.Scheme; s { switch s := apiURL.Scheme; s {
case "grpc": case "grpc":
@ -38,7 +38,7 @@ func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
// called right before the initial refresh. // called right before the initial refresh.
conn.Connect() conn.Connect()
return NewDNSServiceClient(conn), nil return conn, nil
} }
// reportf is a helper method for reporting non-critical errors. // reportf is a helper method for reporting non-critical errors.

View File

@ -65,3 +65,39 @@ func (s *testDNSServiceServer) SaveDevicesBillingStat(
) (err error) { ) (err error) {
return s.OnSaveDevicesBillingStat(srv) return s.OnSaveDevicesBillingStat(srv)
} }
// testRemoteKVServiceServer is the [backendpb.RemoteKVServiceServer] for tests.
type testRemoteKVServiceServer struct {
backendpb.UnimplementedRemoteKVServiceServer
OnGet func(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error)
OnSet func(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error)
}
// type check
var _ backendpb.RemoteKVServiceServer = (*testRemoteKVServiceServer)(nil)
// Get implements the [backendpb.RemoteKVServiceServer] interface for
// *testRemoteKVServiceServer.
func (s *testRemoteKVServiceServer) Get(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error) {
return s.OnGet(ctx, req)
}
// Set implements the [backendpb.RemoteKVServiceServer] interface for
// *testRemoteKVServiceServer.
func (s *testRemoteKVServiceServer) Set(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error) {
return s.OnSet(ctx, req)
}

View File

@ -42,7 +42,7 @@ func NewBillStat(c *BillStatConfig) (b *BillStat, err error) {
return &BillStat{ return &BillStat{
errColl: c.ErrColl, errColl: c.ErrColl,
metrics: c.Metrics, metrics: c.Metrics,
client: client, client: NewDNSServiceClient(client),
apiKey: c.APIKey, apiKey: c.APIKey,
}, nil }, nil
} }

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,42 @@ service DNSService {
rpc createDeviceByHumanId(CreateDeviceRequest) returns (CreateDeviceResponse); rpc createDeviceByHumanId(CreateDeviceRequest) returns (CreateDeviceResponse);
} }
service RateLimitService {
/*
Gets rate limit settings.
*/
rpc getRateLimitSettings(RateLimitSettingsRequest) returns (RateLimitSettingsResponse);
}
service RemoteKVService {
/**
Get the value for the specified key.
This method may return the following errors:
- AuthenticationFailedError: If the authentication failed.
*/
rpc get(RemoteKVGetRequest) returns (RemoteKVGetResponse);
/**
Set the value for the specified key.
This method may return the following errors:
- AuthenticationFailedError: If the authentication failed.
- BadRequestError: If the request is invalid: value size exceeds the 512kb.
*/
rpc set(RemoteKVSetRequest) returns (RemoteKVSetResponse);
}
message RateLimitSettingsRequest {
}
message RateLimitSettingsResponse {
repeated CidrRange allowed_subnets = 1;
}
message DNSProfilesRequest { message DNSProfilesRequest {
google.protobuf.Timestamp sync_time = 1; google.protobuf.Timestamp sync_time = 1;
} }
@ -212,3 +248,24 @@ message RateLimitSettings {
uint32 rps = 2; uint32 rps = 2;
repeated CidrRange client_cidr = 3; repeated CidrRange client_cidr = 3;
} }
message RemoteKVGetRequest {
string key = 1;
}
message RemoteKVGetResponse {
oneof value {
bytes data = 1;
google.protobuf.Empty empty = 2;
}
}
message RemoteKVSetRequest {
string key = 1;
bytes data = 2;
google.protobuf.Duration ttl = 3;
}
message RemoteKVSetResponse {
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.5.1 // - protoc-gen-go-grpc v1.5.1
// - protoc v5.27.1 // - protoc v5.28.3
// source: dns.proto // source: dns.proto
package backendpb package backendpb
@ -235,3 +235,269 @@ var DNSService_ServiceDesc = grpc.ServiceDesc{
}, },
Metadata: "dns.proto", Metadata: "dns.proto",
} }
const (
RateLimitService_GetRateLimitSettings_FullMethodName = "/RateLimitService/getRateLimitSettings"
)
// RateLimitServiceClient is the client API for RateLimitService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type RateLimitServiceClient interface {
// Gets rate limit settings.
GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error)
}
type rateLimitServiceClient struct {
cc grpc.ClientConnInterface
}
func NewRateLimitServiceClient(cc grpc.ClientConnInterface) RateLimitServiceClient {
return &rateLimitServiceClient{cc}
}
func (c *rateLimitServiceClient) GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RateLimitSettingsResponse)
err := c.cc.Invoke(ctx, RateLimitService_GetRateLimitSettings_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// RateLimitServiceServer is the server API for RateLimitService service.
// All implementations must embed UnimplementedRateLimitServiceServer
// for forward compatibility.
type RateLimitServiceServer interface {
// Gets rate limit settings.
GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error)
mustEmbedUnimplementedRateLimitServiceServer()
}
// UnimplementedRateLimitServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedRateLimitServiceServer struct{}
func (UnimplementedRateLimitServiceServer) GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetRateLimitSettings not implemented")
}
func (UnimplementedRateLimitServiceServer) mustEmbedUnimplementedRateLimitServiceServer() {}
func (UnimplementedRateLimitServiceServer) testEmbeddedByValue() {}
// UnsafeRateLimitServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to RateLimitServiceServer will
// result in compilation errors.
type UnsafeRateLimitServiceServer interface {
mustEmbedUnimplementedRateLimitServiceServer()
}
func RegisterRateLimitServiceServer(s grpc.ServiceRegistrar, srv RateLimitServiceServer) {
// If the following call pancis, it indicates UnimplementedRateLimitServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&RateLimitService_ServiceDesc, srv)
}
func _RateLimitService_GetRateLimitSettings_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RateLimitSettingsRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RateLimitService_GetRateLimitSettings_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, req.(*RateLimitSettingsRequest))
}
return interceptor(ctx, in, info, handler)
}
// RateLimitService_ServiceDesc is the grpc.ServiceDesc for RateLimitService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var RateLimitService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "RateLimitService",
HandlerType: (*RateLimitServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "getRateLimitSettings",
Handler: _RateLimitService_GetRateLimitSettings_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}
const (
RemoteKVService_Get_FullMethodName = "/RemoteKVService/get"
RemoteKVService_Set_FullMethodName = "/RemoteKVService/set"
)
// RemoteKVServiceClient is the client API for RemoteKVService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type RemoteKVServiceClient interface {
// *
// Get the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error)
// *
// Set the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error)
}
type remoteKVServiceClient struct {
cc grpc.ClientConnInterface
}
func NewRemoteKVServiceClient(cc grpc.ClientConnInterface) RemoteKVServiceClient {
return &remoteKVServiceClient{cc}
}
func (c *remoteKVServiceClient) Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoteKVGetResponse)
err := c.cc.Invoke(ctx, RemoteKVService_Get_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *remoteKVServiceClient) Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoteKVSetResponse)
err := c.cc.Invoke(ctx, RemoteKVService_Set_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// RemoteKVServiceServer is the server API for RemoteKVService service.
// All implementations must embed UnimplementedRemoteKVServiceServer
// for forward compatibility.
type RemoteKVServiceServer interface {
// *
// Get the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error)
// *
// Set the value for the specified key.
//
// This method may return the following errors:
// - AuthenticationFailedError: If the authentication failed.
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error)
mustEmbedUnimplementedRemoteKVServiceServer()
}
// UnimplementedRemoteKVServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedRemoteKVServiceServer struct{}
func (UnimplementedRemoteKVServiceServer) Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
}
func (UnimplementedRemoteKVServiceServer) Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
}
func (UnimplementedRemoteKVServiceServer) mustEmbedUnimplementedRemoteKVServiceServer() {}
func (UnimplementedRemoteKVServiceServer) testEmbeddedByValue() {}
// UnsafeRemoteKVServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to RemoteKVServiceServer will
// result in compilation errors.
type UnsafeRemoteKVServiceServer interface {
mustEmbedUnimplementedRemoteKVServiceServer()
}
func RegisterRemoteKVServiceServer(s grpc.ServiceRegistrar, srv RemoteKVServiceServer) {
// If the following call pancis, it indicates UnimplementedRemoteKVServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&RemoteKVService_ServiceDesc, srv)
}
func _RemoteKVService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoteKVGetRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RemoteKVServiceServer).Get(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RemoteKVService_Get_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RemoteKVServiceServer).Get(ctx, req.(*RemoteKVGetRequest))
}
return interceptor(ctx, in, info, handler)
}
func _RemoteKVService_Set_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoteKVSetRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RemoteKVServiceServer).Set(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RemoteKVService_Set_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RemoteKVServiceServer).Set(ctx, req.(*RemoteKVSetRequest))
}
return interceptor(ctx, in, info, handler)
}
// RemoteKVService_ServiceDesc is the grpc.ServiceDesc for RemoteKVService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var RemoteKVService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "RemoteKVService",
HandlerType: (*RemoteKVServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "get",
Handler: _RemoteKVService_Get_Handler,
},
{
MethodName: "set",
Handler: _RemoteKVService_Set_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}

View File

@ -81,7 +81,7 @@ func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage, err error) {
return &ProfileStorage{ return &ProfileStorage{
bindSet: c.BindSet, bindSet: c.BindSet,
errColl: c.ErrColl, errColl: c.ErrColl,
client: client, client: NewDNSServiceClient(client),
logger: c.Logger, logger: c.Logger,
metrics: c.Metrics, metrics: c.Metrics,
apiKey: c.APIKey, apiKey: c.APIKey,

View File

@ -0,0 +1,103 @@
package backendpb
import (
"context"
"fmt"
"log/slog"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
)
// RateLimiterConfig is the configuration structure for the business logic
// backend rate limiter.
type RateLimiterConfig struct {
// Logger is used for logging the operation of the rate limiter. It must
// not be nil.
Logger *slog.Logger
// GRPCMetrics is used for the collection of the protobuf errors.
GRPCMetrics Metrics
// Metrics is used to collect allowlist statistics.
Metrics consul.Metrics
// Allowlist is the allowlist to update.
Allowlist *ratelimit.DynamicAllowlist
// ErrColl is used to collect errors during refreshes.
ErrColl errcoll.Interface
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs". It must not be nil.
Endpoint *url.URL
// APIKey is the API key used for authentication, if any. If empty, no
// authentication is performed.
APIKey string
}
// RateLimiter is the implementation of the [agdservice.Refresher] interface
// that retrieves the rate limit settings from the business logic backend.
type RateLimiter struct {
logger *slog.Logger
grpcMetrics Metrics
metrics consul.Metrics
allowlist *ratelimit.DynamicAllowlist
errColl errcoll.Interface
client RateLimitServiceClient
apiKey string
}
// NewRateLimiter creates a new properly initialized rate limiter. c must not
// be nil.
func NewRateLimiter(c *RateLimiterConfig) (l *RateLimiter, err error) {
client, err := newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return &RateLimiter{
logger: c.Logger,
grpcMetrics: c.GRPCMetrics,
metrics: c.Metrics,
allowlist: c.Allowlist,
errColl: c.ErrColl,
client: NewRateLimitServiceClient(client),
apiKey: c.APIKey,
}, nil
}
// type check
var _ agdservice.Refresher = (*RateLimiter)(nil)
// Refresh implements the [agdservice.Refresher] interface for *RateLimiter.
func (l *RateLimiter) Refresh(ctx context.Context) (err error) {
l.logger.InfoContext(ctx, "refresh started")
defer l.logger.InfoContext(ctx, "refresh finished")
defer func() { l.metrics.SetStatus(ctx, err) }()
ctx = ctxWithAuthentication(ctx, l.apiKey)
backendResp, err := l.client.GetRateLimitSettings(ctx, &RateLimitSettingsRequest{})
if err != nil {
return fmt.Errorf(
"loading backend rate limit settings: %w",
fixGRPCError(ctx, l.grpcMetrics, err),
)
}
allowedSubnets := backendResp.AllowedSubnets
prefixes := cidrRangeToInternal(ctx, l.errColl, allowedSubnets)
l.allowlist.Update(prefixes)
l.logger.InfoContext(ctx, "refresh successful", "num_records", len(prefixes))
l.metrics.SetSize(ctx, len(prefixes))
return nil
}

View File

@ -0,0 +1,110 @@
package backendpb_test
import (
"context"
"net"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// testRateLimitServiceServer is the [backendpb.RateLimitServiceServer] for
// tests.
type testRateLimitServiceServer struct {
backendpb.UnimplementedRateLimitServiceServer
OnGetRateLimitSettings func(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error)
}
// type check
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
// GetRateLimitSettings implements the [backendpb.RateLimitServiceServer]
// interface for *testRateLimitServiceServer.
func (s *testRateLimitServiceServer) GetRateLimitSettings(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error) {
return s.OnGetRateLimitSettings(ctx, req)
}
func TestRateLimiter_Refresh(t *testing.T) {
var (
allowedIP = netip.MustParseAddr("1.2.3.4")
notAllowedIP = netip.MustParseAddr("4.3.2.1")
cidr = &backendpb.CidrRange{
Address: allowedIP.AsSlice(),
Prefix: 32,
}
)
srv := &testRateLimitServiceServer{
OnGetRateLimitSettings: func(
ctx context.Context,
req *backendpb.RateLimitSettingsRequest,
) (resp *backendpb.RateLimitSettingsResponse, err error) {
return &backendpb.RateLimitSettingsResponse{
AllowedSubnets: []*backendpb.CidrRange{cidr},
}, nil
},
}
ln, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterRateLimitServiceServer(grpcSrv, srv)
go func() {
pt := testutil.PanicT{}
srvErr := grpcSrv.Serve(ln)
require.NoError(pt, srvErr)
}()
t.Cleanup(grpcSrv.GracefulStop)
allowlist := ratelimit.NewDynamicAllowlist(nil, nil)
l, err := backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
Logger: slogutil.NewDiscardLogger(),
Metrics: consul.EmptyMetrics{},
GRPCMetrics: backendpb.EmptyMetrics{},
Allowlist: allowlist,
Endpoint: &url.URL{
Scheme: "grpc",
Host: ln.Addr().String(),
},
})
require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
err = l.Refresh(ctx)
require.NoError(t, err)
ok, err := allowlist.IsAllowed(ctx, allowedIP)
require.NoError(t, err)
assert.True(t, ok)
ok, err = allowlist.IsAllowed(ctx, notAllowedIP)
require.NoError(t, err)
assert.False(t, ok)
}

View File

@ -0,0 +1,96 @@
package backendpb
import (
"context"
"fmt"
"net/url"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"google.golang.org/protobuf/types/known/durationpb"
)
// RemoteKVConfig is the configuration for the business logic backend key-value
// storage.
type RemoteKVConfig struct {
// Metrics is used for the collection of the remote key-value storage
// statistics.
//
// TODO(e.burkov): Perhaps, it worths of a separate metrics interface,
// since it's only used for the collection of the protobuf errors.
Metrics Metrics
// Endpoint is the backend API URL. The scheme should be either "grpc" or
// "grpcs".
Endpoint *url.URL
// APIKey is the API key used for authentication, if any.
APIKey string
// TTL is the TTL of the values in the storage.
TTL time.Duration
}
// RemoteKV is the implementation of the [remotekv.Interface] interface that
// uses the business logic backend as the key-value storage. It is safe for
// concurrent use.
type RemoteKV struct {
metrics Metrics
client RemoteKVServiceClient
apiKey string
ttl time.Duration
}
// NewRemoteKV returns a new [RemoteKV] that retrieves information from the
// business logic backend.
func NewRemoteKV(c *RemoteKVConfig) (kv *RemoteKV, err error) {
client, err := newClient(c.Endpoint)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return &RemoteKV{
metrics: c.Metrics,
client: NewRemoteKVServiceClient(client),
apiKey: c.APIKey,
ttl: c.TTL,
}, nil
}
// type check
var _ remotekv.Interface = (*RemoteKV)(nil)
// Get implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Get(ctx context.Context, key string) (val []byte, ok bool, err error) {
req := &RemoteKVGetRequest{
Key: key,
}
ctx = ctxWithAuthentication(ctx, kv.apiKey)
resp, err := kv.client.Get(ctx, req)
if err != nil {
return nil, false, fmt.Errorf("getting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
}
val = resp.GetData()
return val, val != nil, nil
}
// Set implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Set(ctx context.Context, key string, val []byte) (err error) {
req := &RemoteKVSetRequest{
Key: key,
Data: val,
Ttl: durationpb.New(kv.ttl),
}
ctx = ctxWithAuthentication(ctx, kv.apiKey)
_, err = kv.client.Set(ctx, req)
if err != nil {
return fmt.Errorf("setting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
}
return nil
}

View File

@ -0,0 +1,106 @@
package backendpb_test
import (
"context"
"net"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestRemoteKV_Get(t *testing.T) {
const testTTL = 10 * time.Second
pt := &testutil.PanicT{}
strg := map[string][]byte{}
srv := &testRemoteKVServiceServer{
OnGet: func(
ctx context.Context,
req *backendpb.RemoteKVGetRequest,
) (resp *backendpb.RemoteKVGetResponse, err error) {
resp = &backendpb.RemoteKVGetResponse{
Value: &backendpb.RemoteKVGetResponse_Empty{},
}
if val, ok := strg[req.Key]; ok {
resp.Value = &backendpb.RemoteKVGetResponse_Data{Data: val}
}
return resp, nil
},
OnSet: func(
ctx context.Context,
req *backendpb.RemoteKVSetRequest,
) (resp *backendpb.RemoteKVSetResponse, err error) {
require.Equal(pt, testTTL, req.Ttl.AsDuration())
strg[req.Key] = req.Data
return &backendpb.RemoteKVSetResponse{}, nil
},
}
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcSrv := grpc.NewServer(
grpc.ConnectionTimeout(1*time.Second),
grpc.Creds(insecure.NewCredentials()),
)
backendpb.RegisterRemoteKVServiceServer(grpcSrv, srv)
go func() {
srvErr := grpcSrv.Serve(l)
require.NoError(pt, srvErr)
}()
t.Cleanup(grpcSrv.GracefulStop)
kv, err := backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
Metrics: backendpb.EmptyMetrics{},
Endpoint: &url.URL{
Scheme: "grpc",
Host: l.Addr().String(),
},
APIKey: "apikey",
TTL: testTTL,
})
require.NoError(t, err)
const (
keyWithData = "key"
keyNoData = "unknown"
)
t.Run("success", func(t *testing.T) {
val := []byte("value")
ctx := testutil.ContextWithTimeout(t, testTimeout)
setErr := kv.Set(ctx, keyWithData, val)
require.NoError(t, setErr)
gotVal, ok, getErr := kv.Get(ctx, keyWithData)
require.NoError(t, getErr)
require.True(t, ok)
assert.Equal(t, val, gotVal)
})
t.Run("not_found", func(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
val, ok, getErr := kv.Get(ctx, keyNoData)
require.NoError(t, getErr)
require.False(t, ok)
assert.Nil(t, val)
})
}

View File

@ -20,23 +20,19 @@ type connIndex struct {
} }
// subnetCompare is a comparison function for the two subnets. It returns -1 if // subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting // a sorts before b, 1 if a sorts after b, and 0 if their relative sorting
// position is the same. // position is the same.
func subnetCompare(x, y netip.Prefix) (cmp int) { func subnetCompare(a, b netip.Prefix) (cmp int) {
if x == y { aAddr, aBits := a.Addr(), a.Bits()
return 0 bAddr, bBits := b.Addr(), b.Bits()
}
xAddr, xBits := x.Addr(), x.Bits() switch {
yAddr, yBits := y.Addr(), y.Bits() case aBits > bBits:
if xBits == yBits {
return xAddr.Compare(yAddr)
}
if xBits > yBits {
return -1 return -1
} else { case aBits < bBits:
return 1 return 1
default:
return aAddr.Compare(bAddr)
} }
} }
@ -45,8 +41,8 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
// //
// TODO(a.garipov): Merge with [addListenerChannel]. // TODO(a.garipov): Merge with [addListenerChannel].
func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) { func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
cmpFunc := func(x, y *chanPacketConn) (cmp int) { cmpFunc := func(a, b *chanPacketConn) (cmp int) {
return subnetCompare(x.subnet, y.subnet) return subnetCompare(a.subnet, b.subnet)
} }
newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc) newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc)
@ -67,8 +63,8 @@ func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
// //
// TODO(a.garipov): Merge with [addPacketConnChannel]. // TODO(a.garipov): Merge with [addPacketConnChannel].
func (idx *connIndex) addListener(l *chanListener) (err error) { func (idx *connIndex) addListener(l *chanListener) (err error) {
cmpFunc := func(x, y *chanListener) (cmp int) { cmpFunc := func(a, b *chanListener) (cmp int) {
return subnetCompare(x.subnet, y.subnet) return subnetCompare(a.subnet, b.subnet)
} }
newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc) newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc)

View File

@ -6,7 +6,6 @@ import (
"log/slog" "log/slog"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb" "github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat" "github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
) )
@ -100,8 +100,9 @@ func newBillStatUploader(
mtrc backendpb.Metrics, mtrc backendpb.Metrics,
) (s billstat.Uploader, err error) { ) (s billstat.Uploader, err error) {
apiURL := netutil.CloneURL(&envs.BillStatURL.URL) apiURL := netutil.CloneURL(&envs.BillStatURL.URL)
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) { err = urlutil.ValidateGRPCURL(apiURL)
return nil, fmt.Errorf("invalid backend api url: %s", apiURL) if err != nil {
return nil, fmt.Errorf("billstat api url: %w", err)
} }
return backendpb.NewBillStat(&backendpb.BillStatConfig{ return backendpb.NewBillStat(&backendpb.BillStatConfig{

View File

@ -6,6 +6,7 @@ import (
"log/slog" "log/slog"
"maps" "maps"
"net/netip" "net/netip"
"net/url"
"path" "path"
"path/filepath" "path/filepath"
"slices" "slices"
@ -14,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/access" "github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache" "github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice" "github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb" "github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat" "github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@ -38,9 +38,11 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat" "github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc" "github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/service" "github.com/AdguardTeam/golibs/service"
"github.com/c2h5oh/datasize" "github.com/c2h5oh/datasize"
@ -112,6 +114,8 @@ type builder struct {
ruleStat rulestat.Interface ruleStat rulestat.Interface
safeBrowsing *hashprefix.Filter safeBrowsing *hashprefix.Filter
safeBrowsingHashes *hashprefix.Storage safeBrowsingHashes *hashprefix.Storage
sdeConf *dnsmsg.StructuredDNSErrorsConfig
tlsMtrc tlsconfig.Metrics
webSvc *websvc.Service webSvc *websvc.Service
// The fields below are initialized later, just like with the fields above, // The fields below are initialized later, just like with the fields above,
@ -556,12 +560,42 @@ func (b *builder) initBindToDevice(ctx context.Context) (err error) {
return nil return nil
} }
// Constants for the experimental Structured DNS Errors feature.
//
// TODO(a.garipov): Make configurable.
const (
sdeJustification = "Filtered by AdGuard DNS"
sdeOrganization = "AdGuard DNS"
)
// Variables for the experimental Structured DNS Errors feature.
//
// TODO(a.garipov): Make configurable.
var (
sdeContactURL = &url.URL{
Scheme: "mailto",
Opaque: "support@adguard-dns.io",
}
)
// initMsgConstructor initializes the common DNS message constructor. // initMsgConstructor initializes the common DNS message constructor.
func (b *builder) initMsgConstructor(ctx context.Context) (err error) { func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
fltConf := b.conf.Filters
b.sdeConf = &dnsmsg.StructuredDNSErrorsConfig{
Contact: []*url.URL{
sdeContactURL,
},
Justification: sdeJustification,
Organization: sdeOrganization,
Enabled: fltConf.SDEEnabled,
}
b.messages, err = dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ b.messages, err = dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: b.cloner, Cloner: b.cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: b.conf.Filters.ResponseTTL.Duration, StructuredErrors: b.sdeConf,
FilteredResponseTTL: fltConf.ResponseTTL.Duration,
EDEEnabled: fltConf.EDEEnabled,
}) })
if err != nil { if err != nil {
return fmt.Errorf("creating dns message constructor: %w", err) return fmt.Errorf("creating dns message constructor: %w", err)
@ -579,8 +613,17 @@ func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
// - [builder.initFilteringGroups] // - [builder.initFilteringGroups]
// - [builder.initMsgConstructor] // - [builder.initMsgConstructor]
func (b *builder) initServerGroups(ctx context.Context) (err error) { func (b *builder) initServerGroups(ctx context.Context) (err error) {
mtrc, err := metrics.NewTLSConfig(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering tls metrics: %w", err)
}
b.tlsMtrc = mtrc
c := b.conf c := b.conf
b.serverGroups, err = c.ServerGroups.toInternal( b.serverGroups, err = c.ServerGroups.toInternal(
ctx,
mtrc,
b.messages, b.messages,
b.btdManager, b.btdManager,
b.filteringGroups, b.filteringGroups,
@ -671,7 +714,7 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
} }
} }
tickRot := newTicketRotator(b.baseLogger, b.errColl, b.serverGroups) tickRot := newTicketRotator(b.baseLogger, b.errColl, b.tlsMtrc, b.serverGroups)
err = tickRot.Refresh(ctx) err = tickRot.Refresh(ctx)
if err != nil { if err != nil {
return fmt.Errorf("initial session ticket refresh: %w", err) return fmt.Errorf("initial session ticket refresh: %w", err)
@ -700,6 +743,29 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
return nil return nil
} }
// initGRPCMetrics initializes the gRPC metrics if necessary.
func (b *builder) initGRPCMetrics(ctx context.Context) (err error) {
switch {
case
b.profilesEnabled,
b.conf.Check.RemoteKV.Type == kvModeBackend,
b.conf.RateLimit.Allowlist.Type == rlAllowlistTypeBackend:
// Go on.
default:
// Don't initialize the metrics if no protobuf backend is used.
return nil
}
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering backendbp metrics: %w", err)
}
b.logger.DebugContext(ctx, "initialized grpc metrics")
return nil
}
// initBillStat initializes the billing-statistics recorder if necessary. It // initBillStat initializes the billing-statistics recorder if necessary. It
// also adds the refresher with ID [debugIDBillStat] to the debug refreshers. // also adds the refresher with ID [debugIDBillStat] to the debug refreshers.
func (b *builder) initBillStat(ctx context.Context) (err error) { func (b *builder) initBillStat(ctx context.Context) (err error) {
@ -709,11 +775,6 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
return nil return nil
} }
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
if err != nil {
return fmt.Errorf("registering backendbp metrics: %w", err)
}
upl, err := newBillStatUploader(b.env, b.errColl, b.backendGRPCMtrc) upl, err := newBillStatUploader(b.env, b.errColl, b.backendGRPCMtrc)
if err != nil { if err != nil {
return fmt.Errorf("creating billstat uploader: %w", err) return fmt.Errorf("creating billstat uploader: %w", err)
@ -760,9 +821,9 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
// initProfileDB initializes the profile database if necessary. // initProfileDB initializes the profile database if necessary.
// //
// [builder.initBillStat] and [builder.initServerGroups] must be called before // [builder.initGRPCMetrics] and [builder.initServerGroups] must be called
// this method. It also adds the refresher with ID [debugIDProfileDB] to the // before this method. It also adds the refresher with ID [debugIDProfileDB] to
// debug refreshers. // the debug refreshers.
func (b *builder) initProfileDB(ctx context.Context) (err error) { func (b *builder) initProfileDB(ctx context.Context) (err error) {
if !b.profilesEnabled { if !b.profilesEnabled {
b.profileDB = &profiledb.Disabled{} b.profileDB = &profiledb.Disabled{}
@ -771,8 +832,9 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
} }
apiURL := netutil.CloneURL(&b.env.ProfilesURL.URL) apiURL := netutil.CloneURL(&b.env.ProfilesURL.URL)
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) { err = urlutil.ValidateGRPCURL(apiURL)
return fmt.Errorf("invalid backend api url: %s", apiURL) if err != nil {
return fmt.Errorf("profile api url: %w", err)
} }
respSzEst := b.conf.RateLimit.ResponseSizeEstimate respSzEst := b.conf.RateLimit.ResponseSizeEstimate
@ -842,7 +904,8 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
// initDNSCheck initializes the DNS checker. // initDNSCheck initializes the DNS checker.
// //
// [builder.initMsgConstructor] must be called before this method. // [builder.initGRPCMetrics] and [builder.initMsgConstructor] must be called
// before this method.
func (b *builder) initDNSCheck(ctx context.Context) (err error) { func (b *builder) initDNSCheck(ctx context.Context) (err error) {
b.dnsCheck = b.plugins.DNSCheck() b.dnsCheck = b.plugins.DNSCheck()
if b.dnsCheck != nil { if b.dnsCheck != nil {
@ -853,7 +916,14 @@ func (b *builder) initDNSCheck(ctx context.Context) (err error) {
c := b.conf.Check c := b.conf.Check
checkConf, err := c.toInternal(b.env, b.messages, b.errColl, b.mtrcNamespace, b.promRegisterer) checkConf, err := c.toInternal(
b.env,
b.messages,
b.errColl,
b.mtrcNamespace,
b.promRegisterer,
b.backendGRPCMtrc,
)
if err != nil { if err != nil {
return fmt.Errorf("initializing dnscheck: %w", err) return fmt.Errorf("initializing dnscheck: %w", err)
} }
@ -911,18 +981,44 @@ func (b *builder) initRuleStat(ctx context.Context) (err error) {
// well as starts and registers the rate-limiter refresher in the signal // well as starts and registers the rate-limiter refresher in the signal
// handler. It also adds the refresher with ID [debugIDAllowlist] to the debug // handler. It also adds the refresher with ID [debugIDAllowlist] to the debug
// refreshers. // refreshers.
//
// [builder.initGRPCMetrics] must be called before this method.
func (b *builder) initRateLimiter(ctx context.Context) (err error) { func (b *builder) initRateLimiter(ctx context.Context) (err error) {
c := b.conf.RateLimit c := b.conf.RateLimit
allowSubnets := netutil.UnembedPrefixes(c.Allowlist.List) allowSubnets := netutil.UnembedPrefixes(c.Allowlist.List)
allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil) allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil)
updater := consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
Logger: b.baseLogger.With(slogutil.KeyPrefix, "ratelimit_allowlist_updater"), typ := b.conf.RateLimit.Allowlist.Type
Allowlist: allowlist, mtrc, err := metrics.NewAllowlist(b.mtrcNamespace, b.promRegisterer, typ)
ConsulURL: &b.env.ConsulAllowlistURL.URL, if err != nil {
ErrColl: b.errColl, return fmt.Errorf("ratelimit metrics: %w", err)
// TODO(a.garipov): Make configurable. }
Timeout: 15 * time.Second,
}) var updater agdservice.Refresher
if typ == rlAllowlistTypeBackend {
updater, err = backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
Logger: b.baseLogger.With(slogutil.KeyPrefix, "backend_ratelimiter"),
Metrics: mtrc,
GRPCMetrics: b.backendGRPCMtrc,
Allowlist: allowlist,
Endpoint: &b.env.BackendRateLimitURL.URL,
ErrColl: b.errColl,
APIKey: b.env.BackendRateLimitAPIKey,
})
if err != nil {
return fmt.Errorf("ratelimit: %w", err)
}
} else {
updater = consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
Logger: b.baseLogger.With(slogutil.KeyPrefix, "ratelimit_allowlist_updater"),
Allowlist: allowlist,
ConsulURL: &b.env.ConsulAllowlistURL.URL,
ErrColl: b.errColl,
Metrics: mtrc,
// TODO(a.garipov): Make configurable.
Timeout: 15 * time.Second,
})
}
err = updater.Refresh(ctx) err = updater.Refresh(ctx)
if err != nil { if err != nil {
@ -956,9 +1052,11 @@ func (b *builder) initRateLimiter(ctx context.Context) (err error) {
// initWeb initializes the web service, starts it, and registers it in the // initWeb initializes the web service, starts it, and registers it in the
// signal handler. // signal handler.
//
// [builder.initServerGroups] must be called before this method.
func (b *builder) initWeb(ctx context.Context) (err error) { func (b *builder) initWeb(ctx context.Context) (err error) {
c := b.conf.Web c := b.conf.Web
webConf, err := c.toInternal(b.env, b.dnsCheck, b.errColl) webConf, err := c.toInternal(ctx, b.env, b.dnsCheck, b.errColl, b.tlsMtrc)
if err != nil { if err != nil {
return fmt.Errorf("converting web configuration: %w", err) return fmt.Errorf("converting web configuration: %w", err)
} }
@ -1057,45 +1155,55 @@ func (b *builder) initDNS(ctx context.Context) (err error) {
b.fwdHandler = forward.NewHandler(b.conf.Upstream.toInternal(b.baseLogger)) b.fwdHandler = forward.NewHandler(b.conf.Upstream.toInternal(b.baseLogger))
b.dnsDB = b.conf.DNSDB.toInternal(b.errColl) b.dnsDB = b.conf.DNSDB.toInternal(b.errColl)
cacheConf := b.conf.Cache dnsHdlrsConf := &dnssvc.HandlersConfig{
dnsConf := &dnssvc.Config{
BaseLogger: b.baseLogger, BaseLogger: b.baseLogger,
Messages: b.messages, Cache: b.conf.Cache.toInternal(),
Cloner: b.cloner, Cloner: b.cloner,
ControlConf: b.controlConf,
ConnLimiter: b.connLimit,
HumanIDParser: agd.NewHumanIDParser(), HumanIDParser: agd.NewHumanIDParser(),
Messages: b.messages,
PluginRegistry: b.plugins, PluginRegistry: b.plugins,
StructuredErrors: b.sdeConf,
AccessManager: b.access, AccessManager: b.access,
SafeBrowsing: b.hashMatcher,
BillStat: b.billStat, BillStat: b.billStat,
CacheManager: b.cacheManager, CacheManager: b.cacheManager,
ProfileDB: b.profileDB,
PrometheusRegisterer: b.promRegisterer,
DNSCheck: b.dnsCheck, DNSCheck: b.dnsCheck,
NonDNS: b.webSvc,
DNSDB: b.dnsDB, DNSDB: b.dnsDB,
ErrColl: b.errColl, ErrColl: b.errColl,
FilterStorage: b.filterStorage, FilterStorage: b.filterStorage,
GeoIP: b.geoIP, GeoIP: b.geoIP,
Handler: b.fwdHandler, Handler: b.fwdHandler,
HashMatcher: b.hashMatcher,
ProfileDB: b.profileDB,
PrometheusRegisterer: b.promRegisterer,
QueryLog: b.queryLog(), QueryLog: b.queryLog(),
RuleStat: b.ruleStat,
RateLimit: b.rateLimit, RateLimit: b.rateLimit,
RuleStat: b.ruleStat,
MetricsNamespace: b.mtrcNamespace, MetricsNamespace: b.mtrcNamespace,
FilteringGroups: b.filteringGroups, FilteringGroups: b.filteringGroups,
ServerGroups: b.serverGroups, ServerGroups: b.serverGroups,
HandleTimeout: b.conf.DNS.HandleTimeout.Duration, EDEEnabled: b.conf.Filters.EDEEnabled,
CacheSize: cacheConf.Size, }
ECSCacheSize: cacheConf.ECSSize,
CacheMinTTL: cacheConf.TTLOverride.Min.Duration, dnsHdlrs, err := dnssvc.NewHandlers(ctx, dnsHdlrsConf)
UseCacheTTLOverride: cacheConf.TTLOverride.Enabled, if err != nil {
UseECSCache: cacheConf.Type == cacheTypeECS, return fmt.Errorf("dns handlers: %w", err)
}
dnsConf := &dnssvc.Config{
Handlers: dnsHdlrs,
Cloner: b.cloner,
ControlConf: b.controlConf,
ConnLimiter: b.connLimit,
NonDNS: b.webSvc,
ErrColl: b.errColl,
MetricsNamespace: b.mtrcNamespace,
ServerGroups: b.serverGroups,
HandleTimeout: b.conf.DNS.HandleTimeout.Duration,
} }
b.dnsSvc, err = dnssvc.New(dnsConf) b.dnsSvc, err = dnssvc.New(dnsConf)
if err != nil { if err != nil {
return fmt.Errorf("initializing dns: %w", err) return fmt.Errorf("dns service: %w", err)
} }
b.logger.DebugContext(ctx, "initialized dns") b.logger.DebugContext(ctx, "initialized dns")

View File

@ -3,6 +3,7 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
) )
@ -43,6 +44,28 @@ const (
cacheTypeSimple = "simple" cacheTypeSimple = "simple"
) )
// toInternal converts c to the cache configuration for the DNS server. c must
// be valid.
func (c *cacheConfig) toInternal() (cacheConf *dnssvc.CacheConfig) {
var typ dnssvc.CacheType
if c.Size == 0 {
// TODO(a.garipov): Add as a type in the configuration file.
typ = dnssvc.CacheTypeNone
} else if c.Type == cacheTypeSimple {
typ = dnssvc.CacheTypeSimple
} else {
typ = dnssvc.CacheTypeECS
}
return &dnssvc.CacheConfig{
MinTTL: c.TTLOverride.Min.Duration,
ECSCount: c.ECSSize,
NoECSCount: c.Size,
Type: typ,
OverrideCacheTTL: c.TTLOverride.Enabled,
}
}
// type check // type check
var _ validator = (*cacheConfig)(nil) var _ validator = (*cacheConfig)(nil)

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -54,8 +55,9 @@ func (c *checkConfig) toInternal(
errColl errcoll.Interface, errColl errcoll.Interface,
namespace string, namespace string,
reg prometheus.Registerer, reg prometheus.Registerer,
backendMtrc backendpb.Metrics,
) (conf *dnscheck.RemoteKVConfig, err error) { ) (conf *dnscheck.RemoteKVConfig, err error) {
kv, err := newDNSCheckKV(c, envs, namespace, reg) kv, err := newRemoteKV(c.RemoteKV, envs, namespace, reg, backendMtrc)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
@ -85,21 +87,33 @@ const maxRespSize = 1 * datasize.MB
// [remotekv.KeyNamespace]. // [remotekv.KeyNamespace].
const keyNamespaceCheck = "check" const keyNamespaceCheck = "check"
// newDNSCheckKV returns a new properly initialized remote key-value storage. // newRemoteKV returns a new properly initialized remote key-value storage.
func newDNSCheckKV( func newRemoteKV(
conf *checkConfig, c *remoteKVConfig,
envs *environment, envs *environment,
namespace string, namespace string,
reg prometheus.Registerer, reg prometheus.Registerer,
backendMtrc backendpb.Metrics,
) (kv remotekv.Interface, err error) { ) (kv remotekv.Interface, err error) {
if conf.RemoteKV.Type == kvModeRedis { switch c.Type {
case kvModeBackend:
kv, err = backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
Metrics: backendMtrc,
Endpoint: &envs.DNSCheckRemoteKVURL.URL,
APIKey: envs.DNSCheckRemoteKVAPIKey,
TTL: c.TTL.Duration,
})
if err != nil {
return nil, fmt.Errorf("initializing backend dnscheck kv: %w", err)
}
case kvModeRedis:
var redisKVMtrc rediskv.Metrics var redisKVMtrc rediskv.Metrics
redisKVMtrc, err = metrics.NewRedisKV(namespace, reg) redisKVMtrc, err = metrics.NewRedisKV(namespace, reg)
if err != nil { if err != nil {
return nil, fmt.Errorf("registering redis kv metrics: %w", err) return nil, fmt.Errorf("registering redis kv metrics: %w", err)
} }
kv := rediskv.NewRedisKV(&rediskv.RedisKVConfig{ redisKV := rediskv.NewRedisKV(&rediskv.RedisKVConfig{
Metrics: redisKVMtrc, Metrics: redisKVMtrc,
Addr: &netutil.HostPort{ Addr: &netutil.HostPort{
Host: envs.RedisAddr, Host: envs.RedisAddr,
@ -108,18 +122,20 @@ func newDNSCheckKV(
MaxActive: envs.RedisMaxActive, MaxActive: envs.RedisMaxActive,
MaxIdle: envs.RedisMaxIdle, MaxIdle: envs.RedisMaxIdle,
IdleTimeout: envs.RedisIdleTimeout.Duration, IdleTimeout: envs.RedisIdleTimeout.Duration,
TTL: conf.RemoteKV.TTL.Duration, TTL: c.TTL.Duration,
}) })
return remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{ kv = remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{
KV: kv, KV: redisKV,
Prefix: fmt.Sprintf("%s:%s:", envs.RedisKeyPrefix, keyNamespaceCheck), Prefix: fmt.Sprintf("%s:%s:", envs.RedisKeyPrefix, keyNamespaceCheck),
}), nil })
} case kvModeConsul:
consulKVURL := envs.ConsulDNSCheckKVURL
consulSessionURL := envs.ConsulDNSCheckSessionURL
if consulKVURL == nil || consulSessionURL == nil {
return remotekv.Empty{}, nil
}
consulKVURL := envs.ConsulDNSCheckKVURL
consulSessionURL := envs.ConsulDNSCheckSessionURL
if consulKVURL != nil && consulSessionURL != nil {
kv, err = consulkv.NewKV(&consulkv.Config{ kv, err = consulkv.NewKV(&consulkv.Config{
URL: &consulKVURL.URL, URL: &consulKVURL.URL,
SessionURL: &consulSessionURL.URL, SessionURL: &consulSessionURL.URL,
@ -129,14 +145,14 @@ func newDNSCheckKV(
}), }),
// TODO(ameshkov): Consider making configurable. // TODO(ameshkov): Consider making configurable.
Limiter: rate.NewLimiter(rate.Limit(200)/60, 1), Limiter: rate.NewLimiter(rate.Limit(200)/60, 1),
TTL: conf.RemoteKV.TTL.Duration, TTL: c.TTL.Duration,
MaxRespSize: maxRespSize, MaxRespSize: maxRespSize,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing consul dnscheck: %w", err) return nil, fmt.Errorf("initializing consul dnscheck kv: %w", err)
} }
} else { default:
kv = remotekv.Empty{} return remotekv.Empty{}, nil
} }
return kv, nil return kv, nil
@ -221,15 +237,16 @@ func validateNonNilIPs(ips []netip.Addr, fam netutil.AddrFamily) (err error) {
// DNSCheck key-value database modes. // DNSCheck key-value database modes.
const ( const (
kvModeConsul = "consul" kvModeBackend = "backend"
kvModeRedis = "redis" kvModeConsul = "consul"
kvModeRedis = "redis"
) )
// remoteKVConfig is remote key-value store configuration for DNS server // remoteKVConfig is remote key-value store configuration for DNS server
// checking. // checking.
type remoteKVConfig struct { type remoteKVConfig struct {
// Type defines the type of remote key-value store. Allowed values are // Type defines the type of remote key-value store. Allowed values are
// [kvModeConsul] and [kvModeRedis]. // [kvModeBackend], [kvModeConsul] and [kvModeRedis].
Type string `yaml:"type"` Type string `yaml:"type"`
// TTL defines, for how long to keep the information about a single client. // TTL defines, for how long to keep the information about a single client.
@ -248,6 +265,10 @@ func (c *remoteKVConfig) validate() (err error) {
ttl := c.TTL ttl := c.TTL
switch c.Type { switch c.Type {
case kvModeBackend:
if ttl.Duration <= 0 {
return newNotPositiveError("ttl", ttl)
}
case kvModeConsul: case kvModeConsul:
if ttl.Duration < consulkv.MinTTL || ttl.Duration > consulkv.MaxTTL { if ttl.Duration < consulkv.MinTTL || ttl.Duration > consulkv.MaxTTL {
return fmt.Errorf( return fmt.Errorf(
@ -270,7 +291,7 @@ func (c *remoteKVConfig) validate() (err error) {
case "": case "":
return fmt.Errorf("type: %w", errors.ErrEmptyValue) return fmt.Errorf("type: %w", errors.ErrEmptyValue)
default: default:
return fmt.Errorf("type: %q: %w", c.Type, errors.ErrBadEnumValue) return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
} }
return nil return nil

View File

@ -100,6 +100,8 @@ func Main(plugins *plugin.Registry) {
errors.Check(b.initTLS(ctx)) errors.Check(b.initTLS(ctx))
errors.Check(b.initGRPCMetrics(ctx))
errors.Check(b.initBillStat(ctx)) errors.Check(b.initBillStat(ctx))
errors.Check(b.initProfileDB(ctx)) errors.Check(b.initProfileDB(ctx))

View File

@ -6,9 +6,10 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc" "github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
@ -25,13 +26,17 @@ import (
) )
// environment represents the configuration that is kept in the environment. // environment represents the configuration that is kept in the environment.
//
// TODO(e.burkov, a.garipov): Name variables more consistently.
type environment struct { type environment struct {
AdultBlockingURL *urlutil.URL `env:"ADULT_BLOCKING_URL"` AdultBlockingURL *urlutil.URL `env:"ADULT_BLOCKING_URL"`
BackendRateLimitURL *urlutil.URL `env:"BACKEND_RATELIMIT_URL"`
BillStatURL *urlutil.URL `env:"BILLSTAT_URL"` BillStatURL *urlutil.URL `env:"BILLSTAT_URL"`
BlockedServiceIndexURL *urlutil.URL `env:"BLOCKED_SERVICE_INDEX_URL"` BlockedServiceIndexURL *urlutil.URL `env:"BLOCKED_SERVICE_INDEX_URL"`
ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL,notEmpty"` ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL"`
ConsulDNSCheckKVURL *urlutil.URL `env:"CONSUL_DNSCHECK_KV_URL"` ConsulDNSCheckKVURL *urlutil.URL `env:"CONSUL_DNSCHECK_KV_URL"`
ConsulDNSCheckSessionURL *urlutil.URL `env:"CONSUL_DNSCHECK_SESSION_URL"` ConsulDNSCheckSessionURL *urlutil.URL `env:"CONSUL_DNSCHECK_SESSION_URL"`
DNSCheckRemoteKVURL *urlutil.URL `env:"DNSCHECK_REMOTEKV_URL"`
FilterIndexURL *urlutil.URL `env:"FILTER_INDEX_URL,notEmpty"` FilterIndexURL *urlutil.URL `env:"FILTER_INDEX_URL,notEmpty"`
GeneralSafeSearchURL *urlutil.URL `env:"GENERAL_SAFE_SEARCH_URL"` GeneralSafeSearchURL *urlutil.URL `env:"GENERAL_SAFE_SEARCH_URL"`
LinkedIPTargetURL *urlutil.URL `env:"LINKED_IP_TARGET_URL"` LinkedIPTargetURL *urlutil.URL `env:"LINKED_IP_TARGET_URL"`
@ -41,23 +46,25 @@ type environment struct {
SafeBrowsingURL *urlutil.URL `env:"SAFE_BROWSING_URL"` SafeBrowsingURL *urlutil.URL `env:"SAFE_BROWSING_URL"`
YoutubeSafeSearchURL *urlutil.URL `env:"YOUTUBE_SAFE_SEARCH_URL"` YoutubeSafeSearchURL *urlutil.URL `env:"YOUTUBE_SAFE_SEARCH_URL"`
BillStatAPIKey string `env:"BILLSTAT_API_KEY"` BackendRateLimitAPIKey string `env:"BACKEND_RATELIMIT_API_KEY"`
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"` BillStatAPIKey string `env:"BILLSTAT_API_KEY"`
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"` ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"` DNSCheckRemoteKVAPIKey string `env:"DNSCHECK_REMOTEKV_API_KEY"`
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"` FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
ProfilesAPIKey string `env:"PROFILES_API_KEY"` GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"` GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
RedisAddr string `env:"REDIS_ADDR"` ProfilesAPIKey string `env:"PROFILES_API_KEY"`
RedisKeyPrefix string `env:"REDIS_KEY_PREFIX" envDefault:"agdns"` ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"`
QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"` RedisAddr string `env:"REDIS_ADDR"`
SSLKeyLogFile string `env:"SSL_KEY_LOG_FILE"` RedisKeyPrefix string `env:"REDIS_KEY_PREFIX" envDefault:"agdns"`
SentryDSN string `env:"SENTRY_DSN" envDefault:"stderr"` QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"`
WebStaticDir string `env:"WEB_STATIC_DIR"` SSLKeyLogFile string `env:"SSL_KEY_LOG_FILE"`
SentryDSN string `env:"SENTRY_DSN" envDefault:"stderr"`
WebStaticDir string `env:"WEB_STATIC_DIR"`
ListenAddr net.IP `env:"LISTEN_ADDR" envDefault:"127.0.0.1"` ListenAddr net.IP `env:"LISTEN_ADDR" envDefault:"127.0.0.1"`
ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"8MB"` ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"64MB"`
RedisIdleTimeout timeutil.Duration `env:"REDIS_IDLE_TIMEOUT" envDefault:"30s"` RedisIdleTimeout timeutil.Duration `env:"REDIS_IDLE_TIMEOUT" envDefault:"30s"`
@ -100,7 +107,8 @@ func (envs *environment) validate() (err error) {
errs = envs.validateHTTPURLs(errs) errs = envs.validateHTTPURLs(errs)
if s := envs.FilterIndexURL.Scheme; s != agdhttp.SchemeFile && !agdhttp.CheckHTTPURLScheme(s) { if s := envs.FilterIndexURL.Scheme; !strings.EqualFold(s, urlutil.SchemeFile) &&
!urlutil.IsValidHTTPURLScheme(s) {
errs = append(errs, fmt.Errorf( errs = append(errs, fmt.Errorf(
"env %s: not a valid http(s) url or file uri", "env %s: not a valid http(s) url or file uri",
"FILTER_INDEX_URL", "FILTER_INDEX_URL",
@ -140,10 +148,6 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
url: envs.BlockedServiceIndexURL, url: envs.BlockedServiceIndexURL,
name: "BLOCKED_SERVICE_INDEX_URL", name: "BLOCKED_SERVICE_INDEX_URL",
isRequired: bool(envs.BlockedServiceEnabled), isRequired: bool(envs.BlockedServiceEnabled),
}, {
url: envs.ConsulAllowlistURL,
name: "CONSUL_ALLOWLIST_URL",
isRequired: true,
}, { }, {
url: envs.ConsulDNSCheckKVURL, url: envs.ConsulDNSCheckKVURL,
name: "CONSUL_DNSCHECK_KV_URL", name: "CONSUL_DNSCHECK_KV_URL",
@ -184,15 +188,14 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
continue continue
} }
u := urlData.url var u *url.URL
if u == nil { if urlData.url != nil {
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue)) u = &urlData.url.URL
continue
} }
if !agdhttp.CheckHTTPURLScheme(u.Scheme) { err := urlutil.ValidateHTTPURL(u)
res = append(res, fmt.Errorf("env %s: not a valid http(s) url", urlData.name)) if err != nil {
res = append(res, fmt.Errorf("env %s: %w", urlData.name, err))
} }
} }
@ -227,59 +230,86 @@ func (envs *environment) validateWebStaticDir() (err error) {
// validateFromValidConfig returns an error if environment variables that depend // validateFromValidConfig returns an error if environment variables that depend
// on configuration properties contain errors. conf is expected to be valid. // on configuration properties contain errors. conf is expected to be valid.
func (envs *environment) validateFromValidConfig(conf *configuration) (err error) { func (envs *environment) validateFromValidConfig(conf *configuration) (err error) {
err = envs.validateRedis(conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if !conf.isProfilesEnabled() {
return nil
}
if envs.ProfilesMaxRespSize > math.MaxInt {
return fmt.Errorf(
"PROFILES_MAX_RESP_SIZE: %w: must be less than or equal to %s, got %s",
errors.ErrOutOfRange,
datasize.ByteSize(math.MaxInt),
envs.ProfilesMaxRespSize,
)
}
return envs.validateProfilesURLs()
}
// validateRedis returns an error if environment variables for Redis as a remote
// key-value store for DNS server checking contain errors.
func (envs *environment) validateRedis(conf *configuration) (err error) {
if conf.Check.RemoteKV.Type != kvModeRedis {
return nil
}
var errs []error var errs []error
if envs.RedisAddr == "" {
errs = append(errs, fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue)) switch typ := conf.Check.RemoteKV.Type; typ {
case kvModeRedis:
errs = envs.validateRedis(errs)
case kvModeBackend:
errs = envs.validateBackendKV(errs)
default:
// Probably consul.
} }
if envs.RedisIdleTimeout.Duration <= 0 { if conf.isProfilesEnabled() {
errs = append(errs, newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout)) errs = envs.validateProfilesURLs(errs)
if envs.ProfilesMaxRespSize > math.MaxInt {
errs = append(errs, fmt.Errorf(
"PROFILES_MAX_RESP_SIZE: %w: must be less than or equal to %s, got %s",
errors.ErrOutOfRange,
datasize.ByteSize(math.MaxInt),
envs.ProfilesMaxRespSize,
))
}
} }
if envs.RedisMaxActive < 0 { errs = envs.validateRateLimitURLs(conf, errs)
errs = append(errs, newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive))
}
if envs.RedisMaxIdle < 0 {
errs = append(errs, newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle))
}
return errors.Join(errs...) return errors.Join(errs...)
} }
// validateRedis appends validation errors to the given errs if environment
// variables for Redis contain errors.
func (envs *environment) validateRedis(errs []error) (withRedis []error) {
withRedis = errs
if envs.RedisAddr == "" {
err := fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue)
withRedis = append(withRedis, err)
}
if envs.RedisIdleTimeout.Duration <= 0 {
err := newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout)
withRedis = append(withRedis, err)
}
if envs.RedisMaxActive < 0 {
err := newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive)
withRedis = append(withRedis, err)
}
if envs.RedisMaxIdle < 0 {
err := newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle)
withRedis = append(withRedis, err)
}
return withRedis
}
// validateBackendKV appends validation errors to the given errs if environment
// variables for a backend key-value store contain errors.
func (envs *environment) validateBackendKV(errs []error) (withKV []error) {
withKV = errs
var u *url.URL
if envs.DNSCheckRemoteKVURL != nil {
u = &envs.DNSCheckRemoteKVURL.URL
}
err := urlutil.ValidateGRPCURL(u)
if err != nil {
withKV = append(withKV, fmt.Errorf("env DNSCHECK_REMOTEKV_URL: %w", err))
}
return withKV
}
// validateProfilesURLs appends validation errors to the given errs if profiles // validateProfilesURLs appends validation errors to the given errs if profiles
// URLs in environment variables are invalid. All errors are appended to errs // URLs in environment variables are invalid.
// and returned as res. func (envs *environment) validateProfilesURLs(errs []error) (withURLs []error) {
func (envs *environment) validateProfilesURLs() (err error) { withURLs = errs
grpcOnlyURLs := []*urlEnvData{{ grpcOnlyURLs := []*urlEnvData{{
url: envs.BillStatURL, url: envs.BillStatURL,
name: "BILLSTAT_URL", name: "BILLSTAT_URL",
@ -290,24 +320,52 @@ func (envs *environment) validateProfilesURLs() (err error) {
isRequired: true, isRequired: true,
}} }}
var res []error
for _, urlData := range grpcOnlyURLs { for _, urlData := range grpcOnlyURLs {
if !urlData.isRequired { if !urlData.isRequired {
continue continue
} }
if urlData.url == nil { var u *url.URL
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue)) if urlData.url != nil {
u = &urlData.url.URL
continue
} }
if !agdhttp.CheckGRPCURLScheme(urlData.url.Scheme) { err := urlutil.ValidateGRPCURL(u)
res = append(res, fmt.Errorf("env %s: not a valid grpc(s) url", urlData.name)) if err != nil {
withURLs = append(withURLs, fmt.Errorf("env %s: %w", urlData.name, err))
} }
} }
return errors.Join(res...) return withURLs
}
// validateRateLimitURLs appends validation errors to the given errs if rate
// limit URLs in environment variables are invalid.
func (envs *environment) validateRateLimitURLs(
conf *configuration,
errs []error,
) (withURLs []error) {
rlURL := envs.BackendRateLimitURL
rlEnv := "BACKEND_RATELIMIT_URL"
validateFunc := urlutil.ValidateGRPCURL
if conf.RateLimit.Allowlist.Type == rlAllowlistTypeConsul {
rlURL = envs.ConsulAllowlistURL
rlEnv = "CONSUL_ALLOWLIST_URL"
validateFunc = urlutil.ValidateHTTPURL
}
var u *url.URL
if rlURL != nil {
u = &rlURL.URL
}
err := validateFunc(u)
if err != nil {
return append(errs, fmt.Errorf("env %s: %w", rlEnv, err))
}
return errs
} }
// configureLogs sets the configuration for the plain text logs. It also // configureLogs sets the configuration for the plain text logs. It also

View File

@ -58,6 +58,12 @@ type filtersConfig struct {
// MaxSize is the maximum size of the downloadable filtering rule-list. // MaxSize is the maximum size of the downloadable filtering rule-list.
MaxSize datasize.ByteSize `yaml:"max_size"` MaxSize datasize.ByteSize `yaml:"max_size"`
// EDEEnabled enables the Extended DNS Errors feature.
EDEEnabled bool `yaml:"ede_enabled"`
// SDEEnabled enables the experimental Structured DNS Errors feature.
SDEEnabled bool `yaml:"sde_enabled"`
} }
// toInternal converts c to the filter storage configuration for the DNS server. // toInternal converts c to the filter storage configuration for the DNS server.
@ -117,33 +123,31 @@ var _ validator = (*filtersConfig)(nil)
// validate implements the [validator] interface for *filtersConfig. // validate implements the [validator] interface for *filtersConfig.
func (c *filtersConfig) validate() (err error) { func (c *filtersConfig) validate() (err error) {
switch { if c == nil {
case c == nil:
return errors.ErrNoValue return errors.ErrNoValue
case c.SafeSearchCacheSize <= 0: }
return newNotPositiveError("safe_search_cache_size", c.SafeSearchCacheSize)
case c.ResponseTTL.Duration <= 0: errs := []error{
return newNotPositiveError("response_ttl", c.ResponseTTL) validatePositive("custom_filter_cache_size", c.CustomFilterCacheSize),
case c.RefreshIvl.Duration <= 0: validatePositive("safe_search_cache_size", c.SafeSearchCacheSize),
return newNotPositiveError("refresh_interval", c.RefreshIvl) validatePositive("response_ttl", c.ResponseTTL),
case c.RefreshTimeout.Duration <= 0: validatePositive("refresh_interval", c.RefreshIvl),
return newNotPositiveError("refresh_timeout", c.RefreshTimeout) validatePositive("refresh_timeout", c.RefreshTimeout),
case c.IndexRefreshTimeout.Duration <= 0: validatePositive("index_refresh_timeout", c.IndexRefreshTimeout),
return newNotPositiveError("index_refresh_timeout", c.IndexRefreshTimeout) validatePositive("rule_list_refresh_timeout", c.RuleListRefreshTimeout),
case c.RuleListRefreshTimeout.Duration <= 0: validatePositive("max_size", c.MaxSize),
return newNotPositiveError("rule_list_refresh_timeout", c.RuleListRefreshTimeout) }
case c.MaxSize <= 0:
return newNotPositiveError("max_size", c.MaxSize) if !c.EDEEnabled && c.SDEEnabled {
default: errs = append(errs, errors.Error("ede must be enabled to enable sde"))
// Go on.
} }
err = c.RuleListCache.validate() err = c.RuleListCache.validate()
if err != nil { if err != nil {
return fmt.Errorf("rule_list_cache: %w", err) errs = append(errs, fmt.Errorf("rule_list_cache: %w", err))
} }
return nil return errors.Join(errs...)
} }
// fltRuleListCache contains filtering rule-list cache configuration. // fltRuleListCache contains filtering rule-list cache configuration.

View File

@ -4,6 +4,7 @@ package plugin
import ( import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
) )
@ -13,16 +14,19 @@ import (
type Registry struct { type Registry struct {
dnscheck dnscheck.Interface dnscheck dnscheck.Interface
mainMwMtrc metrics.MainMiddleware mainMwMtrc metrics.MainMiddleware
postInitMw dnsserver.Middleware
} }
// NewRegistry returns a new registry with the given custom implementations. // NewRegistry returns a new registry with the given custom implementations.
func NewRegistry( func NewRegistry(
dnsCk dnscheck.Interface, dnsCk dnscheck.Interface,
mainMwMtrc metrics.MainMiddleware, mainMwMtrc metrics.MainMiddleware,
postInitMw dnsserver.Middleware,
) (r *Registry) { ) (r *Registry) {
return &Registry{ return &Registry{
dnscheck: dnsCk, dnscheck: dnsCk,
mainMwMtrc: mainMwMtrc, mainMwMtrc: mainMwMtrc,
postInitMw: postInitMw,
} }
} }
@ -44,3 +48,13 @@ func (r *Registry) MainMiddlewareMetrics() (mainMwMtrc metrics.MainMiddleware) {
return r.mainMwMtrc return r.mainMwMtrc
} }
// PostInitialMiddleware returns a custom implementation of the post-initial
// middleware, if any.
func (r *Registry) PostInitialMiddleware() (postInitMw dnsserver.Middleware) {
if r == nil {
return nil
}
return r.postInitMw
}

View File

@ -15,6 +15,12 @@ import (
"github.com/c2h5oh/datasize" "github.com/c2h5oh/datasize"
) )
// Constants for rate limit settings endpoints.
const (
rlAllowlistTypeBackend = "backend"
rlAllowlistTypeConsul = "consul"
)
// rateLimitConfig is the configuration of the instance's rate limiting. // rateLimitConfig is the configuration of the instance's rate limiting.
type rateLimitConfig struct { type rateLimitConfig struct {
// AllowList is the allowlist of clients. // AllowList is the allowlist of clients.
@ -57,20 +63,14 @@ type rateLimitConfig struct {
RefuseANY bool `yaml:"refuse_any"` RefuseANY bool `yaml:"refuse_any"`
} }
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []netutil.Prefix `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6 // rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
// addresses. // addresses.
type rateLimitOptions struct { type rateLimitOptions struct {
// RPS is the maximum number of requests per second. // Count is the maximum number of requests per interval.
RPS uint `yaml:"rps"` Count uint `yaml:"count"`
// Interval is the time during which to count the number of requests.
Interval timeutil.Duration `yaml:"interval"`
// SubnetKeyLen is the length of the subnet prefix used to calculate // SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys. // rate limiter bucket keys.
@ -87,7 +87,8 @@ func (o *rateLimitOptions) validate() (err error) {
} }
return cmp.Or( return cmp.Or(
validatePositive("rps", o.RPS), validatePositive("count", o.Count),
validatePositive("interval", o.Interval),
validatePositive("subnet_key_len", o.SubnetKeyLen), validatePositive("subnet_key_len", o.SubnetKeyLen),
) )
} }
@ -100,9 +101,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
ResponseSizeEstimate: c.ResponseSizeEstimate, ResponseSizeEstimate: c.ResponseSizeEstimate,
Duration: c.BackoffDuration.Duration, Duration: c.BackoffDuration.Duration,
Period: c.BackoffPeriod.Duration, Period: c.BackoffPeriod.Duration,
IPv4RPS: c.IPv4.RPS, IPv4Count: c.IPv4.Count,
IPv4Interval: c.IPv4.Interval.Duration,
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen, IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
IPv6RPS: c.IPv6.RPS, IPv6Count: c.IPv6.Count,
IPv6Interval: c.IPv6.Interval.Duration,
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen, IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
Count: c.BackoffCount, Count: c.BackoffCount,
RefuseANY: c.RefuseANY, RefuseANY: c.RefuseANY,
@ -114,14 +117,12 @@ var _ validator = (*rateLimitConfig)(nil)
// validate implements the [validator] interface for *rateLimitConfig. // validate implements the [validator] interface for *rateLimitConfig.
func (c *rateLimitConfig) validate() (err error) { func (c *rateLimitConfig) validate() (err error) {
switch { if c == nil {
case c == nil:
return errors.ErrNoValue return errors.ErrNoValue
case c.Allowlist == nil:
return fmt.Errorf("allowlist: %w", errors.ErrNoValue)
} }
return cmp.Or( return cmp.Or(
validateProp("allowlist", c.Allowlist.validate),
validateProp("connection_limit", c.ConnectionLimit.validate), validateProp("connection_limit", c.ConnectionLimit.validate),
validateProp("ipv4", c.IPv4.validate), validateProp("ipv4", c.IPv4.validate),
validateProp("ipv6", c.IPv6.validate), validateProp("ipv6", c.IPv6.validate),
@ -131,10 +132,41 @@ func (c *rateLimitConfig) validate() (err error) {
validatePositive("backoff_duration", c.BackoffDuration), validatePositive("backoff_duration", c.BackoffDuration),
validatePositive("backoff_period", c.BackoffPeriod), validatePositive("backoff_period", c.BackoffPeriod),
validatePositive("response_size_estimate", c.ResponseSizeEstimate), validatePositive("response_size_estimate", c.ResponseSizeEstimate),
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
) )
} }
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// Type defines where the rate limit settings are received from. Allowed
// values are [rlAllowlistTypeBackend] and [rlAllowlistTypeConsul].
Type string `yaml:"type"`
// List contains IPs and CIDRs.
List []netutil.Prefix `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// type check
var _ validator = (*allowListConfig)(nil)
// validate implements the [validator] interface for *allowListConfig.
func (c *allowListConfig) validate() (err error) {
if c == nil {
return errors.ErrNoValue
}
switch c.Type {
case rlAllowlistTypeBackend, rlAllowlistTypeConsul:
// Go on.
default:
return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
}
return validatePositive("refresh_interval", c.RefreshIvl)
}
// connLimitConfig is the configuration structure for the stream-connection // connLimitConfig is the configuration structure for the stream-connection
// limiter. // limiter.
type connLimitConfig struct { type connLimitConfig struct {

View File

@ -6,7 +6,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
) )
@ -14,6 +14,7 @@ import (
// toInternal returns the configuration of DNS servers for a single server // toInternal returns the configuration of DNS servers for a single server
// group. srvs and other parts of the configuration must be valid. // group. srvs and other parts of the configuration must be valid.
func (srvs servers) toInternal( func (srvs servers) toInternal(
mtrc tlsconfig.Metrics,
tlsConfig *agd.TLS, tlsConfig *agd.TLS,
btdMgr *bindtodevice.Manager, btdMgr *bindtodevice.Manager,
ratelimitConf *rateLimitConfig, ratelimitConf *rateLimitConfig,
@ -68,8 +69,8 @@ func (srvs servers) toInternal(
tlsConf := tlsConfig.Conf.Clone() tlsConf := tlsConfig.Conf.Clone()
// Attach the functions that will count TLS handshake metrics. // Attach the functions that will count TLS handshake metrics.
tlsConf.GetConfigForClient = metrics.TLSMetricsBeforeHandshake(string(srv.Protocol)) tlsConf.GetConfigForClient = mtrc.BeforeHandshake(string(srv.Protocol))
tlsConf.VerifyConnection = metrics.TLSMetricsAfterHandshake( tlsConf.VerifyConnection = mtrc.AfterHandshake(
string(srv.Protocol), string(srv.Protocol),
srv.Name, srv.Name,
tlsConfig.DeviceDomains, tlsConfig.DeviceDomains,

View File

@ -1,11 +1,13 @@
package cmd package cmd
import ( import (
"context"
"fmt" "fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
) )
@ -17,6 +19,8 @@ type serverGroups []*serverGroup
// toInternal returns the configuration for all server groups in the DNS // toInternal returns the configuration for all server groups in the DNS
// service. srvGrps and other parts of the configuration must be valid. // service. srvGrps and other parts of the configuration must be valid.
func (srvGrps serverGroups) toInternal( func (srvGrps serverGroups) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
messages *dnsmsg.Constructor, messages *dnsmsg.Constructor,
btdMgr *bindtodevice.Manager, btdMgr *bindtodevice.Manager,
fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup, fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup,
@ -32,7 +36,7 @@ func (srvGrps serverGroups) toInternal(
} }
var tlsConf *agd.TLS var tlsConf *agd.TLS
tlsConf, err = g.TLS.toInternal() tlsConf, err = g.TLS.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("tls: %w", err) return nil, fmt.Errorf("tls: %w", err)
} }
@ -45,7 +49,13 @@ func (srvGrps serverGroups) toInternal(
ProfilesEnabled: g.ProfilesEnabled, ProfilesEnabled: g.ProfilesEnabled,
} }
svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf, btdMgr, ratelimitConf, dnsConf) svcSrvGrps[i].Servers, err = g.Servers.toInternal(
mtrc,
tlsConf,
btdMgr,
ratelimitConf,
dnsConf,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("server group %q: %w", g.Name, err) return nil, fmt.Errorf("server group %q: %w", g.Name, err)
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice" "github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
@ -19,6 +19,7 @@ import (
type ticketRotator struct { type ticketRotator struct {
logger *slog.Logger logger *slog.Logger
errColl errcoll.Interface errColl errcoll.Interface
mtrc tlsconfig.Metrics
confs map[*tls.Config][]string confs map[*tls.Config][]string
} }
@ -29,6 +30,7 @@ type ticketRotator struct {
func newTicketRotator( func newTicketRotator(
logger *slog.Logger, logger *slog.Logger,
errColl errcoll.Interface, errColl errcoll.Interface,
mtrc tlsconfig.Metrics,
grps []*agd.ServerGroup, grps []*agd.ServerGroup,
) (tr *ticketRotator) { ) (tr *ticketRotator) {
confs := map[*tls.Config][]string{} confs := map[*tls.Config][]string{}
@ -49,6 +51,7 @@ func newTicketRotator(
return &ticketRotator{ return &ticketRotator{
logger: logger.With(slogutil.KeyPrefix, "tickrot"), logger: logger.With(slogutil.KeyPrefix, "tickrot"),
errColl: errColl, errColl: errColl,
mtrc: mtrc,
confs: confs, confs: confs,
} }
} }
@ -81,7 +84,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
var key [sessTickLen]byte var key [sessTickLen]byte
key, err = readSessionTicketKey(fileName) key, err = readSessionTicketKey(fileName)
if err != nil { if err != nil {
metrics.TLSSessionTicketsRotateStatus.Set(0) r.mtrc.SetSessionTicketRotationStatus(ctx, false)
return fmt.Errorf("session ticket for srv %s: %w", conf.ServerName, err) return fmt.Errorf("session ticket for srv %s: %w", conf.ServerName, err)
} }
@ -96,8 +99,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
conf.SetSessionTicketKeys(keys) conf.SetSessionTicketKeys(keys)
} }
metrics.TLSSessionTicketsRotateStatus.Set(1) r.mtrc.SetSessionTicketRotationStatus(ctx, true)
metrics.TLSSessionTicketsRotateTime.SetToCurrentTime()
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -9,10 +10,9 @@ import (
"strings" "strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/prometheus/client_golang/prometheus"
) )
// tlsConfig are the TLS settings of a DNS server, if any. // tlsConfig are the TLS settings of a DNS server, if any.
@ -37,12 +37,15 @@ type tlsConfig struct {
// toInternal converts c to the TLS configuration for a DNS server. c must be // toInternal converts c to the TLS configuration for a DNS server. c must be
// valid. // valid.
func (c *tlsConfig) toInternal() (conf *agd.TLS, err error) { func (c *tlsConfig) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *agd.TLS, err error) {
if c == nil { if c == nil {
return nil, nil return nil, nil
} }
tlsConf, err := c.Certificates.toInternal() tlsConf, err := c.Certificates.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("certificates: %w", err) return nil, fmt.Errorf("certificates: %w", err)
} }
@ -123,7 +126,10 @@ type tlsConfigCert struct {
type tlsConfigCerts []*tlsConfigCert type tlsConfigCerts []*tlsConfigCert
// toInternal converts certs to a TLS configuration. certs must be valid. // toInternal converts certs to a TLS configuration. certs must be valid.
func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) { func (certs tlsConfigCerts) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *tls.Config, err error) {
if len(certs) == 0 { if len(certs) == 0 {
return nil, nil return nil, nil
} }
@ -146,13 +152,8 @@ func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
tlsCerts[i] = cert tlsCerts[i] = cert
authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String() authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String()
metrics.TLSCertificateInfo.With(prometheus.Labels{
"auth_algo": authAlgo, mtrc.SetCertificateInfo(ctx, authAlgo, subj, leaf.NotAfter)
"subject": subj,
}).Set(1)
metrics.TLSCertificateNotAfter.With(prometheus.Labels{
"subject": subj,
}).Set(float64(leaf.NotAfter.Unix()))
} }
return &tls.Config{ return &tls.Config{

View File

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"maps" "maps"
@ -13,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc" "github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
@ -62,9 +64,11 @@ type webConfig struct {
// toInternal converts c to the AdGuardDNS web service configuration. c must be // toInternal converts c to the AdGuardDNS web service configuration. c must be
// valid. // valid.
func (c *webConfig) toInternal( func (c *webConfig) toInternal(
ctx context.Context,
envs *environment, envs *environment,
dnsCk dnscheck.Interface, dnsCk dnscheck.Interface,
errColl errcoll.Interface, errColl errcoll.Interface,
mtrc tlsconfig.Metrics,
) (conf *websvc.Config, err error) { ) (conf *websvc.Config, err error) {
if c == nil { if c == nil {
return nil, nil return nil, nil
@ -83,7 +87,7 @@ func (c *webConfig) toInternal(
conf.RootRedirectURL = netutil.CloneURL(&c.RootRedirectURL.URL) conf.RootRedirectURL = netutil.CloneURL(&c.RootRedirectURL.URL)
} }
conf.LinkedIP, err = c.LinkedIP.toInternal(envs.LinkedIPTargetURL) conf.LinkedIP, err = c.LinkedIP.toInternal(ctx, mtrc, envs.LinkedIPTargetURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("converting linked_ip: %w", err) return nil, fmt.Errorf("converting linked_ip: %w", err)
} }
@ -107,7 +111,7 @@ func (c *webConfig) toInternal(
}} }}
for _, bp := range blockPages { for _, bp := range blockPages {
*bp.webConfPtr, err = bp.conf.toInternal() *bp.webConfPtr, err = bp.conf.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s: %w", bp.name, err) return nil, fmt.Errorf("%s: %w", bp.name, err)
} }
@ -119,7 +123,7 @@ func (c *webConfig) toInternal(
return nil, err return nil, err
} }
conf.NonDoHBind, err = c.NonDoHBind.toInternal() conf.NonDoHBind, err = c.NonDoHBind.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("converting non_doh_bind: %w", err) return nil, fmt.Errorf("converting non_doh_bind: %w", err)
} }
@ -225,6 +229,8 @@ type linkedIPServer struct {
// toInternal converts s to a linkedIP server configuration. s must be valid. // toInternal converts s to a linkedIP server configuration. s must be valid.
func (s *linkedIPServer) toInternal( func (s *linkedIPServer) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
targetURL *urlutil.URL, targetURL *urlutil.URL,
) (srv *websvc.LinkedIPServer, err error) { ) (srv *websvc.LinkedIPServer, err error) {
if s == nil { if s == nil {
@ -232,7 +238,7 @@ func (s *linkedIPServer) toInternal(
} }
srv = &websvc.LinkedIPServer{} srv = &websvc.LinkedIPServer{}
srv.Bind, err = s.Bind.toInternal() srv.Bind, err = s.Bind.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("converting bind: %w", err) return nil, fmt.Errorf("converting bind: %w", err)
} }
@ -280,7 +286,10 @@ type blockPageServer struct {
} }
// toInternal converts s to a block page server configuration. s must be valid. // toInternal converts s to a block page server configuration. s must be valid.
func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err error) { func (s *blockPageServer) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (conf *websvc.BlockPageServerConfig, err error) {
if s == nil { if s == nil {
return nil, nil return nil, nil
} }
@ -289,7 +298,7 @@ func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err
ContentFilePath: s.BlockPage, ContentFilePath: s.BlockPage,
} }
conf.Bind, err = s.Bind.toInternal() conf.Bind, err = s.Bind.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("converting bind: %w", err) return nil, fmt.Errorf("converting bind: %w", err)
} }
@ -326,11 +335,14 @@ type bindData []*bindItem
// toInternal converts bd to bind data for the AdGuard DNS web service. bd must // toInternal converts bd to bind data for the AdGuard DNS web service. bd must
// be valid. // be valid.
func (bd bindData) toInternal() (data []*websvc.BindData, err error) { func (bd bindData) toInternal(
ctx context.Context,
mtrc tlsconfig.Metrics,
) (data []*websvc.BindData, err error) {
data = make([]*websvc.BindData, len(bd)) data = make([]*websvc.BindData, len(bd))
for i, d := range bd { for i, d := range bd {
data[i], err = d.toInternal() data[i], err = d.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err) return nil, fmt.Errorf("bind data at index %d: %w", i, err)
} }
@ -369,8 +381,11 @@ type bindItem struct {
// toInternal converts i to bind data for the AdGuard DNS web service. i must // toInternal converts i to bind data for the AdGuard DNS web service. i must
// be valid. // be valid.
func (i *bindItem) toInternal() (data *websvc.BindData, err error) { func (i *bindItem) toInternal(
tlsConf, err := i.Certificates.toInternal() ctx context.Context,
mtrc tlsconfig.Metrics,
) (data *websvc.BindData, err error) {
tlsConf, err := i.Certificates.toInternal(ctx, mtrc)
if err != nil { if err != nil {
return nil, fmt.Errorf("certificates: %w", err) return nil, fmt.Errorf("certificates: %w", err)
} }

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice" "github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/netutil/urlutil"
) )
@ -28,6 +27,7 @@ type AllowlistUpdater struct {
http *agdhttp.Client http *agdhttp.Client
url *url.URL url *url.URL
errColl errcoll.Interface errColl errcoll.Interface
metrics Metrics
} }
// AllowlistUpdaterConfig is the configuration structure for the allowlist // AllowlistUpdaterConfig is the configuration structure for the allowlist
@ -45,6 +45,9 @@ type AllowlistUpdaterConfig struct {
// ErrColl is used to collect errors during refreshes. // ErrColl is used to collect errors during refreshes.
ErrColl errcoll.Interface ErrColl errcoll.Interface
// Metrics is used to collect allowlist statistics.
Metrics Metrics
// Timeout is the timeout for Consul queries. // Timeout is the timeout for Consul queries.
Timeout time.Duration Timeout time.Duration
} }
@ -60,6 +63,7 @@ func NewAllowlistUpdater(c *AllowlistUpdaterConfig) (upd *AllowlistUpdater) {
}), }),
url: c.ConsulURL, url: c.ConsulURL,
errColl: c.ErrColl, errColl: c.ErrColl,
metrics: c.Metrics,
} }
} }
@ -72,10 +76,7 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
upd.logger.InfoContext(ctx, "refresh started") upd.logger.InfoContext(ctx, "refresh started")
defer upd.logger.InfoContext(ctx, "refresh finished") defer upd.logger.InfoContext(ctx, "refresh finished")
defer func() { defer func() { upd.metrics.SetStatus(ctx, err) }()
metrics.ConsulAllowlistUpdateTime.SetToCurrentTime()
metrics.SetStatusGauge(metrics.ConsulAllowlistUpdateStatus, err)
}()
consulNets, err := upd.loadConsul(ctx) consulNets, err := upd.loadConsul(ctx)
if err != nil { if err != nil {
@ -89,13 +90,11 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
ctx, ctx,
"refresh successful", "refresh successful",
"num_records", len(consulNets), "num_records", len(consulNets),
"url", &urlutil.URL{ "url", urlutil.RedactUserinfo(upd.url),
URL: *upd.url,
},
) )
upd.allowlist.Update(consulNets) upd.allowlist.Update(consulNets)
metrics.ConsulAllowlistSize.Set(float64(len(consulNets))) upd.metrics.SetSize(ctx, len(consulNets))
return nil return nil
} }
@ -107,7 +106,7 @@ type consulRecord struct {
// loadConsul fetches, decodes, and returns the list of IP networks from consul. // loadConsul fetches, decodes, and returns the list of IP networks from consul.
func (upd *AllowlistUpdater) loadConsul(ctx context.Context) (nets []netip.Prefix, err error) { func (upd *AllowlistUpdater) loadConsul(ctx context.Context) (nets []netip.Prefix, err error) {
defer func() { err = errors.Annotate(err, "loading allowlist nets from %s: %w", upd.url) }() defer func() { err = errors.Annotate(err, "loading allowlist nets: %w") }()
httpResp, err := upd.http.Get(ctx, upd.url) httpResp, err := upd.http.Get(ctx, upd.url)
if err != nil { if err != nil {

View File

@ -83,6 +83,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
Allowlist: al, Allowlist: al,
ConsulURL: u, ConsulURL: u,
ErrColl: agdtest.NewErrorCollector(), ErrColl: agdtest.NewErrorCollector(),
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout, Timeout: testTimeout,
}) })
@ -128,6 +129,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
Allowlist: al, Allowlist: al,
ConsulURL: u, ConsulURL: u,
ErrColl: errColl, ErrColl: errColl,
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout, Timeout: testTimeout,
}) })
@ -161,6 +163,7 @@ func TestAllowlistUpdater_Refresh_deadline(t *testing.T) {
Allowlist: al, Allowlist: al,
ConsulURL: u, ConsulURL: u,
ErrColl: errColl, ErrColl: errColl,
Metrics: consul.EmptyMetrics{},
Timeout: testTimeout, Timeout: testTimeout,
}) })

View File

@ -0,0 +1,26 @@
package consul
import "context"
// Metrics is an interface that is used for the collection of the allowlist
// statistics.
type Metrics interface {
// SetSize sets the number of received subnets.
SetSize(ctx context.Context, n int)
// SetStatus sets the status and time of the allowlist refresh attempt.
SetStatus(ctx context.Context, err error)
}
// EmptyMetrics is the implementation of the [Metrics] interface that does
// nothing.
type EmptyMetrics struct{}
// type check
var _ Metrics = EmptyMetrics{}
// SetSize implements the [Metrics] interface for EmptyMetrics.
func (EmptyMetrics) SetSize(_ context.Context, _ int) {}
// SetStatus plements the [Metrics] interface for EmptyMetrics.
func (EmptyMetrics) SetStatus(_ context.Context, _ error) {}

View File

@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc" "github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -92,7 +93,7 @@ func TestService_Start(t *testing.T) {
}) })
srvURL := &url.URL{ srvURL := &url.URL{
Scheme: agdhttp.SchemeHTTP, Scheme: urlutil.SchemeHTTP,
Host: addr, Host: addr,
} }

View File

@ -201,18 +201,21 @@ func (cc *RemoteKV) newInfo(ri *agd.RequestInfo) (inf *info) {
} }
// resp returns the corresponding response. // resp returns the corresponding response.
//
// TODO(e.burkov): Inspect the reason for using different message constructors
// for different DNS types, and consider using only one of them.
func (cc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) { func (cc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) {
qt := ri.QType qt := ri.QType
if qt != dns.TypeA && qt != dns.TypeAAAA { if qt != dns.TypeA && qt != dns.TypeAAAA {
return ri.Messages.NewMsgNODATA(req), nil return ri.Messages.NewRespRCode(req, dns.RcodeSuccess), nil
} }
if qt == dns.TypeA { if qt == dns.TypeA {
return cc.messages.NewIPRespMsg(req, cc.ipv4...) return cc.messages.NewRespIP(req, cc.ipv4...)
} }
return cc.messages.NewIPRespMsg(req, cc.ipv6...) return cc.messages.NewRespIP(req, cc.ipv6...)
} }
// type check // type check

View File

@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv" "github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -86,7 +87,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
t.Run("hit", func(t *testing.T) { t.Run("hit", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, (&url.URL{ r := httptest.NewRequest(http.MethodGet, (&url.URL{
Scheme: "http", Scheme: urlutil.SchemeHTTP,
Host: randomid + "-" + checkDomain, Host: randomid + "-" + checkDomain,
Path: "/dnscheck/test", Path: "/dnscheck/test",
}).String(), strings.NewReader("")) }).String(), strings.NewReader(""))
@ -103,7 +104,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
t.Run("miss", func(t *testing.T) { t.Run("miss", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, (&url.URL{ r := httptest.NewRequest(http.MethodGet, (&url.URL{
Scheme: "http", Scheme: urlutil.SchemeHTTP,
Host: "non" + randomid + "-" + checkDomain, Host: "non" + randomid + "-" + checkDomain,
Path: "/dnscheck/test", Path: "/dnscheck/test",
}).String(), strings.NewReader("")) }).String(), strings.NewReader(""))

View File

@ -18,13 +18,14 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestDefault_ServeHTTP(t *testing.T) { func TestDefault_ServeHTTP(t *testing.T) {
const dname = "some-domain.name" const domain = "domain.example"
testIP := netip.MustParseAddr("1.2.3.4") testIP := netip.MustParseAddr("1.2.3.4")
@ -39,7 +40,7 @@ func TestDefault_ServeHTTP(t *testing.T) {
rcode, rcode,
dnsservertest.NewReq(name, qtype, dns.ClassINET), dnsservertest.NewReq(name, qtype, dns.ClassINET),
dnsservertest.SectionAnswer{ dnsservertest.SectionAnswer{
dnsservertest.NewA(dname, 0, testIP), dnsservertest.NewA(domain, 0, testIP),
}, },
) )
} }
@ -52,38 +53,38 @@ func TestDefault_ServeHTTP(t *testing.T) {
}{{ }{{
name: "single", name: "single",
msgs: []*dns.Msg{ msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA), newMsg(dns.RcodeSuccess, domain, dns.TypeA),
}, },
wantHdr: successHdr, wantHdr: successHdr,
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`)}, wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,1`)},
}, { }, {
name: "existing", name: "existing",
msgs: []*dns.Msg{ msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA), newMsg(dns.RcodeSuccess, domain, dns.TypeA),
newMsg(dns.RcodeSuccess, dname, dns.TypeA), newMsg(dns.RcodeSuccess, domain, dns.TypeA),
}, },
wantHdr: successHdr, wantHdr: successHdr,
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,2`)}, wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,2`)},
}, { }, {
name: "different", name: "different",
msgs: []*dns.Msg{ msgs: []*dns.Msg{
newMsg(dns.RcodeSuccess, dname, dns.TypeA), newMsg(dns.RcodeSuccess, domain, dns.TypeA),
newMsg(dns.RcodeSuccess, "sub."+dname, dns.TypeA), newMsg(dns.RcodeSuccess, "sub."+domain, dns.TypeA),
}, },
wantHdr: successHdr, wantHdr: successHdr,
wantResp: [][]byte{ wantResp: [][]byte{
[]byte("sub." + dname + `,A,NOERROR,` + testIP.String() + `,1`), []byte("sub." + domain + `,A,NOERROR,` + testIP.String() + `,1`),
[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`), []byte(domain + `,A,NOERROR,` + testIP.String() + `,1`),
}, },
}, { }, {
name: "non-recordable", name: "non-recordable",
msgs: []*dns.Msg{ msgs: []*dns.Msg{
// Not NOERROR. // Not NOERROR.
newMsg(dns.RcodeBadName, dname, dns.TypeA), newMsg(dns.RcodeBadName, domain, dns.TypeA),
// Not A/AAAA. // Not A/AAAA.
newMsg(dns.RcodeSuccess, dname, dns.TypeSRV), newMsg(dns.RcodeSuccess, domain, dns.TypeSRV),
// Android metrics. // Android metrics.
newMsg(dns.RcodeSuccess, dname+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA), newMsg(dns.RcodeSuccess, domain+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA),
}, },
wantHdr: successHdr, wantHdr: successHdr,
wantResp: [][]byte{}, wantResp: [][]byte{},
@ -109,7 +110,10 @@ func TestDefault_ServeHTTP(t *testing.T) {
r := httptest.NewRequest( r := httptest.NewRequest(
http.MethodGet, http.MethodGet,
(&url.URL{Scheme: "http", Host: "example.com"}).String(), (&url.URL{
Scheme: urlutil.SchemeHTTP,
Host: "dnsdb.example",
}).String(),
nil, nil,
) )
r.Header.Add(httphdr.AcceptEncoding, agdhttp.HdrValGzip) r.Header.Add(httphdr.AcceptEncoding, agdhttp.HdrValGzip)

View File

@ -15,6 +15,11 @@ type ConstructorConfig struct {
// Cloner used to clone DNS messages. It must not be nil. // Cloner used to clone DNS messages. It must not be nil.
Cloner *Cloner Cloner *Cloner
// StructuredErrors is the configuration for the experimental Structured DNS
// Errors feature. It must not be nil. If enabled,
// [ConstructorConfig.Enabled] should also be true.
StructuredErrors *StructuredDNSErrorsConfig
// BlockingMode is the blocking mode to use in // BlockingMode is the blocking mode to use in
// [Constructor.NewBlockedRespMsg]. It must not be nil. // [Constructor.NewBlockedRespMsg]. It must not be nil.
BlockingMode BlockingMode BlockingMode BlockingMode
@ -22,6 +27,9 @@ type ConstructorConfig struct {
// FilteredResponseTTL is the time-to-live value used for responses created // FilteredResponseTTL is the time-to-live value used for responses created
// by this message constructor. It must be non-negative. // by this message constructor. It must be non-negative.
FilteredResponseTTL time.Duration FilteredResponseTTL time.Duration
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes.
EDEEnabled bool
} }
// validate checks the configuration for errors. // validate checks the configuration for errors.
@ -33,13 +41,19 @@ func (conf *ConstructorConfig) validate() (err error) {
errs = append(errs, err) errs = append(errs, err)
} }
err = conf.StructuredErrors.validate(conf.EDEEnabled)
if err != nil {
err = fmt.Errorf("structured errors: %w", err)
errs = append(errs, err)
}
if conf.BlockingMode == nil { if conf.BlockingMode == nil {
err = fmt.Errorf("blocking mode: %w", errors.ErrNoValue) err = fmt.Errorf("blocking mode: %w", errors.ErrNoValue)
errs = append(errs, err) errs = append(errs, err)
} }
if conf.FilteredResponseTTL < 0 { if conf.FilteredResponseTTL < 0 {
err = fmt.Errorf("filtered response TTL: %w", errors.ErrNegative) err = fmt.Errorf("filtered response ttl: %w", errors.ErrNegative)
errs = append(errs, err) errs = append(errs, err)
} }
@ -51,7 +65,9 @@ func (conf *ConstructorConfig) validate() (err error) {
type Constructor struct { type Constructor struct {
cloner *Cloner cloner *Cloner
blockingMode BlockingMode blockingMode BlockingMode
sde string
fltRespTTL time.Duration fltRespTTL time.Duration
edeEnabled bool
} }
// NewConstructor returns a properly initialized constructor using conf. // NewConstructor returns a properly initialized constructor using conf.
@ -60,10 +76,17 @@ func NewConstructor(conf *ConstructorConfig) (c *Constructor, err error) {
return nil, fmt.Errorf("configuration: %w", err) return nil, fmt.Errorf("configuration: %w", err)
} }
var sde string
if sdeConf := conf.StructuredErrors; sdeConf.Enabled {
sde = sdeConf.iJSON()
}
return &Constructor{ return &Constructor{
cloner: conf.Cloner, cloner: conf.Cloner,
blockingMode: conf.BlockingMode, blockingMode: conf.BlockingMode,
sde: sde,
fltRespTTL: conf.FilteredResponseTTL, fltRespTTL: conf.FilteredResponseTTL,
edeEnabled: conf.EDEEnabled,
}, nil }, nil
} }
@ -72,156 +95,6 @@ func (c *Constructor) Cloner() (cloner *Cloner) {
return c.cloner return c.cloner
} }
// FilteredResponseTTL returns the TTL that the constructor uses to build
// blocked responses.
func (c *Constructor) FilteredResponseTTL() (ttl time.Duration) {
return c.fltRespTTL
}
// NewBlockedRespMsg returns a blocked DNS response message based on the
// constructor's blocking mode.
func (c *Constructor) NewBlockedRespMsg(req *dns.Msg) (msg *dns.Msg, err error) {
switch m := c.blockingMode.(type) {
case *BlockingModeCustomIP:
return c.newBlockedCustomIPRespMsg(req, m)
case *BlockingModeNullIP:
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA, dns.TypeAAAA:
return c.NewIPRespMsg(req, netip.Addr{})
default:
return c.NewMsgNODATA(req), nil
}
case *BlockingModeNXDOMAIN:
return c.NewMsgNXDOMAIN(req), nil
case *BlockingModeREFUSED:
return c.NewMsgREFUSED(req), nil
default:
// Consider unhandled sum type members as unrecoverable programmer
// errors.
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
}
}
// newBlockedCustomIPRespMsg returns a blocked DNS response message with either
// the custom IPs from the blocking mode options or a NODATA one.
func (c *Constructor) newBlockedCustomIPRespMsg(
req *dns.Msg,
m *BlockingModeCustomIP,
) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
if len(m.IPv4) > 0 {
return c.NewIPRespMsg(req, m.IPv4...)
}
case dns.TypeAAAA:
if len(m.IPv6) > 0 {
return c.NewIPRespMsg(req, m.IPv6...)
}
default:
// Go on.
}
return c.NewMsgNODATA(req), nil
}
// NewIPRespMsg returns a DNS A or AAAA response message with the given IP
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
// null) IP. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return c.newMsgA(req, ips...)
case dns.TypeAAAA:
return c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
}
// NewCNAMEWithIPs generates a filtered response to req with CNAME record and
// provided ips. cname is the fully-qualified name and must not be empty, ips
// must be of the same family.
func (c *Constructor) NewCNAMEWithIPs(
req *dns.Msg,
cname string,
ips ...netip.Addr,
) (resp *dns.Msg, err error) {
resp = c.NewRespMsg(req)
resp.Answer = make([]dns.RR, 0, len(ips)+1)
resp.Answer = append(resp.Answer, c.NewAnswerCNAME(req, cname))
var ans dns.RR
for i, ip := range ips {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
ans, err = c.NewAnswerA(cname, ip)
case dns.TypeAAAA:
ans, err = c.NewAnswerAAAA(cname, ip)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
if err != nil {
return nil, fmt.Errorf("bad ip at idx %d: %w", i, err)
}
resp.Answer = append(resp.Answer, ans)
}
return resp, err
}
// NewMsgFORMERR returns a properly initialized FORMERR response.
func (c *Constructor) NewMsgFORMERR(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeFormatError)
}
// NewMsgNXDOMAIN returns a properly initialized NXDOMAIN response.
func (c *Constructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeNameError)
}
// NewMsgREFUSED returns a properly initialized REFUSED response.
func (c *Constructor) NewMsgREFUSED(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeRefused)
}
// NewMsgSERVFAIL returns a properly initialized SERVFAIL response.
func (c *Constructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeServerFailure)
}
// NewMsgNODATA returns a properly initialized NODATA response.
//
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (c *Constructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
return c.newMsgRCode(req, dns.RcodeSuccess)
}
// newMsgRCode returns a properly initialized response with the given RCode.
func (c *Constructor) newMsgRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, int(rc))
resp.Ns = c.newSOARecords(req)
resp.RecursionAvailable = true
return resp
}
// NewTXTRespMsg returns a DNS TXT response message with the given strings as
// content. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewTXTRespMsg(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
ans, err := c.NewAnswerTXT(req, strs)
if err != nil {
return nil, err
}
msg = c.NewRespMsg(req)
msg.Answer = append(msg.Answer, ans)
return msg, nil
}
// AppendDebugExtra appends to response message a DNS TXT extra with CHAOS // AppendDebugExtra appends to response message a DNS TXT extra with CHAOS
// class. // class.
func (c *Constructor) AppendDebugExtra(req, resp *dns.Msg, str string) (err error) { func (c *Constructor) AppendDebugExtra(req, resp *dns.Msg, str string) (err error) {
@ -386,26 +259,23 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
zone = req.Question[0].Name zone = req.Question[0].Name
} }
// TODO(a.garipov): A lot of this is copied from AdGuard Home and needs // TODO(a.garipov): A lot of this is copied from AdGuard Home and needs to
// to be inspected and refactored. // be inspected and refactored.
soa := &dns.SOA{ soa := &dns.SOA{
// values copied from verisign's nonexistent .com domain // Use values from verisign's nonexistent.com domain. Their exact
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers // values are not important in our use case because they are used for
// domain transfers between primary/secondary DNS servers.
Refresh: 1800, Refresh: 1800,
Retry: 900, Retry: 900,
Expire: 604800, Expire: 604800,
Minttl: 86400, Minttl: 86400,
// copied from AdGuard DNS // Copied from AdGuard DNS.
Ns: "fake-for-negative-caching.adguard.com.", Ns: "fake-for-negative-caching.adguard.com.",
Serial: 100500, Serial: 100500,
// rest is request-specific // Rest is request-specific.
Hdr: dns.RR_Header{ Hdr: c.newHdrWithClass(zone, dns.TypeSOA, dns.ClassINET),
Name: zone, // Zone will be appended later if it's not empty or ".".
Rrtype: dns.TypeSOA, Mbox: "hostmaster.",
Ttl: uint32(c.fltRespTTL.Seconds()),
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
} }
if len(zone) > 0 && zone[0] != '.' { if len(zone) > 0 && zone[0] != '.' {
@ -415,25 +285,10 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
return []dns.RR{soa} return []dns.RR{soa}
} }
// NewRespMsg creates a DNS response for req and sets all necessary flags and
// fields. It also guarantees that req.Question will be not empty.
func (c *Constructor) NewRespMsg(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
resp.SetReply(req)
return resp
}
// newMsgA returns a new DNS response with the given IPv4 addresses. If any IP // newMsgA returns a new DNS response with the given IPv4 addresses. If any IP
// address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0. // address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0.
func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) { func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req) msg = c.NewResp(req)
for i, ip := range ips { for i, ip := range ips {
var ans dns.RR var ans dns.RR
ans, err = c.NewAnswerA(req.Question[0].Name, ip) ans, err = c.NewAnswerA(req.Question[0].Name, ip)
@ -450,7 +305,7 @@ func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, er
// newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any // newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any
// IP address is nil, it is replaced by an unspecified (aka null) IP, [::]. // IP address is nil, it is replaced by an unspecified (aka null) IP, [::].
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) { func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
msg = c.NewRespMsg(req) msg = c.NewResp(req)
for i, ip := range ips { for i, ip := range ips {
var ans dns.RR var ans dns.RR
ans, err = c.NewAnswerAAAA(req.Question[0].Name, ip) ans, err = c.NewAnswerAAAA(req.Question[0].Name, ip)

View File

@ -1,18 +1,16 @@
package dnsmsg_test package dnsmsg_test
import ( import (
"net" "net/url"
"net/netip"
"strings" "strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// newTXTExtra is a helper constructor of the expected extra data. // newTXTExtra is a helper constructor of the expected extra data.
@ -27,311 +25,88 @@ func newTXTExtra(ttl uint32, strs ...string) (extra []dns.RR) {
}} }}
} }
// newConstructor returns a new dnsmsg.Constructor with [testFltRespTTL]. func TestNewConstructor(t *testing.T) {
func newConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{}),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(tb, err)
return msgs
}
func TestConstructor_NewBlockedRespMsg_nullIP(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
testCases := []struct {
name string
wantAnsNum int
qt dnsmsg.RRType
}{{
name: "a",
wantAnsNum: 1,
qt: dns.TypeA,
}, {
name: "aaaa",
wantAnsNum: 1,
qt: dns.TypeAAAA,
}, {
name: "txt",
wantAnsNum: 0,
qt: dns.TypeTXT,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET)
resp, respErr := msgs.NewBlockedRespMsg(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
if tc.wantAnsNum == 0 {
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
} else {
require.Len(t, resp.Answer, 1)
ansTTL := resp.Answer[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
}
})
}
}
func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
t.Parallel() t.Parallel()
cloner := agdtest.NewCloner() cloner := agdtest.NewCloner()
badContactURL := errors.Must(url.Parse("invalid-scheme://devteam@adguard.com"))
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
wantA bool
wantAAAA bool
}{{
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
IPv6: []netip.Addr{testIPv6},
},
name: "both",
wantA: true,
wantAAAA: true,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
},
name: "ipv4_only",
wantA: true,
wantAAAA: false,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv6: []netip.Addr{testIPv6},
},
name: "ipv6_only",
wantA: false,
wantAAAA: true,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{},
IPv6: []netip.Addr{},
},
name: "empty",
wantA: false,
wantAAAA: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(t, err)
reqA := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
respA, err := msgs.NewBlockedRespMsg(reqA)
require.NoError(t, err)
require.NotNil(t, respA)
assert.Equal(t, dns.RcodeSuccess, respA.Rcode)
if tc.wantA {
require.Len(t, respA.Answer, 1)
a := testutil.RequireTypeAssert[*dns.A](t, respA.Answer[0])
assert.Equal(t, net.IP(testIPv4.AsSlice()), a.A)
} else {
assert.Empty(t, respA.Answer)
}
reqAAAA := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET)
respAAAA, err := msgs.NewBlockedRespMsg(reqAAAA)
require.NoError(t, err)
require.NotNil(t, respAAAA)
assert.Equal(t, dns.RcodeSuccess, respAAAA.Rcode)
if tc.wantAAAA {
require.Len(t, respAAAA.Answer, 1)
aaaa := testutil.RequireTypeAssert[*dns.AAAA](t, respAAAA.Answer[0])
assert.Equal(t, net.IP(testIPv6.AsSlice()), aaaa.AAAA)
} else {
assert.Empty(t, respAAAA.Answer)
}
})
}
}
func TestConstructor_NewBlockedRespMsg_noAnswer(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
cloner := agdtest.NewCloner()
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
rcode dnsmsg.RCode
}{{
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
name: "nxdomain",
rcode: dns.RcodeNameError,
}, {
blockingMode: &dnsmsg.BlockingModeREFUSED{},
name: "refused",
rcode: dns.RcodeRefused,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
FilteredResponseTTL: agdtest.FilteredResponseTTL,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedRespMsg(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
})
}
}
func TestConstructor_noAnswerMethods(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
testCases := []struct {
method func(req *dns.Msg) (resp *dns.Msg)
name string
want dnsmsg.RCode
}{{
method: msgs.NewMsgFORMERR,
name: "formerr",
want: dns.RcodeFormatError,
}, {
method: msgs.NewMsgNXDOMAIN,
name: "nxdomain",
want: dns.RcodeNameError,
}, {
method: msgs.NewMsgREFUSED,
name: "refused",
want: dns.RcodeRefused,
}, {
method: msgs.NewMsgSERVFAIL,
name: "servfail",
want: dns.RcodeServerFailure,
}, {
method: msgs.NewMsgNODATA,
name: "nodata",
want: dns.RcodeSuccess,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp := tc.method(req)
require.NotNil(t, resp)
require.Len(t, resp.Ns, 1)
assert.Empty(t, resp.Answer)
assert.Equal(t, tc.want, dnsmsg.RCode(resp.Rcode))
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
})
}
}
func TestConstructor_NewTXTRespMsg(t *testing.T) {
t.Parallel()
msgs := newConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET)
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
testCases := []struct { testCases := []struct {
name string name string
conf *dnsmsg.ConstructorConfig
wantErrMsg string wantErrMsg string
strs []string
}{{ }{{
name: "success", name: "good",
conf: &dnsmsg.ConstructorConfig{
Cloner: cloner,
StructuredErrors: agdtest.NewSDEConfig(true),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
},
wantErrMsg: "", wantErrMsg: "",
strs: []string{"111"},
}, { }, {
name: "success_many", name: "all_bad",
wantErrMsg: "", conf: &dnsmsg.ConstructorConfig{
strs: []string{"111", "222"}, FilteredResponseTTL: -1,
},
wantErrMsg: "configuration: " +
"cloner: no value\n" +
"structured errors: no value\n" +
"blocking mode: no value\n" +
"filtered response ttl: negative value",
}, { }, {
name: "success_nil", name: "sde_enabled",
wantErrMsg: "", conf: &dnsmsg.ConstructorConfig{
strs: nil, Cloner: cloner,
StructuredErrors: agdtest.NewSDEConfig(true),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: false,
},
wantErrMsg: "configuration: structured errors: " +
"ede must be enabled to enable sde",
}, { }, {
name: "success_empty", name: "sde_empty",
wantErrMsg: "", conf: &dnsmsg.ConstructorConfig{
strs: []string{}, Cloner: cloner,
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
Enabled: true,
},
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
},
wantErrMsg: "configuration: structured errors: " +
"contact data: empty value\n" +
"justification: empty value",
}, { }, {
name: "too_long", name: "sde_bad",
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255", conf: &dnsmsg.ConstructorConfig{
strs: []string{tooLong}, Cloner: cloner,
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
Enabled: true,
Contact: []*url.URL{badContactURL, nil},
Justification: "\uFFFE",
Organization: "\uFFFE",
},
BlockingMode: &dnsmsg.BlockingModeNullIP{},
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
},
wantErrMsg: "configuration: structured errors: " +
`contact data: at index 0: scheme: bad enum value: "invalid-scheme"` + "\n" +
"contact data: at index 1: no value\n" +
"justification: bad code point at index 0\n" +
"organization: bad code point at index 0",
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
resp, respErr := msgs.NewTXTRespMsg(req, tc.strs...) _, err := dnsmsg.NewConstructor(tc.conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
if tc.wantErrMsg != "" {
return
}
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ans := resp.Answer[0]
ansTTL := ans.Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
assert.Equal(t, tc.strs, txt.Txt)
}) })
} }
} }
@ -339,7 +114,7 @@ func TestConstructor_NewTXTRespMsg(t *testing.T) {
func TestConstructor_AppendDebugExtra(t *testing.T) { func TestConstructor_AppendDebugExtra(t *testing.T) {
t.Parallel() t.Parallel()
msgs := newConstructor(t) msgs := agdtest.NewConstructor(t)
shortText := "This is a short test text" shortText := "This is a short test text"
longText := strings.Repeat("a", 2*dnsmsg.MaxTXTStringLen) longText := strings.Repeat("a", 2*dnsmsg.MaxTXTStringLen)

View File

@ -78,7 +78,9 @@ func TestClone(t *testing.T) {
}, },
name: "empty_slice_ans", name: "empty_slice_ans",
}, { }, {
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET), msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}),
name: "a", name: "a",
}} }}

View File

@ -72,8 +72,7 @@ func (c *optCloner) clone(rr *dns.OPT) (clone *dns.OPT, full bool) {
optClone = opt optClone = opt
case *dns.EDNS0_EDE: case *dns.EDNS0_EDE:
opt := c.ede.Get() opt := c.ede.Get()
opt.InfoCode = orig.InfoCode *opt = *orig
opt.ExtraText = orig.ExtraText
optClone = opt optClone = opt
case *dns.EDNS0_SUBNET: case *dns.EDNS0_SUBNET:

228
internal/dnsmsg/response.go Normal file
View File

@ -0,0 +1,228 @@
package dnsmsg
import (
"fmt"
"net/netip"
"github.com/miekg/dns"
)
// NewResp creates a response DNS message for req and sets all necessary flags
// and fields. resp contains no resource records.
func (c *Constructor) NewResp(req *dns.Msg) (resp *dns.Msg) {
return (&dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}).SetReply(req)
}
// NewBlockedResp returns a blocked response DNS message based on the
// constructor's blocking mode.
func (c *Constructor) NewBlockedResp(req *dns.Msg) (msg *dns.Msg, err error) {
switch m := c.blockingMode.(type) {
case *BlockingModeCustomIP:
return c.newBlockedCustomIPResp(req, m)
case *BlockingModeNullIP:
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA, dns.TypeAAAA:
return c.NewBlockedNullIPResp(req)
default:
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
msg.Ns = c.newSOARecords(req)
}
case *BlockingModeNXDOMAIN:
msg = c.NewBlockedRespRCode(req, dns.RcodeNameError)
msg.Ns = c.newSOARecords(req)
case *BlockingModeREFUSED:
msg = c.NewBlockedRespRCode(req, dns.RcodeRefused)
msg.Ns = c.newSOARecords(req)
default:
// Consider unhandled sum type members as unrecoverable programmer
// errors.
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
}
return msg, nil
}
// NewRespRCode returns a response DNS message with given response code and a
// predefined authority section.
//
// Use [dns.RcodeSuccess] for a proper NODATA response, see
// https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (c *Constructor) NewRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = c.NewResp(req)
resp.Rcode = int(rc)
resp.Ns = c.newSOARecords(req)
return resp
}
// NewBlockedRespRCode returns a blocked response DNS message with given
// response code.
//
// TODO(e.burkov): Add SOA records to the response, like in
// [Constructor.NewRespRCode].
func (c *Constructor) NewBlockedRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
resp = c.NewResp(req)
resp.Rcode = int(rc)
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
return resp
}
// NewRespTXT returns a DNS TXT response message with the given strings as
// content. The TTL of the TXT answer is set to c.FilteredResponseTTL.
func (c *Constructor) NewRespTXT(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
ans, err := c.NewAnswerTXT(req, strs)
if err != nil {
return nil, err
}
msg = c.NewResp(req)
msg.Answer = append(msg.Answer, ans)
return msg, nil
}
// NewRespIP returns an A or AAAA DNS response message with the given IP
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
// null) IP. The TTL is also set to c.FilteredResponseTTL.
func (c *Constructor) NewRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
return c.newMsgA(req, ips...)
case dns.TypeAAAA:
return c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
}
}
// NewBlockedRespIP returns an A or AAAA DNS response message with the given IP
// addresses. The TTL of each record is set to c.FilteredResponseTTL. ips
// should not contain zero values due to the extended error code semantics, use
// [NewBlockedNullIPResp] for this case.
//
// TODO(a.garipov): Consider merging with [NewRespIP] if AddEDE with the Forged
// Answer code isn't used again.
func (c *Constructor) NewBlockedRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
msg, err = c.newMsgA(req, ips...)
case dns.TypeAAAA:
msg, err = c.newMsgAAAA(req, ips...)
default:
return nil, fmt.Errorf("bad qtype for an ip resp: %d", qt)
}
if err != nil {
return nil, err
}
return msg, nil
}
// NewBlockedNullIPResp returns a blocked A or AAAA DNS response message with an
// unspecified (aka null) IP address. The TTL of the record is set to the
// constructor's FilteredResponseTTL.
func (c *Constructor) NewBlockedNullIPResp(req *dns.Msg) (resp *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
resp, err = c.newMsgA(req, netip.Addr{})
case dns.TypeAAAA:
resp, err = c.newMsgAAAA(req, netip.Addr{})
default:
err = fmt.Errorf("bad qtype for an ip resp: %d", qt)
}
if err != nil {
return nil, err
}
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
return resp, nil
}
// AddEDE adds an Extended DNS Error (EDE) option to the blocked response
// message, if the feature is enabled in the Constructor and the request
// indicates EDNS support. It does not overwrite EDE if there already is one.
// req and resp must not be nil.
func (c *Constructor) AddEDE(req, resp *dns.Msg, code uint16) {
if !c.edeEnabled {
return
}
reqOpt := req.IsEdns0()
if reqOpt == nil {
// Requestor doesn't implement EDNS, see
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
return
}
respOpt := resp.IsEdns0()
if respOpt == nil {
respOpt = newOPT(c.cloner, reqOpt.UDPSize(), reqOpt.Do())
resp.Extra = append(resp.Extra, respOpt)
} else if findEDE(respOpt) != nil {
// Do not add an EDE option if there already is one.
return
}
sdeText := c.sdeForReqOpt(reqOpt)
respOpt.Option = append(respOpt.Option, newEDNS0EDE(c.cloner, code, sdeText))
}
// findEDE returns the EDE option if there is one. opt must not be nil.
func findEDE(opt *dns.OPT) (ede *dns.EDNS0_EDE) {
for _, o := range opt.Option {
var ok bool
if ede, ok = o.(*dns.EDNS0_EDE); ok {
return ede
}
}
return nil
}
// sdeForReqOpt returns either the configured SDE text or empty string depending
// on the request's EDNS options.
func (c *Constructor) sdeForReqOpt(reqOpt *dns.OPT) (sde string) {
ede := findEDE(reqOpt)
if ede != nil && ede.InfoCode == 0 && ede.ExtraText == "" {
return c.sde
}
return ""
}
// newBlockedCustomIPResp returns a blocked DNS response message with either the
// custom IPs from the blocking mode options or a NODATA one.
func (c *Constructor) newBlockedCustomIPResp(
req *dns.Msg,
m *BlockingModeCustomIP,
) (msg *dns.Msg, err error) {
switch qt := req.Question[0].Qtype; qt {
case dns.TypeA:
if len(m.IPv4) > 0 {
return c.NewBlockedRespIP(req, m.IPv4...)
}
case dns.TypeAAAA:
if len(m.IPv6) > 0 {
return c.NewBlockedRespIP(req, m.IPv6...)
}
default:
// Go on.
}
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
msg.Ns = c.newSOARecords(req)
return msg, nil
}

View File

@ -0,0 +1,416 @@
package dnsmsg_test
import (
"net/netip"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConstructor_NewBlockedResp_nullIP(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
reqExtra := dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}
filteredSDE := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})
testCases := []struct {
name string
wantAns []dns.RR
wantExtra []dns.RR
qt dnsmsg.RRType
}{{
name: "a",
wantAns: []dns.RR{dnsservertest.NewA(
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified(),
)},
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeA,
}, {
name: "aaaa",
wantAns: []dns.RR{dnsservertest.NewAAAA(
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv6Unspecified(),
)},
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeAAAA,
}, {
name: "txt",
wantAns: nil,
wantExtra: []dns.RR{filteredSDE},
qt: dns.TypeTXT,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAns, resp.Answer)
assert.Equal(t, tc.wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewBlockedResp_customIP(t *testing.T) {
t.Parallel()
cloner := agdtest.NewCloner()
// TODO(a.garipov): Test the forged extra as well if the EDE with that code
// is used again.
reqExtra := dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
}
filteredExtra := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})
ansA := dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv4)
ansAAAA := dnsservertest.NewAAAA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv6)
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
wantAnsA []dns.RR
wantAnsAAAA []dns.RR
wantExtraA []dns.RR
wantExtraAAAA []dns.RR
}{{
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
IPv6: []netip.Addr{testIPv6},
},
name: "both",
wantAnsA: []dns.RR{ansA},
wantAnsAAAA: []dns.RR{ansAAAA},
wantExtraA: nil,
wantExtraAAAA: nil,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{testIPv4},
},
name: "ipv4_only",
wantAnsA: []dns.RR{ansA},
wantAnsAAAA: nil,
wantExtraA: nil,
wantExtraAAAA: []dns.RR{filteredExtra},
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv6: []netip.Addr{testIPv6},
},
name: "ipv6_only",
wantAnsA: nil,
wantAnsAAAA: []dns.RR{ansAAAA},
wantExtraA: []dns.RR{filteredExtra},
wantExtraAAAA: nil,
}, {
blockingMode: &dnsmsg.BlockingModeCustomIP{
IPv4: []netip.Addr{},
IPv6: []netip.Addr{},
},
name: "empty",
wantAnsA: nil,
wantAnsAAAA: nil,
wantExtraA: []dns.RR{filteredExtra},
wantExtraAAAA: []dns.RR{filteredExtra},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
t.Run("a", func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAnsA, resp.Answer)
assert.Equal(t, tc.wantExtraA, resp.Extra)
})
t.Run("aaaa", func(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET, reqExtra)
resp, respErr := msgs.NewBlockedResp(req)
require.NoError(t, respErr)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, tc.wantAnsAAAA, resp.Answer)
assert.Equal(t, tc.wantExtraAAAA, resp.Extra)
})
})
}
}
func TestConstructor_NewBlockedResp_nodata(t *testing.T) {
t.Parallel()
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
cloner := agdtest.NewCloner()
wantExtra := []dns.RR{dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
})}
testCases := []struct {
blockingMode dnsmsg.BlockingMode
name string
rcode dnsmsg.RCode
}{{
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
name: "nxdomain",
rcode: dns.RcodeNameError,
}, {
blockingMode: &dnsmsg.BlockingModeREFUSED{},
name: "refused",
rcode: dns.RcodeRefused,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner,
BlockingMode: tc.blockingMode,
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedResp(req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
assert.Empty(t, resp.Answer)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
assert.Equal(t, wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewBlockedResp_sde(t *testing.T) {
t.Parallel()
reqEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
reqNoEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
wantAns := []dns.RR{
dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified()),
}
testCases := []struct {
req *dns.Msg
sde *dnsmsg.StructuredDNSErrorsConfig
name string
wantExtra []dns.RR
ede bool
}{{
req: reqEDNS,
sde: agdtest.NewSDEConfig(true),
name: "ede_sde",
wantExtra: []dns.RR{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
ExtraText: agdtest.SDEText,
}),
},
ede: true,
}, {
req: reqEDNS,
sde: agdtest.NewSDEConfig(false),
name: "ede_no_sde",
wantExtra: []dns.RR{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
InfoCode: dns.ExtendedErrorCodeFiltered,
}),
},
ede: true,
}, {
req: reqEDNS,
sde: agdtest.NewSDEConfig(false),
name: "no_ede",
wantExtra: nil,
ede: false,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(true),
name: "unsupported_ede_sde",
wantExtra: nil,
ede: true,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(false),
name: "unsupported_ede_no_sde",
wantExtra: nil,
ede: true,
}, {
req: reqNoEDNS,
sde: agdtest.NewSDEConfig(false),
name: "unsupported_no_ede",
wantExtra: nil,
ede: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: agdtest.NewCloner(),
BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: tc.sde,
FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: tc.ede,
})
require.NoError(t, err)
resp, err := msgs.NewBlockedResp(tc.req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Equal(t, wantAns, resp.Answer)
assert.Equal(t, tc.wantExtra, resp.Extra)
})
}
}
func TestConstructor_NewRespRCode(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
for rcode, name := range dns.RcodeToString {
t.Run(name, func(t *testing.T) {
t.Parallel()
resp := msgs.NewRespRCode(req, dnsmsg.RCode(rcode))
require.NotNil(t, resp)
require.Empty(t, resp.Answer)
assert.Equal(t, rcode, resp.Rcode)
require.Len(t, resp.Ns, 1)
nsTTL := resp.Ns[0].Header().Ttl
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
assert.Empty(t, resp.Extra)
})
}
}
func TestConstructor_NewRespTXT(t *testing.T) {
t.Parallel()
msgs := agdtest.NewConstructor(t)
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET, dnsservertest.SectionExtra{
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
})
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
testCases := []struct {
name string
wantErrMsg string
strs []string
}{{
name: "success",
wantErrMsg: "",
strs: []string{"111"},
}, {
name: "success_many",
wantErrMsg: "",
strs: []string{"111", "222"},
}, {
name: "success_nil",
wantErrMsg: "",
strs: nil,
}, {
name: "success_empty",
wantErrMsg: "",
strs: []string{},
}, {
name: "too_long",
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255",
strs: []string{tooLong},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp, respErr := msgs.NewRespTXT(req, tc.strs...)
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr)
if tc.wantErrMsg != "" {
return
}
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ans := resp.Answer[0]
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), txt.Hdr.Ttl)
assert.Equal(t, tc.strs, txt.Txt)
assert.Empty(t, resp.Extra)
})
}
}

View File

@ -140,3 +140,36 @@ func newTXT(c *Cloner, txt []string) (rr *dns.TXT) {
return rr return rr
} }
// newOPT constructs a new resource record of type OPT, optionally using c to
// allocate the structure.
func newOPT(c *Cloner, udpSize uint16, doBit bool) (opt *dns.OPT) {
if c == nil {
opt = &dns.OPT{}
} else {
opt = c.opt.rr.Get()
opt.Option = opt.Option[:0]
}
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.SetUDPSize(udpSize)
opt.SetDo(doBit)
return opt
}
// newEDNS0EDE constructs a new resource record of type EDNS0_EDE, optionally
// using c to allocate the structure.
func newEDNS0EDE(c *Cloner, infoCode uint16, extraText string) (opt *dns.EDNS0_EDE) {
if c == nil {
opt = &dns.EDNS0_EDE{}
} else {
opt = c.opt.ede.Get()
}
opt.InfoCode = infoCode
opt.ExtraText = extraText
return opt
}

View File

@ -0,0 +1,143 @@
package dnsmsg
import (
"encoding/json"
"fmt"
"net/url"
"strings"
"unicode"
"github.com/AdguardTeam/golibs/errors"
)
// StructuredDNSErrorsConfig is the configuration structure for the experimental
// Structured DNS Errors feature.
//
// See https://www.ietf.org/archive/id/draft-ietf-dnsop-structured-dns-error-09.html.
//
// TODO(a.garipov): Add sub-error?
type StructuredDNSErrorsConfig struct {
// Justification for this particular DNS filtering. It must not be empty.
Justification string
// Organization is an optional description of the organization.
Organization string
// Contact information for the DNS service. It must not be empty. All
// items must not be nil and must be valid mailto, sips, or tel URLs.
Contact []*url.URL
// Enabled, if true, enables the experimental Structured DNS Errors feature.
Enabled bool
}
// iJSON returns the I-JSON representation of this configuration. c must be
// valid.
func (c *StructuredDNSErrorsConfig) iJSON() (s string) {
data := &structuredDNSErrorData{
Justification: c.Justification,
Organization: c.Organization,
}
for _, cont := range c.Contact {
data.Contact = append(data.Contact, cont.String())
}
// The only error that could be returned here is a type error from JSON
// encoding, and these should never happen.
b := errors.Must(json.Marshal(data))
return string(b)
}
// structuredDNSErrorData is the structure for the JSON representation of the
// SDE data.
//
// TODO(a.garipov): Add sub-error?
type structuredDNSErrorData struct {
Justification string `json:"j"`
Organization string `json:"o,omitempty"`
Contact []string `json:"c"`
}
// forbiddenRanges contains the ranges of forbidden code points for structured
// DNS errors according to the I-JSON specification.
//
// See https://datatracker.ietf.org/doc/html/rfc7493#section-2.1.
var forbiddenRanges = []*unicode.RangeTable{unicode.Cs, unicode.Noncharacter_Code_Point}
// isSurrogateOrNonCharacter returns true if r is a surrogate or a non-character
// code point.
func isSurrogateOrNonCharacter(r rune) (ok bool) {
return unicode.IsOneOf(forbiddenRanges, r)
}
// validateSDEString returns an error if s contains a surrogate or a
// non-character code point. It always returns nil for an empty string.
func validateSDEString(s string) (err error) {
if i := strings.IndexFunc(s, isSurrogateOrNonCharacter); i >= 0 {
return fmt.Errorf("bad code point at index %d", i)
}
return nil
}
// validate checks the configuration for errors.
func (c *StructuredDNSErrorsConfig) validate(edeEnabled bool) (err error) {
if c == nil {
return errors.ErrNoValue
}
if !c.Enabled {
return nil
} else if !edeEnabled {
return errors.Error("ede must be enabled to enable sde")
}
var errs []error
if len(c.Contact) == 0 {
err = fmt.Errorf("contact data: %w", errors.ErrEmptyValue)
errs = append(errs, err)
}
for i, cont := range c.Contact {
err = validateSDEContactURL(cont)
if err != nil {
err = fmt.Errorf("contact data: at index %d: %w", i, err)
errs = append(errs, err)
}
}
if c.Justification == "" {
err = fmt.Errorf("justification: %w", errors.ErrEmptyValue)
errs = append(errs, err)
} else if err = validateSDEString(c.Justification); err != nil {
err = fmt.Errorf("justification: %w", err)
errs = append(errs, err)
}
if err = validateSDEString(c.Organization); err != nil {
err = fmt.Errorf("organization: %w", err)
errs = append(errs, err)
}
return errors.Join(errs...)
}
// validateSDEContactURL returns an error if u is not a valid SDE contact URL.
// It doesn't check for bad code points in the URL since [url.URL.String]
// escapes them.
func validateSDEContactURL(u *url.URL) (err error) {
if u == nil {
return errors.ErrNoValue
}
switch strings.ToLower(u.Scheme) {
case "mailto", "sips", "tel":
// TODO(a.garipov): Consider more thorough validations for each scheme.
default:
return fmt.Errorf("scheme: %w: %q", errors.ErrBadEnumValue, u.Scheme)
}
return nil
}

View File

@ -31,8 +31,8 @@ type Middleware struct {
// cacheMinTTL is the minimum supported TTL for cache items. // cacheMinTTL is the minimum supported TTL for cache items.
cacheMinTTL time.Duration cacheMinTTL time.Duration
// useTTLOverride shows if the TTL overrides logic should be used. // overrideTTL shows if the TTL overrides logic should be used.
useTTLOverride bool overrideTTL bool
} }
// MiddlewareConfig is the configuration structure for NewMiddleware. // MiddlewareConfig is the configuration structure for NewMiddleware.
@ -49,8 +49,8 @@ type MiddlewareConfig struct {
// MinTTL is the minimum supported TTL for cache items. // MinTTL is the minimum supported TTL for cache items.
MinTTL time.Duration MinTTL time.Duration
// UseTTLOverride shows if the TTL overrides logic should be used. // OverrideTTL shows if the TTL overrides logic should be used.
UseTTLOverride bool OverrideTTL bool
} }
// NewMiddleware initializes a new LRU caching middleware. c must not be nil. // NewMiddleware initializes a new LRU caching middleware. c must not be nil.
@ -63,10 +63,10 @@ func NewMiddleware(c *MiddlewareConfig) (m *Middleware) {
} }
return &Middleware{ return &Middleware{
metrics: metrics, metrics: metrics,
cache: gcache.New(c.Size).LRU().Build(), cache: gcache.New(c.Size).LRU().Build(),
cacheMinTTL: c.MinTTL, cacheMinTTL: c.MinTTL,
useTTLOverride: c.UseTTLOverride, overrideTTL: c.OverrideTTL,
} }
} }
@ -150,7 +150,7 @@ func (m *Middleware) set(msg *dns.Msg) (err error) {
} }
exp := time.Duration(ttl) * time.Second exp := time.Duration(ttl) * time.Second
if m.useTTLOverride && msg.Rcode != dns.RcodeServerFailure { if m.overrideTTL && msg.Rcode != dns.RcodeServerFailure {
exp = max(exp, m.cacheMinTTL) exp = max(exp, m.cacheMinTTL)
setMinTTL(msg, uint32(exp.Seconds())) setMinTTL(msg, uint32(exp.Seconds()))
} }

View File

@ -186,9 +186,9 @@ func TestMiddleware_Wrap(t *testing.T) {
withCache := dnsserver.WithMiddlewares( withCache := dnsserver.WithMiddlewares(
handler, handler,
cache.NewMiddleware(&cache.MiddlewareConfig{ cache.NewMiddleware(&cache.MiddlewareConfig{
Size: 100, Size: 100,
MinTTL: minTTL, MinTTL: minTTL,
UseTTLOverride: tc.minTTL != nil, OverrideTTL: tc.minTTL != nil,
}), }),
) )

View File

@ -2,6 +2,7 @@ package dnsservertest
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
@ -13,9 +14,15 @@ import (
// AnswerTTL is the default TTL of the test handler's answers. // AnswerTTL is the default TTL of the test handler's answers.
const AnswerTTL time.Duration = 100 * time.Second const AnswerTTL time.Duration = 100 * time.Second
// CreateTestHandler creates a [dnsserver.Handler] with the specified // NewDefaultHandler returns a simple handler that always returns a response
// with a single A record.
func NewDefaultHandler() (handler dnsserver.Handler) {
return NewDefaultHandlerWithCount(1)
}
// NewDefaultHandlerWithCount creates a [dnsserver.Handler] with the specified
// parameters. All responses will have the [TestAnsTTL] TTL. // parameters. All responses will have the [TestAnsTTL] TTL.
func CreateTestHandler(recordsCount int) (h dnsserver.Handler) { func NewDefaultHandlerWithCount(recordsCount int) (h dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) { f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
// Check that necessary context keys are set. // Check that necessary context keys are set.
si := dnsserver.MustServerInfoFromContext(ctx) si := dnsserver.MustServerInfoFromContext(ctx)
@ -49,8 +56,12 @@ func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
return dnsserver.HandlerFunc(f) return dnsserver.HandlerFunc(f)
} }
// DefaultHandler returns a simple handler that always returns a response with // NewPanicHandler returns a DNS handler that panics with an error.
// a single A record. func NewPanicHandler() (handler dnsserver.Handler) {
func DefaultHandler() (handler dnsserver.Handler) { f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
return CreateTestHandler(1) // TODO(a.garipov): Add a helper for these kinds of errors to golibs.
panic(fmt.Errorf("unexpected call to ServeDNS(%v, %v)", rw, req))
}
return dnsserver.HandlerFunc(f)
} }

View File

@ -273,6 +273,22 @@ func NewNS(name string, ttl uint32, ns string) (rr dns.RR) {
} }
} }
// NewOPT constructs the new resource record of type OPT.
func NewOPT(do bool, udpSize uint16, opts ...dns.EDNS0) (rr dns.RR) {
opt := &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
Option: opts,
}
opt.SetDo(do)
opt.SetUDPSize(udpSize)
return opt
}
// NewECSExtra constructs a new OPT RR for the extra section. // NewECSExtra constructs a new OPT RR for the extra section.
func NewECSExtra(ip net.IP, fam uint16, mask, scope uint8) (extra dns.RR) { func NewECSExtra(ip net.IP, fam uint16, mask, scope uint8) (extra dns.RR) {
return &dns.OPT{ return &dns.OPT{
@ -300,11 +316,8 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
padLen := requestPaddingBlockSize - msgLen%requestPaddingBlockSize padLen := requestPaddingBlockSize - msgLen%requestPaddingBlockSize
// Truncate padding to fit in UDP buffer. // Truncate padding to fit in UDP buffer.
if msgLen+padLen > int(UDPBufferSize) { if bufSzInt := int(UDPBufferSize); msgLen+padLen > bufSzInt {
padLen = int(UDPBufferSize) - msgLen padLen = max(bufSzInt-msgLen, 0)
if padLen < 0 {
padLen = 0
}
} }
return &dns.OPT{ return &dns.OPT{

View File

@ -23,7 +23,7 @@ func TestMain(m *testing.M) {
const testTimeout = 1 * time.Second const testTimeout = 1 * time.Second
func TestHandler_ServeDNS(t *testing.T) { func TestHandler_ServeDNS(t *testing.T) {
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// No-fallbacks handler. // No-fallbacks handler.
handler := forward.NewHandler(&forward.HandlerConfig{ handler := forward.NewHandler(&forward.HandlerConfig{
@ -47,7 +47,7 @@ func TestHandler_ServeDNS(t *testing.T) {
} }
func TestHandler_ServeDNS_fallbackNetError(t *testing.T) { func TestHandler_ServeDNS_fallbackNetError(t *testing.T) {
srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
handler := forward.NewHandler(&forward.HandlerConfig{ handler := forward.NewHandler(&forward.HandlerConfig{
UpstreamsAddresses: []*forward.UpstreamPlainConfig{{ UpstreamsAddresses: []*forward.UpstreamPlainConfig{{
Network: forward.NetworkAny, Network: forward.NetworkAny,

View File

@ -20,7 +20,7 @@ func TestHandler_Refresh(t *testing.T) {
var upstreamIsUp atomic.Bool var upstreamIsUp atomic.Bool
var upstreamRequestsCount atomic.Int64 var upstreamRequestsCount atomic.Int64
defaultHandler := dnsservertest.DefaultHandler() defaultHandler := dnsservertest.NewDefaultHandler()
// This handler writes an empty message if upstreamUp flag is false. // This handler writes an empty message if upstreamUp flag is false.
handlerFunc := dnsserver.HandlerFunc(func( handlerFunc := dnsserver.HandlerFunc(func(

View File

@ -36,7 +36,7 @@ func TestUpstreamPlain_Exchange(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) _, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
ups := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ ups := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
Network: tc.network, Network: tc.network,
Address: netip.MustParseAddrPort(addr), Address: netip.MustParseAddrPort(addr),
@ -65,7 +65,7 @@ func TestUpstreamPlain_Exchange_truncated(t *testing.T) {
rw.LocalAddr(), rw.LocalAddr(),
rw.RemoteAddr(), rw.RemoteAddr(),
) )
handler := dnsservertest.DefaultHandler() handler := dnsservertest.NewDefaultHandler()
err = handler.ServeDNS(ctx, nrw, req) err = handler.ServeDNS(ctx, nrw, req)
if err != nil { if err != nil {
return err return err

View File

@ -1,9 +1,9 @@
module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.23.1 go 1.23.2
require ( require (
github.com/AdguardTeam/golibs v0.28.0 github.com/AdguardTeam/golibs v0.30.1
github.com/ameshkov/dnscrypt/v2 v2.3.0 github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/ameshkov/dnsstamps v1.0.3 github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2 github.com/bluele/gcache v0.0.2
@ -11,12 +11,12 @@ require (
github.com/miekg/dns v1.1.62 github.com/miekg/dns v1.1.62
github.com/panjf2000/ants/v2 v2.10.0 github.com/panjf2000/ants/v2 v2.10.0
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/prometheus/client_golang v1.20.1 github.com/prometheus/client_golang v1.20.5
github.com/quic-go/quic-go v0.47.0 github.com/quic-go/quic-go v0.48.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.29.0 golang.org/x/net v0.30.0
golang.org/x/sys v0.25.0 golang.org/x/sys v0.26.0
) )
require ( require (
@ -26,22 +26,23 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.20.2 // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/common v0.60.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.5.0 // indirect
golang.org/x/crypto v0.27.0 // indirect golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.21.0 // indirect golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/text v0.18.0 // indirect golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.6.0 // indirect golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.25.0 // indirect golang.org/x/tools v0.26.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@ -1,4 +1,5 @@
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4= github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
@ -25,10 +26,10 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4= github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@ -47,18 +48,18 @@ github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8= github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -69,29 +70,29 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -23,7 +23,7 @@ func TestCacheMetricsListener_integration_cache(t *testing.T) {
}) })
handlerWithMiddleware := dnsserver.WithMiddlewares( handlerWithMiddleware := dnsserver.WithMiddlewares(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
cacheMiddleware, cacheMiddleware,
) )

View File

@ -18,7 +18,7 @@ import (
// normal unit test, we create a forward handler, emulate a query and then // normal unit test, we create a forward handler, emulate a query and then
// check if prom metrics were incremented. // check if prom metrics were incremented.
func TestForwardMetricsListener_integration_request(t *testing.T) { func TestForwardMetricsListener_integration_request(t *testing.T) {
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// Initialize a new forward.Handler and set the metrics listener. // Initialize a new forward.Handler and set the metrics listener.
handler := forward.NewHandler(&forward.HandlerConfig{ handler := forward.NewHandler(&forward.HandlerConfig{

View File

@ -19,16 +19,21 @@ import (
// normal unit test, we create a cache middleware, emulate a query and then // normal unit test, we create a cache middleware, emulate a query and then
// check if prom metrics were incremented. // check if prom metrics were incremented.
func TestRateLimiterMetricsListener_integration_cache(t *testing.T) { func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
rps := 5 const (
count = 5
ivl = time.Second
)
rl := ratelimit.NewBackoff(&ratelimit.BackoffConfig{ rl := ratelimit.NewBackoff(&ratelimit.BackoffConfig{
Allowlist: ratelimit.NewDynamicAllowlist([]netip.Prefix{}, []netip.Prefix{}), Allowlist: ratelimit.NewDynamicAllowlist([]netip.Prefix{}, []netip.Prefix{}),
Period: time.Minute, Period: time.Minute,
Duration: time.Minute, Duration: time.Minute,
Count: uint(rps), Count: count,
ResponseSizeEstimate: 1 * datasize.KB, ResponseSizeEstimate: 1 * datasize.KB,
IPv4RPS: uint(rps), IPv4Count: count,
IPv6RPS: uint(rps), IPv4Interval: ivl,
IPv6Count: count,
IPv6Interval: ivl,
RefuseANY: true, RefuseANY: true,
}) })
rlMw, err := ratelimit.NewMiddleware(&ratelimit.MiddlewareConfig{ rlMw, err := ratelimit.NewMiddleware(&ratelimit.MiddlewareConfig{
@ -38,7 +43,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
handlerWithMiddleware := dnsserver.WithMiddlewares( handlerWithMiddleware := dnsserver.WithMiddlewares(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
rlMw, rlMw,
) )
@ -55,7 +60,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
err = handlerWithMiddleware.ServeDNS(ctx, nrw, req) err = handlerWithMiddleware.ServeDNS(ctx, nrw, req)
require.NoError(t, err) require.NoError(t, err)
if i < rps { if i < count {
dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false) dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false)
} else { } else {
require.Nil(t, nrw.Msg()) require.Nil(t, nrw.Msg())

View File

@ -24,7 +24,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
ConfigBase: dnsserver.ConfigBase{ ConfigBase: dnsserver.ConfigBase{
Name: "test", Name: "test",
Addr: "127.0.0.1:0", Addr: "127.0.0.1:0",
Handler: dnsservertest.DefaultHandler(), Handler: dnsservertest.NewDefaultHandler(),
Metrics: prometheus.NewServerMetricsListener(testNamespace), Metrics: prometheus.NewServerMetricsListener(testNamespace),
}, },
} }

View File

@ -36,19 +36,29 @@ type BackoffConfig struct {
// as several responses. // as several responses.
ResponseSizeEstimate datasize.ByteSize ResponseSizeEstimate datasize.ByteSize
// IPv4RPS is the maximum number of requests per second allowed from a // IPv4Count is the maximum number of requests per a specified interval
// single subnet for IPv4 addresses. Any requests above this rate are // allowed from a single subnet for IPv4 addresses. Any requests above this
// counted as the client's backoff count. RPS must be greater than zero. // rate are counted as the client's backoff count. It must be greater than
IPv4RPS uint // zero.
IPv4Count uint
// IPv4Interval is the time during which to count the number of requests
// for IPv4 addresses.
IPv4Interval time.Duration
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate // IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero. // rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
IPv4SubnetKeyLen int IPv4SubnetKeyLen int
// IPv6RPS is the maximum number of requests per second allowed from a // IPv6Count is the maximum number of requests per a specified interval
// single subnet for IPv6 addresses. Any requests above this rate are // allowed from a single subnet for IPv6 addresses. Any requests above this
// counted as the client's backoff count. RPS must be greater than zero. // rate are counted as the client's backoff count. It must be greater than
IPv6RPS uint // zero.
IPv6Count uint
// IPv6Interval is the time during which to count the number of requests
// for IPv6 addresses.
IPv6Interval time.Duration
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate // IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero. // rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
@ -75,9 +85,11 @@ type Backoff struct {
allowlist Allowlist allowlist Allowlist
respSzEst datasize.ByteSize respSzEst datasize.ByteSize
count uint count uint
ipv4rps uint ipv4Count uint
ipv4Interval time.Duration
ipv4SubnetKeyLen int ipv4SubnetKeyLen int
ipv6rps uint ipv6Count uint
ipv6Interval time.Duration
ipv6SubnetKeyLen int ipv6SubnetKeyLen int
refuseANY bool refuseANY bool
} }
@ -93,9 +105,11 @@ func NewBackoff(c *BackoffConfig) (l *Backoff) {
allowlist: c.Allowlist, allowlist: c.Allowlist,
respSzEst: c.ResponseSizeEstimate, respSzEst: c.ResponseSizeEstimate,
count: c.Count, count: c.Count,
ipv4rps: c.IPv4RPS, ipv4Count: c.IPv4Count,
ipv4Interval: c.IPv4Interval,
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen, ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
ipv6rps: c.IPv6RPS, ipv6Count: c.IPv6Count,
ipv6Interval: c.IPv6Interval,
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen, ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
refuseANY: c.RefuseANY, refuseANY: c.RefuseANY,
} }
@ -133,12 +147,12 @@ func (l *Backoff) IsRateLimited(
return true, false, nil return true, false, nil
} }
rps := l.ipv4rps count, ivl := l.ipv4Count, l.ipv4Interval
if ip.Is6() { if ip.Is6() {
rps = l.ipv6rps count, ivl = l.ipv6Count, l.ipv6Interval
} }
return l.hasHitRateLimit(key, rps), false, nil return l.hasHitRateLimit(key, count, ivl), false, nil
} }
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address. // validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
@ -198,15 +212,15 @@ func (l *Backoff) incBackoff(key string) {
l.hitCounters.SetDefault(key, counter) l.hitCounters.SetDefault(key, counter)
} }
// hasHitRateLimit checks value for a subnet with rps as a maximum number // hasHitRateLimit checks if the value of requests for given subnet hit the
// requests per second. // maximum count of requests per given interval.
func (l *Backoff) hasHitRateLimit(subnetIPStr string, rps uint) (ok bool) { func (l *Backoff) hasHitRateLimit(subnetIPStr string, count uint, ivl time.Duration) (ok bool) {
var r *RequestCounter var r *RequestCounter
rVal, ok := l.reqCounters.Get(subnetIPStr) rVal, ok := l.reqCounters.Get(subnetIPStr)
if ok { if ok {
r = rVal.(*RequestCounter) r = rVal.(*RequestCounter)
} else { } else {
r = NewRequestCounter(rps, time.Second) r = NewRequestCounter(count, ivl)
l.reqCounters.SetDefault(subnetIPStr, r) l.reqCounters.SetDefault(subnetIPStr, r)
} }

View File

@ -22,7 +22,10 @@ func TestMain(m *testing.M) {
} }
func TestRatelimitMiddleware(t *testing.T) { func TestRatelimitMiddleware(t *testing.T) {
const rps = 10 const (
rps = 10
ivl = time.Second
)
persistent := []netip.Prefix{ persistent := []netip.Prefix{
netip.MustParsePrefix("4.3.2.1/8"), netip.MustParsePrefix("4.3.2.1/8"),
@ -99,9 +102,11 @@ func TestRatelimitMiddleware(t *testing.T) {
Duration: time.Minute, Duration: time.Minute,
Count: rps, Count: rps,
ResponseSizeEstimate: 128 * datasize.B, ResponseSizeEstimate: 128 * datasize.B,
IPv4RPS: rps, IPv4Count: rps,
IPv4Interval: ivl,
IPv4SubnetKeyLen: 24, IPv4SubnetKeyLen: 24,
IPv6RPS: rps, IPv6Count: rps,
IPv6Interval: ivl,
IPv6SubnetKeyLen: 48, IPv6SubnetKeyLen: 48,
RefuseANY: true, RefuseANY: true,
}) })
@ -112,7 +117,7 @@ func TestRatelimitMiddleware(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
withMw := dnsserver.WithMiddlewares( withMw := dnsserver.WithMiddlewares(
dnsservertest.CreateTestHandler(tc.respCount), dnsservertest.NewDefaultHandlerWithCount(tc.respCount),
rlMw, rlMw,
) )

View File

@ -309,6 +309,10 @@ func (s *ServerBase) serveDNSMsgInternal(
s.metrics.OnError(ctx, err) s.metrics.OnError(ctx, err)
resp = genErrorResponse(req, dns.RcodeServerFailure) resp = genErrorResponse(req, dns.RcodeServerFailure)
if isNonCriticalNetError(err) {
addEDE(req, resp, dns.ExtendedErrorCodeNetworkError, "")
}
err = rw.WriteMsg(ctx, req, resp) err = rw.WriteMsg(ctx, req, resp)
if err != nil { if err != nil {
log.Debug("[%d]: error writing a response: %s", req.Id, err) log.Debug("[%d]: error writing a response: %s", req.Id, err)
@ -316,6 +320,28 @@ func (s *ServerBase) serveDNSMsgInternal(
} }
} }
// addEDE adds an Extended DNS Error (EDE) option to the blocked response
// message, if the request indicates EDNS support.
func addEDE(req, resp *dns.Msg, code uint16, text string) {
reqOpt := req.IsEdns0()
if reqOpt == nil {
// Requestor doesn't implement EDNS, see
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
return
}
respOpt := resp.IsEdns0()
if respOpt == nil {
resp.SetEdns0(reqOpt.UDPSize(), reqOpt.Do())
respOpt = resp.Extra[len(resp.Extra)-1].(*dns.OPT)
}
respOpt.Option = append(respOpt.Option, &dns.EDNS0_EDE{
InfoCode: code,
ExtraText: text,
})
}
// acceptMsg checks if we should process the incoming DNS query. // acceptMsg checks if we should process the incoming DNS query.
func (s *ServerBase) acceptMsg(m *dns.Msg) (action dns.MsgAcceptAction) { func (s *ServerBase) acceptMsg(m *dns.Msg) (action dns.MsgAcceptAction) {
if m.Response { if m.Response {

View File

@ -37,7 +37,7 @@ func BenchmarkServeDNS(b *testing.B) {
for _, tc := range testCases { for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
_, addr := dnsservertest.RunDNSServer(b, dnsservertest.DefaultHandler()) _, addr := dnsservertest.RunDNSServer(b, dnsservertest.NewDefaultHandler())
// Prepare a test message. // Prepare a test message.
m := new(dns.Msg) m := new(dns.Msg)
@ -105,7 +105,7 @@ func readMsg(resBuf []byte, network dnsserver.Network, conn net.Conn) (err error
func BenchmarkServeTLS(b *testing.B) { func BenchmarkServeTLS(b *testing.B) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(b, dnsservertest.DefaultHandler(), tlsConfig) addr := dnsservertest.RunTLSServer(b, dnsservertest.NewDefaultHandler(), tlsConfig)
// Prepare a test message // Prepare a test message
m := new(dns.Msg) m := new(dns.Msg)
@ -171,7 +171,7 @@ func BenchmarkServeDoH(b *testing.B) {
for _, tc := range testCases { for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
srv, err := dnsservertest.RunLocalHTTPSServer( srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tc.tlsConfig, tc.tlsConfig,
nil, nil,
) )
@ -250,7 +250,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
Net: string(tc.network), Net: string(tc.network),
} }
s := dnsservertest.RunDNSCryptServer(b, dnsservertest.DefaultHandler()) s := dnsservertest.RunDNSCryptServer(b, dnsservertest.NewDefaultHandler())
stamp := dnsstamps.ServerStamp{ stamp := dnsstamps.ServerStamp{
ServerAddrStr: s.ServerAddr, ServerAddrStr: s.ServerAddr,
ServerPk: s.ResolverPk, ServerPk: s.ResolverPk,
@ -282,7 +282,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
func BenchmarkServeQUIC(b *testing.B) { func BenchmarkServeQUIC(b *testing.B) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer( srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
) )
require.NoError(b, err) require.NoError(b, err)

View File

@ -20,7 +20,7 @@ import (
) )
func TestServerDNS_StartShutdown(t *testing.T) { func TestServerDNS_StartShutdown(t *testing.T) {
_, _ = dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) _, _ = dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
} }
func TestServerDNS_integration_query(t *testing.T) { func TestServerDNS_integration_query(t *testing.T) {
@ -42,7 +42,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1, wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
@ -54,7 +54,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1, wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
@ -89,7 +89,7 @@ func TestServerDNS_integration_query(t *testing.T) {
}, },
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantMsg: func(t *testing.T, m *dns.Msg) { wantMsg: func(t *testing.T, m *dns.Msg) {
opt := m.IsEdns0() opt := m.IsEdns0()
require.NotNil(t, opt) require.NotNil(t, opt)
@ -109,7 +109,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 0, wantRecordsCount: 0,
wantRCode: dns.RcodeFormatError, wantRCode: dns.RcodeFormatError,
}, { }, {
@ -122,7 +122,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "eXaMplE.oRg.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "eXaMplE.oRg.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1, wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
@ -136,7 +136,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 0, wantRecordsCount: 0,
wantRCode: dns.RcodeNotImplemented, wantRCode: dns.RcodeNotImplemented,
}, { }, {
@ -170,7 +170,7 @@ func TestServerDNS_integration_query(t *testing.T) {
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1, wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
@ -185,7 +185,7 @@ func TestServerDNS_integration_query(t *testing.T) {
}, },
}, },
// Set a handler that generates a large response // Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64), handler: dnsservertest.NewDefaultHandlerWithCount(64),
wantRecordsCount: 0, wantRecordsCount: 0,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
wantTruncated: true, wantTruncated: true,
@ -210,7 +210,7 @@ func TestServerDNS_integration_query(t *testing.T) {
}, },
}, },
// Set a handler that generates a large response // Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64), handler: dnsservertest.NewDefaultHandlerWithCount(64),
wantRecordsCount: 64, wantRecordsCount: 64,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
wantTruncated: false, wantTruncated: false,
@ -226,7 +226,7 @@ func TestServerDNS_integration_query(t *testing.T) {
}, },
}, },
// Set a handler that generates a large response // Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64), handler: dnsservertest.NewDefaultHandlerWithCount(64),
// No truncate // No truncate
wantRecordsCount: 64, wantRecordsCount: 64,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
@ -257,7 +257,7 @@ func TestServerDNS_integration_query(t *testing.T) {
}, },
}, },
}, },
handler: dnsservertest.DefaultHandler(), handler: dnsservertest.NewDefaultHandler(),
wantRecordsCount: 1, wantRecordsCount: 1,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}} }}
@ -304,7 +304,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
// As per RFC 7766 we should support queries pipelining for TCP, that is // As per RFC 7766 we should support queries pipelining for TCP, that is
// server must be able to process incoming queries in parallel and write // server must be able to process incoming queries in parallel and write
// responses possibly out of order within the same connection. // responses possibly out of order within the same connection.
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) _, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
// Establish a connection. // Establish a connection.
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
@ -372,7 +372,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
} }
func TestServerDNS_integration_udpMsgIgnore(t *testing.T) { func TestServerDNS_integration_udpMsgIgnore(t *testing.T) {
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) _, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
conn, err := net.Dial("udp", addr) conn, err := net.Dial("udp", addr)
require.Nil(t, err) require.Nil(t, err)
@ -469,7 +469,7 @@ func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) _, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
require.Nil(t, err) require.Nil(t, err)

View File

@ -49,7 +49,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
name: "udp_truncate_response", name: "udp_truncate_response",
network: dnsserver.NetworkUDP, network: dnsserver.NetworkUDP,
// Set a handler that generates a large response // Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64), handler: dnsservertest.NewDefaultHandlerWithCount(64),
// DNSCrypt server removes all records from a truncated response // DNSCrypt server removes all records from a truncated response
expectedRecordsCount: 0, expectedRecordsCount: 0,
expectedRCode: dns.RcodeSuccess, expectedRCode: dns.RcodeSuccess,
@ -66,7 +66,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
name: "udp_edns0_no_truncate", name: "udp_edns0_no_truncate",
network: dnsserver.NetworkUDP, network: dnsserver.NetworkUDP,
// Set a handler that generates a large response // Set a handler that generates a large response
handler: dnsservertest.CreateTestHandler(64), handler: dnsservertest.NewDefaultHandlerWithCount(64),
expectedRecordsCount: 64, expectedRecordsCount: 64,
expectedRCode: dns.RcodeSuccess, expectedRCode: dns.RcodeSuccess,
expectedTruncated: false, expectedTruncated: false,
@ -89,7 +89,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
handler := tc.handler handler := tc.handler
if tc.handler == nil { if tc.handler == nil {
handler = dnsservertest.DefaultHandler() handler = dnsservertest.NewDefaultHandler()
} }
s := dnsservertest.RunDNSCryptServer(t, handler) s := dnsservertest.RunDNSCryptServer(t, handler)

View File

@ -19,6 +19,7 @@ import (
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3"
@ -304,9 +305,9 @@ func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Lis
// application won't be able to continue listening to DoH. // application won't be able to continue listening to DoH.
defer s.handlePanicAndExit(ctx) defer s.handlePanicAndExit(ctx)
scheme := "https" scheme := urlutil.SchemeHTTPS
if s.conf.TLSConfig == nil { if s.conf.TLSConfig == nil {
scheme = "http" scheme = urlutil.SchemeHTTP
} }
u := &url.URL{ u := &url.URL{

View File

@ -109,7 +109,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer( srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
nil, nil,
) )
@ -146,7 +146,7 @@ func TestServerHTTPS_integration_nonDNSHandler(t *testing.T) {
}) })
srv, err := dnsservertest.RunLocalHTTPSServer( srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
nil, nil,
testHandler, testHandler,
) )
@ -321,7 +321,7 @@ func TestDNSMsgToJSONMsg(t *testing.T) {
func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) { func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer( srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
nil, nil,
) )
@ -346,7 +346,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
func TestServerHTTPS_0RTT(t *testing.T) { func TestServerHTTPS_0RTT(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer( srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
nil, nil,
) )
@ -514,7 +514,7 @@ func createDoH3Client(
tlsConfig = tlsConfig.Clone() tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = []string{http3.NextProtoH3} tlsConfig.NextProtos = []string{http3.NextProtoH3}
transport := &http3.RoundTripper{ transport := &http3.Transport{
DisableCompression: true, DisableCompression: true,
Dial: func( Dial: func(
ctx context.Context, ctx context.Context,

View File

@ -7,13 +7,18 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// JSONMsg represents a *dns.Msg in the JSON format defined here: // JSONMsg represents a *dns.Msg in the JSON format defined here:
// https://developers.google.com/speed/public-dns/docs/doh/json#dns_response_in_json // https://developers.google.com/speed/public-dns/docs/doh/json#dns_response_in_json
// Note, that we do not implement some parts of it. There is no "Comment" field //
// and there's no "edns_client_subnet". // NOTE: This API differs from the Google one in the following ways:
// 1. The "Comment" field is not implemented.
// 2. The "edns_client_subnet" query parameter is not supported.
// 3. The "sde" query parameter is added and supported for the experimental
// Structured DNS Errors feature.
type JSONMsg struct { type JSONMsg struct {
Question []JSONQuestion `json:"Question"` Question []JSONQuestion `json:"Question"`
Answer []JSONAnswer `json:"Answer"` Answer []JSONAnswer `json:"Answer"`
@ -26,13 +31,13 @@ type JSONMsg struct {
Status int `json:"Status"` Status int `json:"Status"`
} }
// JSONQuestion is a part of JSONMsg definition. // JSONQuestion is a part of [JSONMsg] definition.
type JSONQuestion struct { type JSONQuestion struct {
Name string `json:"name"` Name string `json:"name"`
Type uint16 `json:"type"` Type uint16 `json:"type"`
} }
// JSONAnswer is a part of JSONMsg definition. // JSONAnswer is a part of [JSONMsg] definition.
type JSONAnswer struct { type JSONAnswer struct {
Name string `json:"name"` Name string `json:"name"`
Data string `json:"data"` Data string `json:"data"`
@ -41,7 +46,7 @@ type JSONAnswer struct {
Class uint16 `json:"class"` Class uint16 `json:"class"`
} }
// DNSMsgToJSONMsg converts the *dns.Msg to the JSON format (*JSONMsg). // DNSMsgToJSONMsg converts the *dns.Msg to the JSON format.
func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) { func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
msg = &JSONMsg{ msg = &JSONMsg{
Status: m.Rcode, Status: m.Rcode,
@ -74,9 +79,9 @@ func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
func rrToJSON(rr dns.RR) (j JSONAnswer) { func rrToJSON(rr dns.RR) (j JSONAnswer) {
hdr := rr.Header() hdr := rr.Header()
// Extracting the RR value is a bit tricky since miekg/dns does not // Extracting the RR value is a bit tricky since miekg/dns does not expose
// expose the necessary methods. This way we can benefit from the // the necessary methods. This way we can benefit from the proper string
// proper string serialization code that's used inside miekg/dns. // serialization code that's used inside miekg/dns.
hdrStr := hdr.String() hdrStr := hdr.String()
valStr := rr.String() valStr := rr.String()
data := strings.TrimLeft(strings.TrimPrefix(valStr, hdrStr), " ") data := strings.TrimLeft(strings.TrimPrefix(valStr, hdrStr), " ")
@ -90,19 +95,17 @@ func rrToJSON(rr dns.RR) (j JSONAnswer) {
} }
} }
// dnsMsgToJSON converts the *dns.Msg to the JSON format (JSONMsg) and returns // dnsMsgToJSON converts the *dns.Msg to the JSON format and returns it in the
// it in the serialized form. // serialized form.
func dnsMsgToJSON(m *dns.Msg) (b []byte, err error) { func dnsMsgToJSON(m *dns.Msg) (b []byte, err error) {
msg := DNSMsgToJSONMsg(m) return json.Marshal(DNSMsgToJSONMsg(m))
return json.Marshal(msg)
} }
// httpRequestToMsgJSON builds a DNS message from the request parameters. // httpRequestToMsgJSON builds a DNS message from the request parameters.
// We use the same parameters as the ones defined here: //
// https://developers.google.com/speed/public-dns/docs/doh/json#supported_parameters // See [JSONMsg].
// Some parameters are not supported: "ct", "edns_client_subnet". func httpRequestToMsgJSON(httpReq *http.Request) (b []byte, err error) {
func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) { q := httpReq.URL.Query()
q := req.URL.Query()
// Query name, the only required parameter. // Query name, the only required parameter.
name := q.Get("name") name := q.Get("name")
@ -111,71 +114,91 @@ func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) {
return nil, ErrInvalidArgument return nil, ErrInvalidArgument
} }
// RR type can be represented as a number in [1, 65535] or a // RR type can be represented as a number in [1, 65535] or a canonical
// canonical string (case-insensitive, such as A or AAAA). // string (case-insensitive, such as A or AAAA).
var t uint16 qt, err := urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
t, err = urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
// Query class can be represented as a number in [1, 65535] or a // Query class can be represented as a number in [1, 65535] or a canonical
// canonical string (case-insensitive). // string (case-insensitive).
var qc uint16 qc, err := urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
qc, err = urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
// The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable // The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable DNSSEC
// DNSSEC validation; use cd=0, cd=false, or no cd parameter to // validation; use cd=0, cd=false, or no cd parameter to enable DNSSEC
// enable DNSSEC validation. // validation.
var cd bool cd, err := urlQueryParameterToBoolean(q, "cd", false)
cd, err = urlQueryParameterToBoolean(q, "cd", false)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
// The DO (DNSSEC OK) flag. Use do=1, or do=true to include DNSSEC // The DO (DNSSEC OK) flag. Use do=1 (or do=true) to include DNSSEC records
// records (RRSIG, NSEC, NSEC3); use do=0, do=false, or no do parameter // (RRSIG, NSEC, NSEC3); use do=0 (do=false) or no do parameter to omit
// to omit DNSSEC records. // DNSSEC records.
var do bool do, err := urlQueryParameterToBoolean(q, "do", false)
do, err = urlQueryParameterToBoolean(q, "do", false)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// The experimental Structured DNS Errors feature.
sde, err := urlQueryParameterToBoolean(q, "sde", false)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
// Now build a DNS message with all those parameters // Now build a DNS message with all those parameters
r := &dns.Msg{ req := &dns.Msg{
MsgHdr: dns.MsgHdr{ MsgHdr: dns.MsgHdr{
Id: dns.Id(), Id: dns.Id(),
CheckingDisabled: cd, CheckingDisabled: cd,
RecursionDesired: true, RecursionDesired: true,
}, },
Question: []dns.Question{ Question: []dns.Question{{
{ Name: dns.Fqdn(name),
Name: dns.Fqdn(name), Qtype: qt,
Qtype: t, Qclass: qc,
Qclass: qc, }},
},
},
} }
if do { setEDNSFromQuery(req, do, sde)
r.SetEdns0(dns.MaxMsgSize, do)
return req.Pack()
}
// setEDNSFromQuery sets the EDNS parameters on the request depending on the
// query parameters.
func setEDNSFromQuery(req *dns.Msg, do, sde bool) {
if !do && !sde {
return
} }
return r.Pack() req.SetEdns0(dns.MaxMsgSize, do)
if sde {
opt := req.Extra[0].(*dns.OPT)
opt.Option = append(opt.Option, &dns.EDNS0_EDE{})
}
} }
// urlQueryParameterToUint16 is a helper function that extracts a uint16 value // urlQueryParameterToUint16 is a helper function that extracts a uint16 value
// from a query parameter. See httpRequestToMsgJSON to see how it's used. // from a query parameter.
func urlQueryParameterToUint16( func urlQueryParameterToUint16(
q url.Values, q url.Values,
name string, name string,
defaultValue uint16, defaultValue uint16,
strValuesMap map[string]uint16, strValuesMap map[string]uint16,
) (v uint16, err error) { ) (v uint16, err error) {
defer func() { err = errors.Annotate(err, "parameter %q: %w", name) }()
strValue := q.Get(name) strValue := q.Get(name)
uintValue, convErr := strconv.ParseUint(strValue, 10, 16) uintValue, convErr := strconv.ParseUint(strValue, 10, 16)
switch { switch {
@ -199,12 +222,8 @@ func urlQueryParameterToUint16(
} }
// urlQueryParameterToBoolean is a helper function that extracts a boolean value // urlQueryParameterToBoolean is a helper function that extracts a boolean value
// from a query parameter. See httpRequestToMsgJSON to see how it's used. // from a query parameter.
func urlQueryParameterToBoolean( func urlQueryParameterToBoolean(q url.Values, name string, defaultValue bool) (v bool, err error) {
q url.Values,
name string,
defaultValue bool,
) (v bool, err error) {
strValue := q.Get(name) strValue := q.Get(name)
switch strValue { switch strValue {
case "1", "true", "True": case "1", "true", "True":

View File

@ -25,7 +25,7 @@ import (
func TestServerQUIC_integration_query(t *testing.T) { func TestServerQUIC_integration_query(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer( srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
) )
require.NoError(t, err) require.NoError(t, err)
@ -74,7 +74,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
func TestServerQUIC_integration_ENDS0Padding(t *testing.T) { func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer( srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
) )
require.NoError(t, err) require.NoError(t, err)
@ -108,7 +108,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
func TestServerQUIC_integration_0RTT(t *testing.T) { func TestServerQUIC_integration_0RTT(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer( srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
) )
require.NoError(t, err) require.NoError(t, err)
@ -146,7 +146,7 @@ func TestServerQUIC_integration_0RTT(t *testing.T) {
func TestServerQUIC_integration_largeQuery(t *testing.T) { func TestServerQUIC_integration_largeQuery(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, addr, err := dnsservertest.RunLocalQUICServer( srv, addr, err := dnsservertest.RunLocalQUICServer(
dnsservertest.DefaultHandler(), dnsservertest.NewDefaultHandler(),
tlsConfig, tlsConfig,
) )
require.NoError(t, err) require.NoError(t, err)

View File

@ -17,7 +17,7 @@ import (
func TestServerTLS_integration_queryTLS(t *testing.T) { func TestServerTLS_integration_queryTLS(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig) addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
// Create a test message. // Create a test message.
req := new(dns.Msg) req := new(dns.Msg)
@ -94,7 +94,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
t.Parallel() t.Parallel()
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
h := dnsservertest.DefaultHandler() h := dnsservertest.NewDefaultHandler()
addr := dnsservertest.RunTLSServer(t, h, tlsConfig) addr := dnsservertest.RunTLSServer(t, h, tlsConfig)
conn, err := tls.Dial("tcp", addr.String(), tlsConfig) conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
@ -120,7 +120,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
func TestServerTLS_integration_noTruncateQuery(t *testing.T) { func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
// Handler that writes a huge response which would not fit // Handler that writes a huge response which would not fit
// into a UDP response, but it should fit a TCP response just okay. // into a UDP response, but it should fit a TCP response just okay.
handler := dnsservertest.CreateTestHandler(64) handler := dnsservertest.NewDefaultHandlerWithCount(64)
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, handler, tlsConfig) addr := dnsservertest.RunTLSServer(t, handler, tlsConfig)
@ -155,7 +155,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
// i.e. we should be able to process incoming queries in parallel and // i.e. we should be able to process incoming queries in parallel and
// write responses out of order. // write responses out of order.
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig) addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
// First - establish a connection // First - establish a connection
conn, err := tls.Dial("tcp", addr.String(), tlsConfig) conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
@ -221,7 +221,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
func TestServerTLS_integration_ENDS0Padding(t *testing.T) { func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig) addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
req := dnsservertest.CreateMessage("example.org.", dns.TypeA) req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
req.Extra = []dns.RR{dnsservertest.NewEDNS0Padding(req.Len(), dns.DefaultMsgSize)} req.Extra = []dns.RR{dnsservertest.NewEDNS0Padding(req.Len(), dns.DefaultMsgSize)}

228
internal/dnssvc/config.go Normal file
View File

@ -0,0 +1,228 @@
package dnssvc
import (
"log/slog"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/prometheus/client_golang/prometheus"
)
// Config is the configuration of the AdGuard DNS service.
type Config struct {
// Handlers are the handlers to use in this DNS service.
Handlers Handlers
// NewListener, when set, is used instead of the package-level function
// [NewListener] when creating a DNS listener.
//
// TODO(a.garipov): This is only used for tests. Replace with a
// [netext.ListenConfig].
NewListener NewListenerFunc
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse. It must not be nil.
Cloner *dnsmsg.Cloner
// ControlConf is the configuration of socket options.
ControlConf *netext.ControlConfig
// ConnLimiter, if not nil, is used to limit the number of simultaneously
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// ErrColl is the error collector that is used to collect critical and
// non-critical errors. It must not be nil.
ErrColl errcoll.Interface
// NonDNS is the handler for non-DNS HTTP requests. It must not be nil.
NonDNS http.Handler
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// ServerGroups are the DNS server groups. Each element must be non-nil.
ServerGroups []*agd.ServerGroup
// HandleTimeout defines the timeout for the entire handling of a single
// query. It must be greater than zero.
HandleTimeout time.Duration
}
// NewListenerFunc is the type for DNS listener constructors.
type NewListenerFunc func(
srv *agd.Server,
baseConf dnsserver.ConfigBase,
nonDNS http.Handler,
) (l Listener, err error)
// Listener is a type alias for dnsserver.Server to make internal naming more
// consistent.
type Listener = dnsserver.Server
// HandlersConfig is the configuration necessary to create or wrap the main DNS
// handler.
//
// TODO(a.garipov): Consider adding validation functions.
type HandlersConfig struct {
// BaseLogger is used to create loggers with custom prefixes for middlewares
// and the service itself. It must not be nil.
BaseLogger *slog.Logger
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse. It must not be nil.
Cloner *dnsmsg.Cloner
// Cache is the configuration for the DNS cache.
Cache *CacheConfig
// HumanIDParser is used to normalize and parse human-readable device
// identifiers. It must not be nil if at least one server group has
// profiles enabled.
HumanIDParser *agd.HumanIDParser
// Messages is the message constructor used to create blocked and other
// messages for this DNS service. It must not be nil.
Messages *dnsmsg.Constructor
// PluginRegistry is used to override configuration parameters.
PluginRegistry *plugin.Registry
// StructuredErrors is the configuration for the experimental Structured DNS
// Errors feature in the profiles' message constructors. It must not be
// nil.
StructuredErrors *dnsmsg.StructuredDNSErrorsConfig
// AccessManager is used to block requests. It must not be nil.
AccessManager access.Interface
// BillStat is used to collect billing statistics. It must not be nil.
BillStat billstat.Recorder
// CacheManager is the global cache manager. It must not be nil.
CacheManager agdcache.Manager
// DNSCheck is used by clients to check if they use AdGuard DNS. It must
// not be nil.
DNSCheck dnscheck.Interface
// DNSDB is used to update anonymous statistics about DNS queries. It must
// not be nil.
DNSDB dnsdb.Interface
// ErrColl is the error collector that is used to collect critical and
// non-critical errors. It must not be nil.
ErrColl errcoll.Interface
// FilterStorage is the storage of all filters. It must not be nil.
FilterStorage filter.Storage
// GeoIP is the GeoIP database used to detect geographic data about IP
// addresses in requests and responses. It must not be nil.
GeoIP geoip.Interface
// Handler is the ultimate handler of the DNS query to be wrapped by
// middlewares. It must not be nil.
Handler dnsserver.Handler
// HashMatcher is the safe-browsing hash matcher for TXT queries. It must
// not be nil.
HashMatcher filter.HashMatcher
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on. It must not be nil if at least one server
// group has profiles enabled.
ProfileDB profiledb.Interface
// PrometheusRegisterer is used to register Prometheus metrics. It must not
// be nil.
PrometheusRegisterer prometheus.Registerer
// QueryLog is used to write the logs into. It must not be nil.
QueryLog querylog.Interface
// RateLimit is used for allow or decline requests. It must not be nil.
RateLimit ratelimit.Interface
// RuleStat is used to collect statistics about matched filtering rules and
// rule lists. It must not be nil.
RuleStat rulestat.Interface
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// FilteringGroups are the DNS filtering groups. Each element must be
// non-nil.
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
// ServerGroups are the DNS server groups for which to build handlers. Each
// element and its servers must be non-nil.
ServerGroups []*agd.ServerGroup
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes in
// the profiles' message constructors.
EDEEnabled bool
}
// Handlers contains the map of handlers for each server of each server group.
// The pointers are the same as those passed in a [HandlersConfig] to
// [NewHandlers].
type Handlers map[HandlerKey]dnsserver.Handler
// HandlerKey is a key for the [Handlers] map.
type HandlerKey struct {
Server *agd.Server
ServerGroup *agd.ServerGroup
}
// CacheConfig is the configuration for the DNS cache.
type CacheConfig struct {
// MinTTL is the minimum supported TTL for cache items.
MinTTL time.Duration
// ECSCount is the size of the DNS cache for domain names that support
// ECS, in entries. It must be greater than zero if [CacheConfig.CacheType]
// is [CacheTypeECS].
ECSCount int
// NoECSCount is the size of the DNS cache for domain names that don't
// support ECS, in entries. It must be greater than zero if
// [CacheConfig.CacheType] is [CacheTypeSimple] or [CacheTypeECS].
NoECSCount int
// Type is the cache type. It must be valid.
Type CacheType
// OverrideCacheTTL shows if the TTL overriding logic should be used.
OverrideCacheTTL bool
}
// CacheType is the type of the cache to use.
type CacheType uint8
// CacheType constants.
const (
CacheTypeNone CacheType = iota + 1
CacheTypeSimple
CacheTypeECS
)

View File

@ -0,0 +1,35 @@
package dnssvc
import (
"context"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
)
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
// returns a context with the given timeout as well as a new [agd.RequestID].
type contextConstructor struct {
timeout time.Duration
}
// newContextConstructor returns a new properly initialized *contextConstructor.
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
return &contextConstructor{
timeout: timeout,
}
}
// type check
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
// New implements the [dnsserver.ContextConstructor] interface for
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
// as its timeout and the corresponding cancelation function.
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
return ctx, cancel
}

View File

@ -7,289 +7,27 @@ package dnssvc
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter" "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus" dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/service" "github.com/AdguardTeam/golibs/service"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
) )
// Config is the configuration of the AdGuard DNS service. // Service is the main DNS service of AdGuard DNS.
type Config struct { type Service struct {
// BaseLogger is used to create loggers with custom prefixes for middlewares groups []*serverGroup
// and the service itself.
BaseLogger *slog.Logger
// Messages is the message constructor used to create blocked and other
// messages for this DNS service.
Messages *dnsmsg.Constructor
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse.
Cloner *dnsmsg.Cloner
// ControlConf is the configuration of socket options.
ControlConf *netext.ControlConfig
// ConnLimiter, if not nil, is used to limit the number of simultaneously
// active stream-connections.
ConnLimiter *connlimiter.Limiter
// HumanIDParser is used to normalize and parse human-readable device
// identifiers.
HumanIDParser *agd.HumanIDParser
// PluginRegistry is used to override configuration parameters.
PluginRegistry *plugin.Registry
// AccessManager is used to block requests.
AccessManager access.Interface
// SafeBrowsing is the safe browsing TXT hash matcher.
SafeBrowsing filter.HashMatcher
// BillStat is used to collect billing statistics.
BillStat billstat.Recorder
// CacheManager is the global cache manager. CacheManager must not be nil.
CacheManager agdcache.Manager
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on.
ProfileDB profiledb.Interface
// PrometheusRegisterer is used to register Prometheus metrics.
PrometheusRegisterer prometheus.Registerer
// DNSCheck is used by clients to check if they use AdGuard DNS.
DNSCheck dnscheck.Interface
// NonDNS is the handler for non-DNS HTTP requests.
NonDNS http.Handler
// DNSDB is used to update anonymous statistics about DNS queries.
DNSDB dnsdb.Interface
// ErrColl is the error collector that is used to collect critical and
// non-critical errors.
ErrColl errcoll.Interface
// FilterStorage is the storage of all filters.
FilterStorage filter.Storage
// GeoIP is the GeoIP database used to detect geographic data about IP
// addresses in requests and responses.
GeoIP geoip.Interface
// QueryLog is used to write the logs into.
QueryLog querylog.Interface
// RuleStat is used to collect statistics about matched filtering rules and
// rule lists.
RuleStat rulestat.Interface
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
// TODO(a.garipov): The handler and service logic should really not be
// intertwined in this way. See AGDNS-1327.
NewListener NewListenerFunc
// Handler is used as the main DNS handler instead of a simple forwarder.
// It must not be nil.
//
// TODO(a.garipov): Think of a better way to make the DNS server logic more
// testable.
Handler dnsserver.Handler
// RateLimit is used for allow or decline requests.
RateLimit ratelimit.Interface
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
// valid Prometheus metric label.
MetricsNamespace string
// FilteringGroups are the DNS filtering groups. Each element must be
// non-nil.
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
// ServerGroups are the DNS server groups. Each element must be non-nil.
ServerGroups []*agd.ServerGroup
// HandleTimeout defines the timeout for the entire handling of a single
// query.
HandleTimeout time.Duration
// CacheSize is the size of the DNS cache for domain names that don't
// support ECS.
//
// TODO(a.garipov): Extract this and following fields to cache configuration
// struct.
CacheSize int
// ECSCacheSize is the size of the DNS cache for domain names that support
// ECS.
ECSCacheSize int
// CacheMinTTL is the minimum supported TTL for cache items. This setting
// is used when UseCacheTTLOverride set to true.
CacheMinTTL time.Duration
// UseCacheTTLOverride shows if the TTL overrides logic should be used.
UseCacheTTLOverride bool
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used.
UseECSCache bool
} }
type ( // serverGroup is a group of servers.
// MainMiddlewareMetrics is a re-export of the internal filtering-middleware type serverGroup struct {
// metrics interface. name agd.ServerGroupName
MainMiddlewareMetrics = mainmw.Metrics servers []*server
// RatelimitMiddlewareMetrics is a re-export of the metrics interface of the
// internal access and ratelimiting middleware.
RatelimitMiddlewareMetrics = ratelimitmw.Metrics
)
// New returns a new DNS service.
func New(c *Config) (svc *Service, err error) {
// Use either the configured listener initializer or the default one.
newListener := c.NewListener
if newListener == nil {
newListener = NewListener
}
// Configure the end of the request handling pipeline.
handler := c.Handler
if handler == nil {
return nil, errors.Error("handler in config must not be nil")
}
// Configure the pre-upstream middleware common for all servers of all
// groups.
preUps := preupstream.New(&preupstream.Config{
Cloner: c.Cloner,
CacheManager: c.CacheManager,
DB: c.DNSDB,
GeoIP: c.GeoIP,
CacheSize: c.CacheSize,
ECSCacheSize: c.ECSCacheSize,
UseECSCache: c.UseECSCache,
CacheMinTTL: c.CacheMinTTL,
UseCacheTTLOverride: c.UseCacheTTLOverride,
})
handler = preUps.Wrap(handler)
errCollListener := &errCollMetricsListener{
errColl: c.ErrColl,
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
}
// Configure the service itself.
groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{
groups: groups,
}
mainMwMtrc, err := newMainMiddlewareMetrics(c)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
for i, srvGrp := range c.ServerGroups {
// The Filtering Middlewares
//
// These are middlewares common to all filtering and server groups.
// They change the flow of request handling, so they are separated.
dnsHdlr := dnsserver.WithMiddlewares(
handler,
preservice.New(&preservice.Config{
Messages: c.Messages,
HashMatcher: c.SafeBrowsing,
Checker: c.DNSCheck,
}),
mainmw.New(&mainmw.Config{
Metrics: mainMwMtrc,
Messages: c.Messages,
Cloner: c.Cloner,
BillStat: c.BillStat,
ErrColl: c.ErrColl,
FilterStorage: c.FilterStorage,
GeoIP: c.GeoIP,
QueryLog: c.QueryLog,
RuleStat: c.RuleStat,
}),
)
var servers []*server
servers, err = newServers(c, srvGrp, dnsHdlr, rlMwMtrc, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
}
groups[i] = &serverGroup{
name: srvGrp.Name,
servers: servers,
}
}
return svc, nil
}
// newMainMiddlewareMetrics returns a filtering-middleware metrics
// implementation from the config.
func newMainMiddlewareMetrics(c *Config) (mainMwMtrc MainMiddlewareMetrics, err error) {
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
if mainMwMtrc != nil {
return mainMwMtrc, nil
}
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("mainmw metrics: %w", err)
}
return mainMwMtrc, nil
} }
// server is a group of listeners. // server is a group of listeners.
@ -303,27 +41,171 @@ type server struct {
listeners []*listener listeners []*listener
} }
// serverGroup is a group of servers. // listener is a Listener along with some of its associated data.
type serverGroup struct { type listener struct {
name agd.ServerGroupName Listener
servers []*server
name string
} }
// Service is the main DNS service of AdGuard DNS. // New returns a new DNS service.
type Service struct { func New(c *Config) (svc *Service, err error) {
groups []*serverGroup // Use either the configured listener initializer or the default one.
} newListener := c.NewListener
if newListener == nil {
// mustStartListener starts l and panics on any error. newListener = NewListener
func mustStartListener(
grp agd.ServerGroupName,
srv agd.ServerName,
l *listener,
) {
err := l.Start(context.Background())
if err != nil {
panic(fmt.Errorf("group %q: server %q: starting %q: %w", grp, srv, l.name, err))
} }
errCollListener := &errCollMetricsListener{
errColl: c.ErrColl,
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
}
// Configure the service itself.
groups := make([]*serverGroup, 0, len(c.ServerGroups))
for _, srvGrp := range c.ServerGroups {
g := &serverGroup{
name: srvGrp.Name,
}
g.servers, err = newServers(c, srvGrp, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
}
groups = append(groups, g)
}
svc = &Service{
groups: groups,
}
return svc, nil
}
// newServers creates a slice of servers.
func newServers(
c *Config,
srvGrp *agd.ServerGroup,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (servers []*server, err error) {
servers = make([]*server, 0, len(srvGrp.Servers))
for _, srv := range srvGrp.Servers {
k := HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handler, ok := c.Handlers[k]
if !ok {
return nil, fmt.Errorf("no handler for server %q of group %q", srv.Name, srvGrp.Name)
}
s := &server{
name: srv.Name,
handler: handler,
}
s.listeners, err = newListeners(c, srv, handler, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("server %q: %w", s.name, err)
}
servers = append(servers, s)
}
return servers, nil
}
// newListeners creates a slice of listeners for a server.
func newListeners(
c *Config,
srv *agd.Server,
handler dnsserver.Handler,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (listeners []*listener, err error) {
bindData := srv.BindData()
listeners = make([]*listener, 0, len(bindData))
for i, bindData := range bindData {
var addr string
if bindData.PrefixAddr == nil {
addr = bindData.AddrPort.String()
} else {
addr = bindData.PrefixAddr.String()
}
proto := srv.Protocol
name := listenerName(srv.Name, addr, proto)
baseConf := dnsserver.ConfigBase{
Network: dnsserver.NetworkAny,
Handler: handler,
Metrics: errCollListener,
Disposer: c.Cloner,
RequestContext: newContextConstructor(c.HandleTimeout),
ListenConfig: newListenConfig(
bindData.ListenConfig,
c.ControlConf,
c.ConnLimiter,
proto,
),
Name: name,
Addr: addr,
}
l := &listener{
name: name,
}
l.Listener, err = newListener(srv, baseConf, c.NonDNS)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
listeners = append(listeners, l)
}
return listeners, nil
}
// listenerName returns a standard name for a listener.
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
}
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
original netext.ListenConfig,
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if original != nil {
if l == nil {
return original
}
return connlimiter.NewListenConfig(original, l)
}
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {
lc = netext.DefaultListenConfig(ctrlConf)
}
if l != nil {
lc = connlimiter.NewListenConfig(lc, l)
}
return lc
} }
// type check // type check
@ -331,13 +213,13 @@ var _ service.Interface = (*Service)(nil)
// Start implements the [service.Interface] interface for *Service. It panics // Start implements the [service.Interface] interface for *Service. It panics
// if one of the listeners could not start. // if one of the listeners could not start.
func (svc *Service) Start(_ context.Context) (err error) { func (svc *Service) Start(ctx context.Context) (err error) {
for _, g := range svc.groups { for _, g := range svc.groups {
for _, s := range g.servers { for _, s := range g.servers {
for _, l := range s.listeners { for _, l := range s.listeners {
// Consider inability to start any one DNS listener a fatal // Consider inability to start any one DNS listener a fatal
// error. // error.
mustStartListener(g.name, s.name, l) mustStartListener(ctx, g.name, s.name, l)
} }
} }
} }
@ -345,17 +227,17 @@ func (svc *Service) Start(_ context.Context) (err error) {
return nil return nil
} }
// shutdownListeners is a helper function that shuts down all listeners of a // mustStartListener starts l and panics on any error.
// server. func mustStartListener(
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) { ctx context.Context,
for _, l := range listeners { srvGrp agd.ServerGroupName,
err = l.Shutdown(ctx) srv agd.ServerName,
if err != nil { l *listener,
return fmt.Errorf("shutting down listener %q: %w", l.name, err) ) {
} err := l.Start(ctx)
if err != nil {
panic(fmt.Errorf("group %q: server %q: starting %q: %w", srvGrp, srv, l.name, err))
} }
return nil
} }
// Shutdown implements the [service.Interface] interface for *Service. // Shutdown implements the [service.Interface] interface for *Service.
@ -378,9 +260,22 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil return nil
} }
// shutdownListeners is a helper function that shuts down all listeners of a
// server.
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
for _, l := range listeners {
err = l.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
}
}
return nil
}
// Handle is a simple helper to test the handling of DNS requests. // Handle is a simple helper to test the handling of DNS requests.
// //
// TODO(a.garipov): Remove once the mainmw refactoring is complete. // TODO(a.garipov): Remove once the refactoring is complete.
func (svc *Service) Handle( func (svc *Service) Handle(
ctx context.Context, ctx context.Context,
grpName agd.ServerGroupName, grpName agd.ServerGroupName,
@ -417,33 +312,10 @@ func (svc *Service) Handle(
return srv.handler.ServeDNS(ctx, rw, r) return srv.handler.ServeDNS(ctx, rw, r)
} }
// Listener is a type alias for dnsserver.Server to make internal naming more
// consistent.
type Listener = dnsserver.Server
// NewListenerFunc is the type for DNS listener constructors.
type NewListenerFunc func(
s *agd.Server,
baseConf dnsserver.ConfigBase,
nonDNS http.Handler,
) (l Listener, err error)
// listener is a Listener along with some of its associated data.
type listener struct {
Listener
name string
}
// listenerName returns a standard name for a listener.
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
}
// NewListener returns a new Listener. It is the default DNS listener // NewListener returns a new Listener. It is the default DNS listener
// constructor. // constructor.
// //
// TODO(a.garipov): Replace this in tests with [netext.ListenConfig]. // TODO(a.garipov): Replace this in tests with [netext.ListenConfig].
func NewListener( func NewListener(
s *agd.Server, s *agd.Server,
baseConf dnsserver.ConfigBase, baseConf dnsserver.ConfigBase,
@ -500,213 +372,8 @@ func NewListener(
TLSConfig: s.TLS, TLSConfig: s.TLS,
}) })
default: default:
return nil, fmt.Errorf("bad protocol %v", p) return nil, fmt.Errorf("protocol: %w: %d", errors.ErrBadEnumValue, p)
} }
return l, nil return l, nil
} }
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
// that returns a context with the given timeout as well as a new
// [agd.RequestID].
type contextConstructor struct {
timeout time.Duration
}
// newContextConstructor returns a new properly initialized *contextConstructor.
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
return &contextConstructor{
timeout: timeout,
}
}
// type check
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
// New implements the [dnsserver.ContextConstructor] interface for
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
// as its timeout and the corresponding cancelation function.
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
return ctx, cancel
}
// newServers creates a slice of servers.
//
// TODO(a.garipov): Refactor this into a builder pattern.
func newServers(
c *Config,
srvGrp *agd.ServerGroup,
handler dnsserver.Handler,
rlMwMtrc ratelimitmw.Metrics,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (servers []*server, err error) {
servers = make([]*server, len(srvGrp.Servers))
for i, s := range srvGrp.Servers {
// The Initial Middlewares
//
// These middlewares are either specific to the server or must be the
// furthest away from the handler and thus are the first to process
// a request.
// Assume that all the validations have been made during the
// configuration validation step back in package cmd. If we ever get
// new ways of receiving configuration, remove this assumption and
// validate fg.
fg := c.FilteringGroups[srvGrp.FilteringGroup]
df := newDeviceFinder(c, srvGrp, s)
rlm := ratelimitmw.New(&ratelimitmw.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw"),
Messages: c.Messages,
FilteringGroup: fg,
ServerGroup: srvGrp,
Server: s,
AccessManager: c.AccessManager,
DeviceFinder: df,
ErrColl: c.ErrColl,
GeoIP: c.GeoIP,
Metrics: rlMwMtrc,
Limiter: c.RateLimit,
// Only apply rate-limiting logic to plain DNS.
Protocols: []agd.Protocol{agd.ProtoDNS},
})
if err != nil {
return nil, fmt.Errorf("ratelimit: %w", err)
}
imw := initial.New(&initial.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
})
h := dnsserver.WithMiddlewares(
handler,
// Keep the rate limiting and access middlewares as the outer ones
// to make sure that the application logic isn't touched if the
// request is ratelimited or blocked by access settings.
rlm,
imw,
)
srvName := s.Name
var listeners []*listener
listeners, err = newListeners(c, s, h, errCollListener, newListener)
if err != nil {
return nil, fmt.Errorf("server %q: %w", srvName, err)
}
servers[i] = &server{
name: srvName,
handler: h,
listeners: listeners,
}
}
return servers, nil
}
// newDeviceFinder returns a new [agd.DeviceFinder] for a server based on the
// configuration.
func newDeviceFinder(c *Config, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
if !g.ProfilesEnabled {
return agd.EmptyDeviceFinder{}
}
return devicefinder.NewDefault(&devicefinder.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
ProfileDB: c.ProfileDB,
HumanIDParser: c.HumanIDParser,
Server: s,
DeviceDomains: g.TLS.DeviceDomains,
})
}
// newServers creates a slice of listeners for a server.
func newListeners(
c *Config,
srv *agd.Server,
handler dnsserver.Handler,
errCollListener *errCollMetricsListener,
newListener NewListenerFunc,
) (listeners []*listener, err error) {
bindData := srv.BindData()
listeners = make([]*listener, 0, len(bindData))
for i, bindData := range bindData {
var addr string
if bindData.PrefixAddr == nil {
addr = bindData.AddrPort.String()
} else {
addr = bindData.PrefixAddr.String()
}
proto := srv.Protocol
name := listenerName(srv.Name, addr, proto)
baseConf := dnsserver.ConfigBase{
Network: dnsserver.NetworkAny,
Handler: handler,
Metrics: errCollListener,
Disposer: c.Cloner,
RequestContext: newContextConstructor(c.HandleTimeout),
ListenConfig: newListenConfig(
bindData.ListenConfig,
c.ControlConf,
c.ConnLimiter,
proto,
),
Name: name,
Addr: addr,
}
var l Listener
l, err = newListener(srv, baseConf, c.NonDNS)
if err != nil {
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
}
listeners = append(listeners, &listener{
name: name,
Listener: l,
})
}
return listeners, nil
}
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
// servers. The resulting ListenConfig sets additional socket flags and
// processes the control messages of connections created with ListenPacket.
// Additionally, if l is not nil, it is used to limit the number of
// simultaneously active stream-connections.
func newListenConfig(
original netext.ListenConfig,
ctrlConf *netext.ControlConfig,
l *connlimiter.Limiter,
p agd.Protocol,
) (lc netext.ListenConfig) {
if original != nil {
if l == nil {
return original
}
return connlimiter.NewListenConfig(original, l)
}
if p == agd.ProtoDNS {
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
} else {
lc = netext.DefaultListenConfig(ctrlConf)
}
if l != nil {
lc = connlimiter.NewListenConfig(lc, l)
}
return lc
}

View File

@ -10,26 +10,17 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice" "github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testSrvGrpName is the [agd.ServerGroupName] for tests.
const testSrvGrpName agd.ServerGroupName = "test_group"
// type check // type check
var _ agdservice.Refresher = (*forward.Handler)(nil) var _ agdservice.Refresher = (*forward.Handler)(nil)
@ -159,16 +150,23 @@ func TestService_Start(t *testing.T) {
AddrPort: netip.MustParseAddrPort("127.0.0.1:53"), AddrPort: netip.MustParseAddrPort("127.0.0.1:53"),
}) })
srvGrp := &agd.ServerGroup{
Name: dnssvctest.ServerGroupName,
Servers: []*agd.Server{srv},
}
k := dnssvc.HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
c := &dnssvc.Config{ c := &dnssvc.Config{
BaseLogger: slogutil.NewDiscardLogger(), NewListener: newTestListenerFunc(tl),
NewListener: newTestListenerFunc(tl), Handlers: dnssvc.Handlers{
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(), k: dnsservertest.NewDefaultHandler(),
Handler: dnsservertest.DefaultHandler(), },
MetricsNamespace: "test_start", MetricsNamespace: "test_start",
ServerGroups: []*agd.ServerGroup{{ ServerGroups: []*agd.ServerGroup{srvGrp},
Name: "test_group",
Servers: []*agd.Server{srv},
}},
} }
svc, err := dnssvc.New(c) svc, err := dnssvc.New(c)
@ -206,15 +204,25 @@ func TestNew(t *testing.T) {
}), }),
} }
srvGrp := &agd.ServerGroup{
Name: dnssvctest.ServerGroupName,
Servers: srvs,
}
handlers := dnssvc.Handlers{}
for _, srv := range srvs {
k := dnssvc.HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handlers[k] = dnsservertest.NewDefaultHandler()
}
c := &dnssvc.Config{ c := &dnssvc.Config{
BaseLogger: slogutil.NewDiscardLogger(), Handlers: handlers,
Handler: dnsservertest.DefaultHandler(), MetricsNamespace: "test_new",
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(), ServerGroups: []*agd.ServerGroup{srvGrp},
MetricsNamespace: "test_new",
ServerGroups: []*agd.ServerGroup{{
Name: "test_group",
Servers: srvs,
}},
} }
svc, err := dnssvc.New(c) svc, err := dnssvc.New(c)

211
internal/dnssvc/handler.go Normal file
View File

@ -0,0 +1,211 @@
package dnssvc
import (
"context"
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/cache"
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
"github.com/AdguardTeam/AdGuardDNS/internal/ecscache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// NewHandlers returns the main DNS handlers wrapped in all necessary
// middlewares. c must not be nil.
func NewHandlers(ctx context.Context, c *HandlersConfig) (handlers Handlers, err error) {
handler := wrapPreUpstreamMw(ctx, c)
mainMwMtrc, err := newMainMiddlewareMetrics(c)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
mainMw := mainmw.New(&mainmw.Config{
Cloner: c.Cloner,
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "mainmw"),
Messages: c.Messages,
BillStat: c.BillStat,
ErrColl: c.ErrColl,
FilterStorage: c.FilterStorage,
GeoIP: c.GeoIP,
QueryLog: c.QueryLog,
Metrics: mainMwMtrc,
RuleStat: c.RuleStat,
})
handler = mainMw.Wrap(handler)
preSvcMw := preservice.New(&preservice.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "presvcmw"),
Messages: c.Messages,
HashMatcher: c.HashMatcher,
Checker: c.DNSCheck,
})
handler = preSvcMw.Wrap(handler)
postInitMw := c.PluginRegistry.PostInitialMiddleware()
if postInitMw != nil {
handler = postInitMw.Wrap(handler)
}
initMw := initial.New(&initial.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
})
handler = initMw.Wrap(handler)
return newHandlersForServers(c, handler)
}
// wrapPreUpstreamMw returns the handler wrapped into the pre-upstream
// middlewares.
//
// TODO(a.garipov): Adapt the cache tests that previously were in package
// preupstream.
func wrapPreUpstreamMw(ctx context.Context, c *HandlersConfig) (wrapped dnsserver.Handler) {
// TODO(a.garipov): Use in other places if necessary.
l := c.BaseLogger.With(slogutil.KeyPrefix, "dnssvc")
wrapped = c.Handler
switch conf := c.Cache; conf.Type {
case CacheTypeNone:
l.WarnContext(ctx, "cache disabled")
case CacheTypeSimple:
l.InfoContext(ctx, "plain cache enabled", "count", conf.NoECSCount)
cacheMw := cache.NewMiddleware(&cache.MiddlewareConfig{
// TODO(a.garipov): Do not use promauto and refactor.
MetricsListener: dnssrvprom.NewCacheMetricsListener(metrics.Namespace()),
Size: conf.NoECSCount,
MinTTL: conf.MinTTL,
OverrideTTL: conf.OverrideCacheTTL,
})
wrapped = cacheMw.Wrap(wrapped)
case CacheTypeECS:
l.InfoContext(
ctx,
"ecs cache enabled",
"ecs_count", conf.ECSCount,
"no_ecs_count", conf.NoECSCount,
)
cacheMw := ecscache.NewMiddleware(&ecscache.MiddlewareConfig{
Cloner: c.Cloner,
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ecscache"),
CacheManager: c.CacheManager,
GeoIP: c.GeoIP,
NoECSCount: conf.NoECSCount,
ECSCount: conf.ECSCount,
MinTTL: conf.MinTTL,
OverrideTTL: conf.OverrideCacheTTL,
})
wrapped = cacheMw.Wrap(wrapped)
default:
panic(fmt.Errorf("cache type: %w: %d", errors.ErrBadEnumValue, conf.Type))
}
preUps := preupstream.New(ctx, &preupstream.Config{
DB: c.DNSDB,
})
wrapped = preUps.Wrap(wrapped)
return wrapped
}
// newMainMiddlewareMetrics returns a filtering-middleware metrics
// implementation from the config.
func newMainMiddlewareMetrics(c *HandlersConfig) (mainMwMtrc MainMiddlewareMetrics, err error) {
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
if mainMwMtrc != nil {
return mainMwMtrc, nil
}
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("mainmw metrics: %w", err)
}
return mainMwMtrc, nil
}
// newHandlersForServers returns a handler map for each server group and each
// server.
func newHandlersForServers(c *HandlersConfig, h dnsserver.Handler) (handlers Handlers, err error) {
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
if err != nil {
return nil, fmt.Errorf("ratelimit middleware metrics: %w", err)
}
handlers = Handlers{}
rlMwLogger := c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw")
for _, srvGrp := range c.ServerGroups {
fltGrp, ok := c.FilteringGroups[srvGrp.FilteringGroup]
if !ok {
return nil, fmt.Errorf(
"no filtering group %q for server group %q",
srvGrp.FilteringGroup,
srvGrp.Name,
)
}
for _, srv := range srvGrp.Servers {
rlMw := ratelimitmw.New(&ratelimitmw.Config{
Logger: rlMwLogger,
Messages: c.Messages,
FilteringGroup: fltGrp,
ServerGroup: srvGrp,
Server: srv,
StructuredErrors: c.StructuredErrors,
AccessManager: c.AccessManager,
DeviceFinder: newDeviceFinder(c, srvGrp, srv),
ErrColl: c.ErrColl,
GeoIP: c.GeoIP,
Metrics: rlMwMtrc,
Limiter: c.RateLimit,
Protocols: []agd.Protocol{agd.ProtoDNS},
EDEEnabled: c.EDEEnabled,
})
k := HandlerKey{
Server: srv,
ServerGroup: srvGrp,
}
handlers[k] = rlMw.Wrap(h)
}
}
return handlers, nil
}
// newDeviceFinder returns a new agd.DeviceFinder for a server based on the
// configuration. All arguments must not be nil.
func newDeviceFinder(c *HandlersConfig, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
if !g.ProfilesEnabled {
return agd.EmptyDeviceFinder{}
}
return devicefinder.NewDefault(&devicefinder.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
ProfileDB: c.ProfileDB,
HumanIDParser: c.HumanIDParser,
Server: s,
DeviceDomains: g.TLS.DeviceDomains,
})
}

View File

@ -0,0 +1,188 @@
package dnssvc_test
import (
"context"
"net/netip"
"path"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewHandlers(t *testing.T) {
t.Parallel()
accessMgr := &agdtest.AccessManager{
OnIsBlockedHost: func(host string, qt uint16) (blocked bool) { panic("not implemented") },
OnIsBlockedIP: func(ip netip.Addr) (blocked bool) { panic("not implemented") },
}
billStat := &agdtest.BillStatRecorder{
OnRecord: func(
_ context.Context,
_ agd.DeviceID,
_ geoip.Country,
_ geoip.ASN,
_ time.Time,
_ agd.Protocol,
) {
panic("not implemented")
},
}
dnsCk := &agdtest.DNSCheck{
OnCheck: func(
_ context.Context,
_ *dns.Msg,
_ *agd.RequestInfo,
) (resp *dns.Msg, err error) {
panic("not implemented")
},
}
dnsDB := &agdtest.DNSDB{
OnRecord: func(_ context.Context, _ *dns.Msg, _ *agd.RequestInfo) {
panic("not implemented")
},
}
fltGrps := map[agd.FilteringGroupID]*agd.FilteringGroup{
dnssvctest.FilteringGroupID: {
ID: dnssvctest.FilteringGroupID,
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
RuleListsEnabled: true,
},
}
fltStrg := &agdtest.FilterStorage{
OnFilterFromContext: func(_ context.Context, _ *agd.RequestInfo) (f filter.Interface) {
panic("not implemented")
},
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
}
hashMatcher := &agdtest.HashMatcher{
OnMatchByPrefix: func(
_ context.Context,
_ string,
) (hashes []string, matched bool, err error) {
panic("not implemented")
},
}
queryLog := &agdtest.QueryLog{
OnWrite: func(_ context.Context, _ *querylog.Entry) (err error) {
panic("not implemented")
},
}
ruleStat := &agdtest.RuleStat{
OnCollect: func(_ context.Context, _ agd.FilterListID, _ agd.FilterRuleText) {
panic("not implemented")
},
}
srv := dnssvctest.NewServer(dnssvctest.ServerName, agd.ProtoDoT, &agd.ServerBindData{
AddrPort: dnssvctest.ServerAddrPort,
})
srvGrp := &agd.ServerGroup{
DDR: &agd.DDR{
Enabled: true,
},
TLS: &agd.TLS{},
Name: dnssvctest.ServerGroupName,
FilteringGroup: dnssvctest.FilteringGroupID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}
testCases := []struct {
cacheConf *dnssvc.CacheConfig
name string
}{{
cacheConf: &dnssvc.CacheConfig{
Type: dnssvc.CacheTypeNone,
},
name: "no_cache",
}, {
cacheConf: &dnssvc.CacheConfig{
MinTTL: 10 * time.Second,
NoECSCount: 100,
Type: dnssvc.CacheTypeSimple,
OverrideCacheTTL: true,
},
name: "cache_simple",
}, {
cacheConf: &dnssvc.CacheConfig{
MinTTL: 10 * time.Second,
ECSCount: 100,
NoECSCount: 100,
Type: dnssvc.CacheTypeECS,
OverrideCacheTTL: true,
},
name: "cache_ecs",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.ContextWithTimeout(t, dnssvctest.Timeout)
handlers, err := dnssvc.NewHandlers(ctx, &dnssvc.HandlersConfig{
BaseLogger: slogutil.NewDiscardLogger(),
Cloner: agdtest.NewCloner(),
Cache: tc.cacheConf,
HumanIDParser: agd.NewHumanIDParser(),
Messages: agdtest.NewConstructor(t),
PluginRegistry: nil,
StructuredErrors: agdtest.NewSDEConfig(true),
AccessManager: accessMgr,
BillStat: billStat,
// TODO(a.garipov): Create a test implementation?
CacheManager: agdcache.EmptyManager{},
DNSCheck: dnsCk,
DNSDB: dnsDB,
ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg,
GeoIP: agdtest.NewGeoIP(),
Handler: dnsservertest.NewPanicHandler(),
HashMatcher: hashMatcher,
ProfileDB: agdtest.NewProfileDB(),
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
QueryLog: queryLog,
RateLimit: agdtest.NewRateLimit(),
RuleStat: ruleStat,
MetricsNamespace: path.Base(t.Name()),
FilteringGroups: fltGrps,
ServerGroups: []*agd.ServerGroup{srvGrp},
EDEEnabled: true,
})
require.NoError(t, err)
assert.Len(t, handlers, 1)
for k, v := range handlers {
assert.Same(t, srv, k.Server)
assert.Same(t, srvGrp, k.ServerGroup)
assert.NotNil(t, v)
break
}
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/access" "github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
"github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd" "github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
@ -19,6 +20,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
@ -46,7 +48,7 @@ import (
// from the service. The channels must not be nil. Each sending to a channel // from the service. The channels must not be nil. Each sending to a channel
// wrapped with [testutil.RequireSend] using [dnssvctest.Timeout]. // wrapped with [testutil.RequireSend] using [dnssvctest.Timeout].
// //
// It also uses the [dnsservertest.DefaultHandler] to create the DNS handler. // It also uses the [dnsservertest.NewDefaultHandler] to create the DNS handler.
func newTestService( func newTestService(
t testing.TB, t testing.TB,
flt filter.Interface, flt filter.Interface,
@ -81,46 +83,14 @@ func newTestService(
QueryLogEnabled: true, QueryLogEnabled: true,
} }
db := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = func(
ctx context.Context, _ context.Context,
id agd.ProfileID, id agd.DeviceID,
humanID agd.HumanID, ) (p *agd.Profile, d *agd.Device, err error) {
devType agd.DeviceType, testutil.RequireSend(pt, profileDBCh, id, dnssvctest.Timeout)
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func( return prof, dev, nil
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
testutil.RequireSend(pt, profileDBCh, id, dnssvctest.Timeout)
return prof, dev, nil
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
} }
accessManager := &agdtest.AccessManager{ accessManager := &agdtest.AccessManager{
@ -145,18 +115,12 @@ func newTestService(
Continent: geoip.ContinentEU, Continent: geoip.ContinentEU,
ASN: 42, ASN: 42,
} }
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ *geoip.Location,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(host string, _ netip.Addr) (l *geoip.Location, err error) {
testutil.RequireSend(pt, geoIPCh, host, dnssvctest.Timeout)
return loc, nil geoIP := agdtest.NewGeoIP()
}, geoIP.OnData = func(host string, _ netip.Addr) (l *geoip.Location, err error) {
testutil.RequireSend(pt, geoIPCh, host, dnssvctest.Timeout)
return loc, nil
} }
fltStrg := &agdtest.FilterStorage{ fltStrg := &agdtest.FilterStorage{
@ -205,25 +169,38 @@ func newTestService(
}, },
} }
rl := &agdtest.RateLimit{ rl := agdtest.NewRateLimit()
OnIsRateLimited: func( rl.OnIsRateLimited = func(
_ context.Context, _ context.Context,
_ *dns.Msg, _ *dns.Msg,
_ netip.Addr, _ netip.Addr,
) (drop, allowlisted bool, err error) { ) (shouldDrop, isAllowlisted bool, err error) {
return true, false, nil return true, false, nil
},
OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {
panic("not implemented")
},
} }
testFltGrpID := agd.FilteringGroupID("1234") srvGrps := []*agd.ServerGroup{{
DDR: &agd.DDR{
Enabled: true,
},
TLS: &agd.TLS{
DeviceDomains: []string{dnssvctest.DomainForDevices},
},
Name: dnssvctest.ServerGroupName,
FilteringGroup: dnssvctest.FilteringGroupID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}}
c := &dnssvc.Config{ hdlrConf := &dnssvc.HandlersConfig{
BaseLogger: slogutil.NewDiscardLogger(), BaseLogger: slogutil.NewDiscardLogger(),
AccessManager: accessManager, Cache: &dnssvc.CacheConfig{
Messages: agdtest.NewConstructor(t), Type: dnssvc.CacheTypeNone,
},
StructuredErrors: agdtest.NewSDEConfig(true),
Cloner: agdtest.NewCloner(),
HumanIDParser: agd.NewHumanIDParser(),
Messages: agdtest.NewConstructor(t),
AccessManager: accessManager,
BillStat: &agdtest.BillStatRecorder{ BillStat: &agdtest.BillStatRecorder{
OnRecord: func( OnRecord: func(
_ context.Context, _ context.Context,
@ -235,42 +212,47 @@ func newTestService(
) { ) {
}, },
}, },
ProfileDB: db, CacheManager: agdcache.EmptyManager{},
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
DNSCheck: dnsCk, DNSCheck: dnsCk,
NonDNS: http.NotFoundHandler(),
DNSDB: dnsDB, DNSDB: dnsDB,
ErrColl: errColl, ErrColl: errColl,
FilterStorage: fltStrg, FilterStorage: fltStrg,
GeoIP: geoIP, GeoIP: geoIP,
Handler: dnsservertest.NewDefaultHandler(),
HashMatcher: hashprefix.NewMatcher(nil),
ProfileDB: profDB,
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
QueryLog: ql, QueryLog: ql,
RuleStat: ruleStat,
NewListener: newTestListenerFunc(tl),
Handler: dnsservertest.DefaultHandler(),
RateLimit: rl, RateLimit: rl,
RuleStat: ruleStat,
MetricsNamespace: path.Base(t.Name()), MetricsNamespace: path.Base(t.Name()),
FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{ FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{
testFltGrpID: { dnssvctest.FilteringGroupID: {
ID: testFltGrpID, ID: dnssvctest.FilteringGroupID,
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1}, RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
RuleListsEnabled: true, RuleListsEnabled: true,
}, },
}, },
ServerGroups: []*agd.ServerGroup{{ ServerGroups: srvGrps,
DDR: &agd.DDR{ EDEEnabled: true,
Enabled: true,
},
TLS: &agd.TLS{
DeviceDomains: []string{dnssvctest.DomainForDevices},
},
Name: testSrvGrpName,
FilteringGroup: testFltGrpID,
Servers: []*agd.Server{srv},
ProfilesEnabled: true,
}},
} }
svc, err := dnssvc.New(c) ctx := context.Background()
handlers, err := dnssvc.NewHandlers(ctx, hdlrConf)
require.NoError(t, err)
c := &dnssvc.Config{
Handlers: handlers,
NewListener: newTestListenerFunc(tl),
Cloner: agdtest.NewCloner(),
ErrColl: errColl,
NonDNS: http.NotFoundHandler(),
MetricsNamespace: path.Base(t.Name()),
ServerGroups: srvGrps,
HandleTimeout: dnssvctest.Timeout,
}
svc, err = dnssvc.New(c)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, svc) require.NotNil(t, svc)
@ -283,6 +265,7 @@ func newTestService(
return svc, srvAddr return svc, srvAddr
} }
// TODO(a.garipov): Refactor to test handlers separately from the service.
func TestService_Wrap(t *testing.T) { func TestService_Wrap(t *testing.T) {
profileDBCh := make(chan agd.DeviceID, 1) profileDBCh := make(chan agd.DeviceID, 1)
querylogCh := make(chan *querylog.Entry, 1) querylogCh := make(chan *querylog.Entry, 1)
@ -346,7 +329,7 @@ func TestService_Wrap(t *testing.T) {
TLSServerName: dnssvctest.DeviceIDSrvName, TLSServerName: dnssvctest.DeviceIDSrvName,
}) })
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req) err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
require.NoError(t, err) require.NoError(t, err)
resp := rw.Msg() resp := rw.Msg()
@ -420,7 +403,7 @@ func TestService_Wrap(t *testing.T) {
TLSServerName: dnssvctest.DeviceIDSrvName, TLSServerName: dnssvctest.DeviceIDSrvName,
}) })
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req) err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
require.NoError(t, err) require.NoError(t, err)
resp := rw.Msg() resp := rw.Msg()

View File

@ -80,35 +80,9 @@ func TestDefault_Find_plainAddrs(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDedicatedIP = newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr)
ctx context.Context, profDB.OnProfileByLinkedIP = newOnProfileByLinkedIP(dnssvctest.LinkedAddr)
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr),
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: newOnProfileByLinkedIP(dnssvctest.LinkedAddr),
}
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(), Logger: slogutil.NewDiscardLogger(),
@ -177,40 +151,8 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(), Logger: slogutil.NewDiscardLogger(),
@ -231,44 +173,12 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
func TestDefault_Find_deleted(t *testing.T) { func TestDefault_Find_deleted(t *testing.T) {
t.Parallel() t.Parallel()
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByLinkedIP = func(
ctx context.Context, _ context.Context,
id agd.ProfileID, _ netip.Addr,
humanID agd.HumanID, ) (p *agd.Profile, d *agd.Device, err error) {
devType agd.DeviceType, return profDeleted, devNormal, nil
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return profDeleted, devNormal, nil
},
} }
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
@ -291,44 +201,21 @@ func TestDefault_Find_byHumanID(t *testing.T) {
// device-type and profile data regardless of the case. // device-type and profile data regardless of the case.
extIDStr := "OTR-" + strings.ToUpper(dnssvctest.ProfileIDStr) + "-" + dnssvctest.HumanIDStr + "-!!!" extIDStr := "OTR-" + strings.ToUpper(dnssvctest.ProfileIDStr) + "-" + dnssvctest.HumanIDStr + "-!!!"
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnCreateAutoDevice = func(
ctx context.Context, ctx context.Context,
id agd.ProfileID, id agd.ProfileID,
humanID agd.HumanID, humanID agd.HumanID,
devType agd.DeviceType, devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) { ) (p *agd.Profile, d *agd.Device, err error) {
return profNormal, devAuto, nil return profNormal, devAuto, nil
}, }
profDB.OnProfileByHumanID = func(
OnProfileByDedicatedIP: func( _ context.Context,
_ context.Context, _ agd.ProfileID,
_ netip.Addr, _ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) { ) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented") return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
} }
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{

View File

@ -2,7 +2,6 @@ package devicefinder_test
import ( import (
"context" "context"
"net/netip"
"net/url" "net/url"
"path" "path"
"testing" "testing"
@ -88,48 +87,16 @@ func TestDefault_Find_DoHAuth(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = func(
ctx context.Context, _ context.Context,
id agd.ProfileID, devID agd.DeviceID,
humanID agd.HumanID, ) (p *agd.Profile, d *agd.Device, err error) {
devType agd.DeviceType, if tc.profDBDev != nil {
) (p *agd.Profile, d *agd.Device, err error) { return profNormal, tc.profDBDev, nil
panic("not implemented") }
},
OnProfileByDedicatedIP: func( return nil, nil, profiledb.ErrDeviceNotFound
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
if tc.profDBDev != nil {
return profNormal, tc.profDBDev, nil
}
return nil, nil, profiledb.ErrDeviceNotFound
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
} }
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
@ -220,44 +187,12 @@ func TestDefault_Find_DoHAuthOnly(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = func(
ctx context.Context, _ context.Context,
id agd.ProfileID, devID agd.DeviceID,
humanID agd.HumanID, ) (p *agd.Profile, d *agd.Device, err error) {
devType agd.DeviceType, return profNormal, tc.profDBDev, nil
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return profNormal, tc.profDBDev, nil
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
} }
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
@ -351,34 +286,9 @@ func TestDefault_Find_DoH(t *testing.T) {
name: "human_id_path_match", name: "human_id_path_match",
}} }}
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
ctx context.Context, profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -450,34 +360,9 @@ func TestDefault_Find_stdEncrypted(t *testing.T) {
deviceDomains: []string{dnssvctest.DomainForDevices}, deviceDomains: []string{dnssvctest.DomainForDevices},
}} }}
profDB := &agdtest.ProfileDB{ profDB := agdtest.NewProfileDB()
OnCreateAutoDevice: func( profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
ctx context.Context, profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
srvData := []struct { srvData := []struct {
srv *agd.Server srv *agd.Server

View File

@ -1,8 +1,6 @@
package devicefinder_test package devicefinder_test
import ( import (
"context"
"net/netip"
"testing" "testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
@ -49,53 +47,13 @@ func TestDefault_Find_humanID(t *testing.T) {
in: "otr-abcd1234-!!!", in: "otr-abcd1234-!!!",
}} }}
profDB := &agdtest.ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByDeviceID: func(
_ context.Context,
devID agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByHumanID: func(
_ context.Context,
_ agd.ProfileID,
_ agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
OnProfileByLinkedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
df := devicefinder.NewDefault(&devicefinder.Config{ df := devicefinder.NewDefault(&devicefinder.Config{
Logger: slogutil.NewDiscardLogger(), Logger: slogutil.NewDiscardLogger(),
ProfileDB: profDB, ProfileDB: agdtest.NewProfileDB(),
HumanIDParser: agd.NewHumanIDParser(), HumanIDParser: agd.NewHumanIDParser(),
Server: srvDoT, Server: srvDoT,
DeviceDomains: []string{dnssvctest.DomainForDevices}, DeviceDomains: []string{dnssvctest.DomainForDevices},

View File

@ -55,8 +55,16 @@ const (
DomainRewrittenCNAMEFQDN = DomainRewrittenCNAME + "." DomainRewrittenCNAMEFQDN = DomainRewrittenCNAME + "."
) )
// ServerName is the common server name for tests. const (
const ServerName agd.ServerName = "test_server_dns_tls" // FilteringGroupID is the common filtering-group ID for tests.
FilteringGroupID agd.FilteringGroupID = "test_filtering_group"
// ServerName is the common server name for tests.
ServerName agd.ServerName = "test_server_dns_tls"
// ServerGroupName is the common server-group name for tests.
ServerGroupName agd.ServerGroupName = "test_server_group"
)
const ( const (
// DomainForDevices is the upper-level domain name for requests with device // DomainForDevices is the upper-level domain name for requests with device

View File

@ -29,8 +29,8 @@ const (
// Resolvers for querying the resolver with unknown or absent name. // Resolvers for querying the resolver with unknown or absent name.
DDRDomain = DDRLabel + "." + ResolverARPADomain DDRDomain = DDRLabel + "." + ResolverARPADomain
// FirefoxCanaryHost is the hostname that Firefox uses to check if it // FirefoxCanaryHost is the hostname that Firefox uses to check if it should
// should use its own DNS-over-HTTPS settings. // use its own DNS-over-HTTPS settings.
// //
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https. // See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
FirefoxCanaryHost = "use-application-dns.net" FirefoxCanaryHost = "use-application-dns.net"
@ -56,6 +56,11 @@ func (mw *Middleware) reqInfoSpecialHandler(
return nil, "" return nil, ""
} }
// As per RFC-9462 section 6.4, resolvers SHOULD respond to queries of any
// type other than SVCB for _dns.resolver.arpa. with NODATA and queries of
// any type for any domain name under resolver.arpa with NODATA.
//
// TODO(e.burkov): Consider adding SOA records for these NODATA responses.
if mw.isDDRRequest(ri) { if mw.isDDRRequest(ri) {
if _, ok := ri.DeviceResult.(*agd.DeviceResultAuthenticationFailure); ok { if _, ok := ri.DeviceResult.(*agd.DeviceResultAuthenticationFailure); ok {
return mw.handleDDRNoData, "ddr_doh" return mw.handleDDRNoData, "ddr_doh"
@ -84,8 +89,6 @@ type reqInfoHandlerFunc func(
ri *agd.RequestInfo, ri *agd.RequestInfo,
) (err error) ) (err error)
// DDR And Resolver ARPA Domain
// isDDRRequest determines if the message is the request for Discovery of // isDDRRequest determines if the message is the request for Discovery of
// Designated Resolvers as defined by the RFC draft. The request is considered // Designated Resolvers as defined by the RFC draft. The request is considered
// ARPA if the requested host is a subdomain of resolver.arpa SUDN. // ARPA if the requested host is a subdomain of resolver.arpa SUDN.
@ -141,8 +144,8 @@ func isDDRDomain(ri *agd.RequestInfo, host string) (ok bool) {
return false return false
} }
// handleDDR checks if the request is for the Discovery of Designated Resolvers // handleDDR responds to Discovery of Designated Resolvers (DDR) queries with a
// and writes a response if needed. // response containing the designated resolvers.
func (mw *Middleware) handleDDR( func (mw *Middleware) handleDDR(
ctx context.Context, ctx context.Context,
rw dnsserver.ResponseWriter, rw dnsserver.ResponseWriter,
@ -157,11 +160,11 @@ func (mw *Middleware) handleDDR(
return rw.WriteMsg(ctx, req, mw.newRespDDR(req, ri)) return rw.WriteMsg(ctx, req, mw.newRespDDR(req, ri))
} }
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req)) return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
} }
// handleDDRNoData processes DDR (Discovery of Designated Resolvers) requests // handleDDRNoData responds to Discovery of Designated Resolvers (DDR) queries
// for devices which need NODATA response and writes the response if needed. // with a NODATA response.
func (mw *Middleware) handleDDRNoData( func (mw *Middleware) handleDDRNoData(
ctx context.Context, ctx context.Context,
rw dnsserver.ResponseWriter, rw dnsserver.ResponseWriter,
@ -173,17 +176,17 @@ func (mw *Middleware) handleDDRNoData(
metrics.DNSSvcDDRRequestsTotal.Inc() metrics.DNSSvcDDRRequestsTotal.Inc()
if ri.ServerGroup.DDR.Enabled { if ri.ServerGroup.DDR.Enabled {
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNODATA(req)) return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeSuccess))
} }
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req)) return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
} }
// newRespDDR returns a new Discovery of Designated Resolvers response copying // newRespDDR returns a new Discovery of Designated Resolvers response copying
// it from the prebuilt templates in srvGrp and modifying it in accordance with // it from the prebuilt templates in srvGrp and modifying it in accordance with
// the request data. req must not be nil. // the request data. req must not be nil.
func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.Msg) { func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.Msg) {
resp = ri.Messages.NewRespMsg(req) resp = ri.Messages.NewResp(req)
name := req.Question[0].Name name := req.Question[0].Name
ddr := ri.ServerGroup.DDR ddr := ri.ServerGroup.DDR
@ -210,7 +213,8 @@ func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.M
return resp return resp
} }
// handleBadResolverARPA writes a NODATA response. // handleBadResolverARPA responds to badly formed resolver.arpa queries with a
// NODATA response.
func (mw *Middleware) handleBadResolverARPA( func (mw *Middleware) handleBadResolverARPA(
ctx context.Context, ctx context.Context,
rw dnsserver.ResponseWriter, rw dnsserver.ResponseWriter,
@ -219,7 +223,8 @@ func (mw *Middleware) handleBadResolverARPA(
) (err error) { ) (err error) {
metrics.DNSSvcBadResolverARPA.Inc() metrics.DNSSvcBadResolverARPA.Inc()
err = rw.WriteMsg(ctx, req, ri.Messages.NewRespMsg(req)) resp := ri.Messages.NewRespRCode(req, dns.RcodeSuccess)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host) return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
} }
@ -257,8 +262,6 @@ func (mw *Middleware) specialDomainHandler(
return nil, "" return nil, ""
} }
// Apple Private Relay
// shouldBlockPrivateRelay returns true if the query is for an Apple Private // shouldBlockPrivateRelay returns true if the query is for an Apple Private
// Relay check domain and the request information or profile indicates that // Relay check domain and the request information or profile indicates that
// Apple Private Relay should be blocked. // Apple Private Relay should be blocked.
@ -280,13 +283,12 @@ func (mw *Middleware) handlePrivateRelay(
) (err error) { ) (err error) {
metrics.DNSSvcApplePrivateRelayRequestsTotal.Inc() metrics.DNSSvcApplePrivateRelayRequestsTotal.Inc()
err = rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req)) resp := ri.Messages.NewRespRCode(req, dns.RcodeNameError)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing private relay resp: %w") return errors.Annotate(err, "writing private relay resp: %w")
} }
// Firefox canary domain
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary // shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
// domain and the request information or profile indicates that Firefox canary // domain and the request information or profile indicates that Firefox canary
// domain should be blocked. // domain should be blocked.
@ -298,8 +300,8 @@ func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool)
return ri.FilteringGroup.BlockFirefoxCanary return ri.FilteringGroup.BlockFirefoxCanary
} }
// handleFirefoxCanary checks if the request is for the fully-qualified domain // handleFirefoxCanary responds to Firefox canary domain queries with a REFUSED
// name that Firefox uses to check DoH settings and writes a response if needed. // response.
func (mw *Middleware) handleFirefoxCanary( func (mw *Middleware) handleFirefoxCanary(
ctx context.Context, ctx context.Context,
rw dnsserver.ResponseWriter, rw dnsserver.ResponseWriter,
@ -308,7 +310,7 @@ func (mw *Middleware) handleFirefoxCanary(
) (err error) { ) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc() metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := ri.Messages.NewMsgREFUSED(req) resp := ri.Messages.NewRespRCode(req, dns.RcodeRefused)
err = rw.WriteMsg(ctx, req, resp) err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w") return errors.Annotate(err, "writing firefox canary resp: %w")

View File

@ -44,7 +44,9 @@ func TestMiddleware_writeDebugResponse(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner, Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL, FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@ -162,7 +162,7 @@ func (mw *Middleware) setFilteredResponse(
mw.setFilteredResponseNoReq(ctx, fctx, ri) mw.setFilteredResponseNoReq(ctx, fctx, ri)
case *filter.ResultBlocked: case *filter.ResultBlocked:
var err error var err error
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest) fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
if err != nil { if err != nil {
mw.reportf(ctx, "creating blocked resp for filtered req: %w", err) mw.reportf(ctx, "creating blocked resp for filtered req: %w", err)
fctx.filteredResponse = fctx.originalResponse fctx.filteredResponse = fctx.originalResponse
@ -199,7 +199,7 @@ func (mw *Middleware) setFilteredResponseNoReq(
fctx.filteredResponse = fctx.originalResponse fctx.filteredResponse = fctx.originalResponse
case *filter.ResultBlocked: case *filter.ResultBlocked:
var err error var err error
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest) fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
if err != nil { if err != nil {
mw.reportf(ctx, "creating blocked resp for filtered resp: %w", err) mw.reportf(ctx, "creating blocked resp for filtered resp: %w", err)
fctx.filteredResponse = fctx.originalResponse fctx.filteredResponse = fctx.originalResponse

View File

@ -4,6 +4,7 @@ package mainmw
import ( import (
"context" "context"
"log/slog"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
@ -14,7 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll" "github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat" "github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -24,14 +25,15 @@ import (
// Middleware is the main middleware of AdGuard DNS. // Middleware is the main middleware of AdGuard DNS.
type Middleware struct { type Middleware struct {
messages *dnsmsg.Constructor
cloner *dnsmsg.Cloner cloner *dnsmsg.Cloner
fltCtxPool *syncutil.Pool[filteringContext] fltCtxPool *syncutil.Pool[filteringContext]
metrics Metrics logger *slog.Logger
messages *dnsmsg.Constructor
billStat billstat.Recorder billStat billstat.Recorder
errColl errcoll.Interface errColl errcoll.Interface
fltStrg filter.Storage fltStrg filter.Storage
geoIP geoip.Interface geoIP geoip.Interface
metrics Metrics
queryLog querylog.Interface queryLog querylog.Interface
ruleStat rulestat.Interface ruleStat rulestat.Interface
} }
@ -39,19 +41,17 @@ type Middleware struct {
// Config is the configuration structure for the main middleware. All fields // Config is the configuration structure for the main middleware. All fields
// must be non-nil. // must be non-nil.
type Config struct { type Config struct {
// Metrics is used to collect the statistics. // Cloner is used to clone messages more efficiently by disposing of parts
Metrics Metrics // of DNS responses for later reuse.
Cloner *dnsmsg.Cloner
// Logger is used to log the operation of the middleware.
Logger *slog.Logger
// Messages is the message constructor used to create blocked and other // Messages is the message constructor used to create blocked and other
// messages for this middleware. // messages for this middleware.
Messages *dnsmsg.Constructor Messages *dnsmsg.Constructor
// Cloner is used to clone messages more efficiently by disposing of parts
// of DNS responses for later reuse.
//
// TODO(a.garipov): Use.
Cloner *dnsmsg.Cloner
// BillStat is used to collect billing statistics. // BillStat is used to collect billing statistics.
BillStat billstat.Recorder BillStat billstat.Recorder
@ -66,6 +66,9 @@ type Config struct {
// addresses in requests and responses. // addresses in requests and responses.
GeoIP geoip.Interface GeoIP geoip.Interface
// Metrics is used to collect the statistics.
Metrics Metrics
// QueryLog is used to write the logs into. // QueryLog is used to write the logs into.
QueryLog querylog.Interface QueryLog querylog.Interface
@ -77,16 +80,17 @@ type Config struct {
// New returns a new main middleware. c must not be nil. // New returns a new main middleware. c must not be nil.
func New(c *Config) (mw *Middleware) { func New(c *Config) (mw *Middleware) {
return &Middleware{ return &Middleware{
metrics: c.Metrics, cloner: c.Cloner,
messages: c.Messages,
cloner: c.Cloner,
fltCtxPool: syncutil.NewPool(func() (v *filteringContext) { fltCtxPool: syncutil.NewPool(func() (v *filteringContext) {
return &filteringContext{} return &filteringContext{}
}), }),
logger: c.Logger,
messages: c.Messages,
billStat: c.BillStat, billStat: c.BillStat,
errColl: c.ErrColl, errColl: c.ErrColl,
fltStrg: c.FilterStorage, fltStrg: c.FilterStorage,
geoIP: c.GeoIP, geoIP: c.GeoIP,
metrics: c.Metrics,
queryLog: c.QueryLog, queryLog: c.QueryLog,
ruleStat: c.RuleStat, ruleStat: c.RuleStat,
} }
@ -106,8 +110,20 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
defer mw.fltCtxPool.Put(fctx) defer mw.fltCtxPool.Put(fctx)
ri := agd.MustRequestInfoFromContext(ctx) ri := agd.MustRequestInfoFromContext(ctx)
optlog.Debug2("processing request %q from %s", ri.ID, ri.RemoteIP) optslog.Debug2(
defer optlog.Debug2("finished processing request %q from %s", ri.ID, ri.RemoteIP) ctx,
mw.logger,
"processing request",
"req_id", ri.ID,
"remote_ip", ri.RemoteIP,
)
defer optslog.Debug2(
ctx,
mw.logger,
"finished processing request",
"req_id", ri.ID,
"remote_ip", ri.RemoteIP,
)
flt := mw.fltStrg.FilterFromContext(ctx, ri) flt := mw.fltStrg.FilterFromContext(ctx, ri)
mw.filterRequest(ctx, fctx, flt, ri) mw.filterRequest(ctx, fctx, flt, ri)
@ -184,10 +200,12 @@ func (mw *Middleware) nextParams(
ctx = agd.ContextWithRequestInfo(ctx, modReqInfo) ctx = agd.ContextWithRequestInfo(ctx, modReqInfo)
optlog.Debug2( optslog.Debug2(
"mainmw: request for %q rewritten to %q by CNAME rewrite rule", ctx,
ri.Host, mw.logger,
modReqInfo.Host, "request rewritten by cname rewrite rule",
"orig_host", ri.Host,
"mod_host", modReqInfo.Host,
) )
return ctx, origRW, modReq return ctx, origRW, modReq

View File

@ -17,17 +17,13 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// Common constants for tests. // Common constants for tests.
const ( const (
testASN geoip.ASN = 12345 testASN geoip.ASN = 12345
@ -121,24 +117,17 @@ func TestMiddleware_Wrap(t *testing.T) {
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") }, OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
} }
geoIP := &agdtest.GeoIP{ geoIP := agdtest.NewGeoIP()
OnSubnetByLocation: func( geoIP.OnData = func(host string, addr netip.Addr) (l *geoip.Location, err error) {
_ *geoip.Location, pt := testutil.PanicT{}
_ netutil.AddrFamily, require.Equal(pt, dnssvctest.Domain, host)
) (n netip.Prefix, err error) { if addr.Is4() {
panic("not implemented") require.Equal(pt, addr, testRespAddr4)
}, } else if addr.Is6() {
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) { require.Equal(pt, addr, testRespAddr6)
pt := testutil.PanicT{} }
require.Equal(pt, dnssvctest.Domain, host)
if addr.Is4() {
require.Equal(pt, addr, testRespAddr4)
} else if addr.Is6() {
require.Equal(pt, addr, testRespAddr6)
}
return nil, nil return nil, nil
},
} }
ruleStat := &agdtest.RuleStat{ ruleStat := &agdtest.RuleStat{
@ -153,7 +142,9 @@ func TestMiddleware_Wrap(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner, Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL, FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -227,13 +218,14 @@ func TestMiddleware_Wrap(t *testing.T) {
} }
c := &mainmw.Config{ c := &mainmw.Config{
Metrics: mainmw.EmptyMetrics{},
Messages: msgs,
Cloner: cloner, Cloner: cloner,
Logger: slogutil.NewDiscardLogger(),
Messages: msgs,
BillStat: tc.billStat, BillStat: tc.billStat,
ErrColl: agdtest.NewErrorCollector(), ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg, FilterStorage: fltStrg,
GeoIP: geoIP, GeoIP: geoIP,
Metrics: mainmw.EmptyMetrics{},
QueryLog: queryLog, QueryLog: queryLog,
RuleStat: ruleStat, RuleStat: ruleStat,
} }
@ -411,16 +403,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
} }
) )
geoIP := &agdtest.GeoIP{ geoIP := agdtest.NewGeoIP()
OnSubnetByLocation: func( geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
_ *geoip.Location, return nil, nil
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic("not implemented")
},
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) {
return nil, nil
},
} }
var ( var (
@ -537,7 +522,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{ msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
Cloner: cloner, Cloner: cloner,
BlockingMode: &dnsmsg.BlockingModeNullIP{}, BlockingMode: &dnsmsg.BlockingModeNullIP{},
StructuredErrors: agdtest.NewSDEConfig(true),
FilteredResponseTTL: agdtest.FilteredResponseTTL, FilteredResponseTTL: agdtest.FilteredResponseTTL,
EDEEnabled: true,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -701,13 +688,14 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
} }
c := &mainmw.Config{ c := &mainmw.Config{
Metrics: mainmw.EmptyMetrics{},
Messages: msgs,
Cloner: cloner, Cloner: cloner,
Logger: slogutil.NewDiscardLogger(),
Messages: msgs,
BillStat: tc.billStat, BillStat: tc.billStat,
ErrColl: agdtest.NewErrorCollector(), ErrColl: agdtest.NewErrorCollector(),
FilterStorage: fltStrg, FilterStorage: fltStrg,
GeoIP: geoIP, GeoIP: geoIP,
Metrics: mainmw.EmptyMetrics{},
QueryLog: queryLog, QueryLog: queryLog,
RuleStat: ruleStat, RuleStat: ruleStat,
} }

View File

@ -12,7 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -112,8 +112,7 @@ func (mw *Middleware) responseCountry(
} }
ctry = mw.country(ctx, host, respIP) ctry = mw.country(ctx, host, respIP)
// TODO(a.garipov): Use optslog.Trace2. optslog.Trace2(ctx, mw.logger, "geoip for resp", "ctry", ctry, "resp_ip", respIP)
optlog.Debug2("mainmw: got ctry %q for resp ip %v", ctry, respIP)
return ctry return ctry
} }

View File

@ -16,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat" "github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -134,20 +134,13 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
Country: testCtry, Country: testCtry,
} }
geoIP := &agdtest.GeoIP{ geoIP := agdtest.NewGeoIP()
OnSubnetByLocation: func( geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
_ *geoip.Location, if !tc.wantGeoIP {
_ netutil.AddrFamily, t.Error("unexpected call to geoip")
) (n netip.Prefix, err error) { }
panic("not implemented")
},
OnData: func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
if !tc.wantGeoIP {
t.Error("unexpected call to geoip")
}
return loc, nil return loc, nil
},
} }
queryLogCalled := false queryLogCalled := false
@ -164,6 +157,7 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
} }
mw := &Middleware{ mw := &Middleware{
logger: slogutil.NewDiscardLogger(),
billStat: billstat.EmptyRecorder{}, billStat: billstat.EmptyRecorder{},
geoIP: geoIP, geoIP: geoIP,
queryLog: queryLog, queryLog: queryLog,

View File

@ -5,15 +5,16 @@ package preservice
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optslog"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -22,19 +23,18 @@ import (
// names that may be filtered by safe browsing or parental control filters as // names that may be filtered by safe browsing or parental control filters as
// well as handling of the DNS-server check queries. // well as handling of the DNS-server check queries.
type Middleware struct { type Middleware struct {
// messages is used to construct TXT responses. logger *slog.Logger
messages *dnsmsg.Constructor messages *dnsmsg.Constructor
// hashMatcher is the safe browsing DNS hashMatcher.
hashMatcher filter.HashMatcher hashMatcher filter.HashMatcher
checker dnscheck.Interface
// checker is used to detect and process DNS-check requests.
checker dnscheck.Interface
} }
// Config is the configurational structure for the preservice middleware. All // Config is the configurational structure for the preservice middleware. All
// fields must be non-nil. // fields must be non-nil.
type Config struct { type Config struct {
// Logger is used to log the operation of the middleware.
Logger *slog.Logger
// Messages is used to construct TXT responses. // Messages is used to construct TXT responses.
Messages *dnsmsg.Constructor Messages *dnsmsg.Constructor
@ -48,6 +48,7 @@ type Config struct {
// New returns a new preservice middleware. c must not be nil. // New returns a new preservice middleware. c must not be nil.
func New(c *Config) (mw *Middleware) { func New(c *Config) (mw *Middleware) {
return &Middleware{ return &Middleware{
logger: c.Logger,
messages: c.Messages, messages: c.Messages,
hashMatcher: c.HashMatcher, hashMatcher: c.HashMatcher,
checker: c.Checker, checker: c.Checker,
@ -60,7 +61,7 @@ var _ dnsserver.Middleware = (*Middleware)(nil)
// Wrap implements the [dnsserver.Middleware] interface for *Middleware. // Wrap implements the [dnsserver.Middleware] interface for *Middleware.
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) { func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) { f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
defer func() { err = errors.Annotate(err, "preservice mw: %w") }() defer func() { err = errors.Annotate(err, "presvcmw: %w") }()
ri := agd.MustRequestInfoFromContext(ctx) ri := agd.MustRequestInfoFromContext(ctx)
if ri.QType == dns.TypeTXT { if ri.QType == dns.TypeTXT {
@ -71,13 +72,20 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
resp, err := mw.checker.Check(ctx, req, ri) resp, err := mw.checker.Check(ctx, req, ri)
if err != nil { if err != nil {
return fmt.Errorf("calling dnscheck: %w", err) return fmt.Errorf("calling dnscheck: %w", err)
} else if resp != nil {
return errors.Annotate(rw.WriteMsg(ctx, req, resp), "writing dnscheck response: %w")
} }
// Don't wrap the error, because this is the main flow, and there is if resp == nil {
// already [errors.Annotate] here. // Don't wrap the error, because this is the main flow, and there is
return next.ServeDNS(ctx, rw, req) // already [errors.Annotate] here.
return next.ServeDNS(ctx, rw, req)
}
err = rw.WriteMsg(ctx, req, resp)
if err != nil {
return fmt.Errorf("writing dnscheck response: %w", err)
}
return nil
} }
return dnsserver.HandlerFunc(f) return dnsserver.HandlerFunc(f)
@ -92,15 +100,19 @@ func (mw *Middleware) respondWithHashes(
req *dns.Msg, req *dns.Msg,
ri *agd.RequestInfo, ri *agd.RequestInfo,
) (err error) { ) (err error) {
optlog.Debug1("preservice mw: safe browsing: got txt req for %q", ri.Host) optslog.Debug1(ctx, mw.logger, "got txt req", "host", ri.Host)
hashes, matched, err := mw.hashMatcher.MatchByPrefix(ctx, ri.Host) hashes, matched, err := mw.hashMatcher.MatchByPrefix(ctx, ri.Host)
if err != nil { if err != nil {
// Don't return or collect this error to prevent DDoS of the error // Don't return or collect this error to prevent DDoS of the error
// collector by sending bad requests. // collector by sending bad requests.
log.Error("preservice mw: safe browsing: matching hashes: %s", err) mw.logger.ErrorContext(
ctx,
"matching hashes",
slogutil.KeyError, err,
)
resp := mw.messages.NewMsgREFUSED(req) resp := mw.messages.NewRespRCode(req, dns.RcodeRefused)
err = rw.WriteMsg(ctx, req, resp) err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing refused response: %w") return errors.Annotate(err, "writing refused response: %w")
@ -110,7 +122,7 @@ func (mw *Middleware) respondWithHashes(
return next.ServeDNS(ctx, rw, req) return next.ServeDNS(ctx, rw, req)
} }
resp, err := mw.messages.NewTXTRespMsg(req, hashes...) resp, err := mw.messages.NewRespTXT(req, hashes...)
if err != nil { if err != nil {
// Technically should never happen since the only error that could arise // Technically should never happen since the only error that could arise
// in [dnsmsg.Constructor.NewTXTRespMsg] is the one about request type // in [dnsmsg.Constructor.NewTXTRespMsg] is the one about request type
@ -118,7 +130,7 @@ func (mw *Middleware) respondWithHashes(
return fmt.Errorf("creating safe browsing result: %w", err) return fmt.Errorf("creating safe browsing result: %w", err)
} }
optlog.Debug1("preservice mw: safe browsing: writing hashes %q", hashes) optslog.Debug1(ctx, mw.logger, "writing hashes", "hashes", hashes)
err = rw.WriteMsg(ctx, req, resp) err = rw.WriteMsg(ctx, req, resp)
if err != nil { if err != nil {

Some files were not shown because too many files have changed in this diff Show More