mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
Sync v2.10.0
This commit is contained in:
parent
da0cb6fd0e
commit
87137bddcf
88
CHANGELOG.md
88
CHANGELOG.md
@ -7,6 +7,94 @@ The format is **not** based on [Keep a Changelog][kec], since the project **does
|
||||
[kec]: https://keepachangelog.com/en/1.0.0/
|
||||
[sem]: https://semver.org/spec/v2.0.0.html
|
||||
|
||||
## AGDNS-2484/ Build 886
|
||||
|
||||
- Property `type` of the `ratelimit` object has been moved to the underlying `allowlist` object. So replace this:
|
||||
|
||||
```yaml
|
||||
ratelimit:
|
||||
type: 'consul'
|
||||
# …
|
||||
allowlist:
|
||||
# …
|
||||
```
|
||||
|
||||
with this:
|
||||
|
||||
```yaml
|
||||
ratelimit:
|
||||
# …
|
||||
allowlist:
|
||||
type: 'consul'
|
||||
# …
|
||||
```
|
||||
|
||||
## AGDNS-2443 / Build 877
|
||||
|
||||
- The object `filters` has new properties: `ede_enabled`, and `sde_enabled`. So replace this:
|
||||
|
||||
```yaml
|
||||
filters:
|
||||
# …
|
||||
```
|
||||
|
||||
with this:
|
||||
|
||||
```yaml
|
||||
filters:
|
||||
# …
|
||||
ede_enabled: true
|
||||
sde_enabled: true
|
||||
```
|
||||
|
||||
## AGDNS-2456 / Build 873
|
||||
|
||||
- The environment variables `BACKEND_RATELIMIT_URL` and `BACKEND_RATELIMIT_API_KEY` have been added.
|
||||
|
||||
- Added the `type` property within the `ratelimit` object. So add it:
|
||||
|
||||
```yaml
|
||||
ratelimit:
|
||||
type: 'consul'
|
||||
# …
|
||||
```
|
||||
|
||||
## AGDNS-2431 / Build 872
|
||||
|
||||
- The objects `ratelimit.ipv4` and `ratelimit.ipv6` have been modified. Its `rps` properties have been replaced with the new properties `count` and `interval`. So replace this:
|
||||
|
||||
```yaml
|
||||
ratelimit:
|
||||
# …
|
||||
ipv4:
|
||||
rps: 30
|
||||
ipv6:
|
||||
rps: 300
|
||||
```
|
||||
|
||||
with this:
|
||||
|
||||
```yaml
|
||||
ratelimit:
|
||||
# …
|
||||
ipv4:
|
||||
# …
|
||||
count: 300
|
||||
interval: 10s
|
||||
ipv6:
|
||||
# …
|
||||
count: 3000
|
||||
interval: 10s
|
||||
```
|
||||
|
||||
Adjust the value and add new ones, if necessary.
|
||||
|
||||
## AGDNS-2457 / Build 871
|
||||
|
||||
- The environment variables `DNSCHECK_REMOTEKV_URL` and `DNSCHECK_REMOTEKV_API_KEY` have been added.
|
||||
|
||||
- The property `kv.type` within the `check` object now supports the `backend` value.
|
||||
|
||||
## AGDNS-2468 / Build 869
|
||||
|
||||
- The environment variable `PROFILES_MAX_RESP_SIZE` has been added. It sets the maximum size of the response from the profiles endpoint of the backend API. The default value is `8MB`.
|
||||
|
2
Makefile
2
Makefile
@ -24,7 +24,7 @@ BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)}
|
||||
GOAMD64 = v1
|
||||
GOPROXY = https://proxy.golang.org|direct
|
||||
GOTELEMETRY = off
|
||||
GOTOOLCHAIN = go1.23.1
|
||||
GOTOOLCHAIN = go1.23.2
|
||||
RACE = 0
|
||||
REVISION = $${REVISION:-$$(git rev-parse --short HEAD)}
|
||||
VERSION = 0
|
||||
|
@ -10,15 +10,19 @@ ratelimit:
|
||||
response_size_estimate: 1KB
|
||||
# Rate limit options for IPv4 addresses.
|
||||
ipv4:
|
||||
# Rate of requests per second for one subnet for IPv4 addresses.
|
||||
rps: 30
|
||||
# Requests per configured interval for one subnet for IPv4 addresses.
|
||||
count: 300
|
||||
# The time during which to count the number of requests.
|
||||
interval: 10s
|
||||
# The lengths of the subnet prefixes used to calculate rate limiter
|
||||
# bucket keys for IPv4 addresses.
|
||||
subnet_key_len: 24
|
||||
# Rate limit options for IPv6 addresses.
|
||||
ipv6:
|
||||
# Rate of requests per second for one subnet for IPv6 addresses.
|
||||
rps: 300
|
||||
# Requests per configured interval for one subnet for IPv6 addresses.
|
||||
count: 3000
|
||||
# The time during which to count the number of requests.
|
||||
interval: 10s
|
||||
# The lengths of the subnet prefixes used to calculate rate limiter
|
||||
# bucket keys for IPv6 addresses.
|
||||
subnet_key_len: 48
|
||||
@ -40,6 +44,9 @@ ratelimit:
|
||||
- '127.0.0.1/24'
|
||||
# Time between two updates of allow list.
|
||||
refresh_interval: 1h
|
||||
# Defines where the rate limiting settings are received from. Allowed
|
||||
# values are "backend" and "consul".
|
||||
type: 'consul'
|
||||
|
||||
# Configuration for the stream connection limiting.
|
||||
connection_limit:
|
||||
@ -167,8 +174,8 @@ geoip:
|
||||
check:
|
||||
# Domains to use for DNS checking.
|
||||
kv:
|
||||
# Defines the type of remote kay-value storage. Allowed values are
|
||||
# "consul" and "redis".
|
||||
# Defines the type of remote key-value storage. Allowed values are
|
||||
# "backend", "consul", and "redis".
|
||||
type: 'consul'
|
||||
# For how long to keep the information about the client.
|
||||
ttl: 30s
|
||||
@ -313,6 +320,10 @@ filters:
|
||||
enabled: true
|
||||
# The size of the LRU cache of rule-list filtering results.
|
||||
size: 10000
|
||||
# Enable the Extended DNS Errors feature.
|
||||
ede_enabled: true
|
||||
# Enable the Structured DNS Errors feature. Requires ede_enabled: true.
|
||||
sde_enabled: true
|
||||
|
||||
# Filtering groups are a set of different filtering configurations. These
|
||||
# filtering configurations are then used by server_groups.
|
||||
|
@ -104,9 +104,13 @@ The `ratelimit` object has the following properties:
|
||||
|
||||
- <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>: The ipv4 configuration object. It has the following fields:
|
||||
|
||||
- <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>: The rate of requests per second for one subnet. Requests above this are counted in the backoff count.
|
||||
- <a href="#ratelimit-ipv4-count" id="ratelimit-ipv4-count" name="ratelimit-ipv4-count">`count`</a>: Requests per configured interval for one subnet for IPv4 addresses. Requests above this are counted in the backoff count.
|
||||
|
||||
**Example:** `30`.
|
||||
**Example:** `300`.
|
||||
|
||||
- <a href="#ratelimit-ipv4-interval" id="ratelimit-ipv4-interval" name="ratelimit-ipv4-interval">`interval`</a>: The time during which to count the number of requests.
|
||||
|
||||
**Example:** `10s`.
|
||||
|
||||
- <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>: The length of the subnet prefix used to calculate rate limiter bucket keys.
|
||||
|
||||
@ -134,7 +138,11 @@ The `ratelimit` object has the following properties:
|
||||
|
||||
**Example:** `30s`.
|
||||
|
||||
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, and `ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`.
|
||||
- <a href="#ratelimit-allowlist-type" id="ratelimit-allowlist-type" name="ratelimit-allowlist-type">`type`</a>: Defines where the rate limit settings are received from. Allowed values are `backend` and `consul`.
|
||||
|
||||
**Example:** `consul`.
|
||||
|
||||
For example, if `backoff_period` is `1m`, `backoff_count` is `10`, `ipv4-count` is `5`, and `ipv4-interval` is `1s`, a client (meaning all IP addresses within the subnet defined by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `backoff_duration`.
|
||||
|
||||
### <a href="#ratelimit-connection_limit" id="ratelimit-connection_limit" name="ratelimit-connection_limit">Stream connection limit</a>
|
||||
|
||||
@ -370,12 +378,14 @@ The `check` object has the following properties:
|
||||
|
||||
- <a href="#check_kv" id="check_kv" name="check_kv">`kv`</a>: Remote key-value storage settings. It has the following properties:
|
||||
|
||||
- <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `consul` and `redis`.
|
||||
- <a href="#check-kv-type" id="check-kv-type" name="check-kv-type">`type`</a>: Type of the remote KV storage. Allowed values are `backend`, `consul`, and `redis`.
|
||||
|
||||
**Example:** `consul`.
|
||||
|
||||
- <a href="#check-kv-ttl" id="check-kv-ttl" name="check-kv-ttl">`ttl`</a>: For how long to keep the information about a single user in remote KV, as a human-readable duration.
|
||||
|
||||
For `backend`, the TTL must be greater than `0s`.
|
||||
|
||||
For `consul`, the TTL must be between `10s` and `1d`. Note that the actual TTL can be up to twice as long.
|
||||
|
||||
For `redis`, the TTL must be greater than or equal to `1ms`.
|
||||
@ -592,6 +602,14 @@ The `filters` object has the following properties:
|
||||
|
||||
**Example:** `10000`.
|
||||
|
||||
- <a href="#filters-ede_enabled" id="filters-ede_enabled" name="filters-ede_enabled">`ede_enabled`</a>: Shows if Extended DNS Error codes should be added.
|
||||
|
||||
**Example:** `true`.
|
||||
|
||||
- <a href="#filters-sde_enabled" id="filters-sde_enabled" name="filters-sde_enabled">`sde_enabled`</a>: Shows if the experimental Structured DNS Errors feature should be enabled. `ede_enabled` must be `true` to enable SDE.
|
||||
|
||||
**Example:** `true`.
|
||||
|
||||
[env-blocked_services]: environment.md#BLOCKED_SERVICE_INDEX_URL
|
||||
|
||||
## <a href="#filtering_groups" id="filtering_groups" name="filtering_groups">Filtering groups</a>
|
||||
|
@ -159,7 +159,7 @@ You'll need to supply the following:
|
||||
|
||||
See the [external HTTP API documentation][externalhttp].
|
||||
|
||||
You may use `go run ./scripts/backend` to start mock GRPC server for `BILLSTAT_URL` and `PROFILES_URL` endpoints.
|
||||
You may use `go run ./scripts/backend` to start mock GRPC server for `BACKEND_PROFILES_URL`, `BILLSTAT_URL`, `DNSCHECK_REMOTEKV_URL`, and `PROFILES_URL` endpoints.
|
||||
|
||||
You may need to change the listen ports in `config.yaml` which are less than 1024 to some other ports. Otherwise, `sudo` or `doas` is required to run `AdGuardDNS`.
|
||||
|
||||
|
@ -6,6 +6,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
|
||||
|
||||
- [`ADULT_BLOCKING_ENABLED`](#ADULT_BLOCKING_ENABLED)
|
||||
- [`ADULT_BLOCKING_URL`](#ADULT_BLOCKING_URL)
|
||||
- [`BACKEND_RATELIMIT_API_KEY`](#BACKEND_RATELIMIT_API_KEY)
|
||||
- [`BACKEND_RATELIMIT_URL`](#BACKEND_RATELIMIT_URL)
|
||||
- [`BILLSTAT_API_KEY`](#BILLSTAT_API_KEY)
|
||||
- [`BILLSTAT_URL`](#BILLSTAT_URL)
|
||||
- [`BLOCKED_SERVICE_ENABLED`](#BLOCKED_SERVICE_ENABLED)
|
||||
@ -14,6 +16,8 @@ AdGuard DNS uses [environment variables][wiki-env] to store some of the more sen
|
||||
- [`CONSUL_ALLOWLIST_URL`](#CONSUL_ALLOWLIST_URL)
|
||||
- [`CONSUL_DNSCHECK_KV_URL`](#CONSUL_DNSCHECK_KV_URL)
|
||||
- [`CONSUL_DNSCHECK_SESSION_URL`](#CONSUL_DNSCHECK_SESSION_URL)
|
||||
- [`DNSCHECK_REMOTEKV_API_KEY`](#DNSCHECK_REMOTEKV_API_KEY)
|
||||
- [`DNSCHECK_REMOTEKV_URL`](#DNSCHECK_REMOTEKV_URL)
|
||||
- [`FILTER_CACHE_PATH`](#FILTER_CACHE_PATH)
|
||||
- [`FILTER_INDEX_URL`](#FILTER_INDEX_URL)
|
||||
- [`GENERAL_SAFE_ENABLED`](#GENERAL_SAFE_SEARCH_ENABLED)
|
||||
@ -62,6 +66,21 @@ The HTTP(S) URL of source list of rules for adult blocking filter.
|
||||
|
||||
**Default:** No default value, the variable is required if `ADULT_BLOCKING_ENABLED` is set to `1`.
|
||||
|
||||
## <a href="#BACKEND_RATELIMIT_API_KEY" id="BACKEND_RATELIMIT_API_KEY" name="BACKEND_RATELIMIT_API_KEY">`BACKEND_RATELIMIT_API_KEY`</a>
|
||||
|
||||
The API key to use when authenticating requests to the backend rate limiter API, if any. The API key should be valid as defined by [RFC 6750].
|
||||
|
||||
**Default:** **Unset.**
|
||||
|
||||
## <a href="#BACKEND_RATELIMIT_URL" id="BACKEND_RATELIMIT_URL" name="BACKEND_RATELIMIT_URL">`BACKEND_RATELIMIT_URL`</a>
|
||||
|
||||
The base backend URL for backend rate limiter. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-backend-ratelimit].
|
||||
|
||||
**Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `backend` in the configuration file.
|
||||
|
||||
[conf-ratelimit-type]: configuration.md#ratelimit-type
|
||||
[ext-backend-ratelimit]: externalhttp.md#backend-ratelimit
|
||||
|
||||
## <a href="#BILLSTAT_API_KEY" id="BILLSTAT_API_KEY" name="BILLSTAT_API_KEY">`BILLSTAT_API_KEY`</a>
|
||||
|
||||
The API key to use when authenticating queries to the billing statistics API, if any. The API key should be valid as defined by [RFC 6750].
|
||||
@ -72,7 +91,7 @@ The API key to use when authenticating queries to the billing statistics API, if
|
||||
|
||||
## <a href="#BILLSTAT_URL" id="BILLSTAT_URL" name="BILLSTAT_URL">`BILLSTAT_URL`</a>
|
||||
|
||||
The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat].
|
||||
The base backend URL for backend billing statistics uploader API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external HTTP API requirements section][ext-billstat].
|
||||
|
||||
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
|
||||
|
||||
@ -103,7 +122,7 @@ The path to the configuration file.
|
||||
|
||||
The HTTP(S) URL of the Consul instance serving the dynamic part of the rate-limit allowlist. See the [external HTTP API requirements section][ext-consul] on the expected format of the response.
|
||||
|
||||
**Default:** No default value, the variable is **required.**
|
||||
**Default:** No default value, the variable is required if the [type][conf-ratelimit-type] of rate limiter is `consul` in the configuration file.
|
||||
|
||||
[ext-consul]: externalhttp.md#consul
|
||||
|
||||
@ -123,6 +142,20 @@ The HTTP(S) URL of the session API of the Consul instance used as a key-value da
|
||||
|
||||
**Example:** `http://localhost:8500/v1/session/create`
|
||||
|
||||
## <a href="#DNSCHECK_REMOTEKV_API_KEY" id="DNSCHECK_REMOTEKV_API_KEY" name="DNSCHECK_REMOTEKV_API_KEY">`DNSCHECK_REMOTEKV_API_KEY`</a>
|
||||
|
||||
The API key to use when authenticating queries to the backend key-value database API, if any. The API key should be valid as defined by [RFC 6750].
|
||||
|
||||
**Default:** **Unset.**
|
||||
|
||||
## <a href="#DNSCHECK_REMOTEKV_URL" id="DNSCHECK_REMOTEKV_URL" name="DNSCHECK_REMOTEKV_URL">`DNSCHECK_REMOTEKV_URL`</a>
|
||||
|
||||
The base backend URL used as a key-value database for the DNS server checking. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-backend-dnscheck].
|
||||
|
||||
**Default:** **Unset.**
|
||||
|
||||
[ext-backend-dnscheck]: externalhttp.md#backend-dnscheck
|
||||
|
||||
## <a href="#FILTER_CACHE_PATH" id="FILTER_CACHE_PATH" name="FILTER_CACHE_PATH">`FILTER_CACHE_PATH`</a>
|
||||
|
||||
The path to the directory used to store the cached version of all filters and filter indexes.
|
||||
@ -238,11 +271,11 @@ The profile cache is read on start and is later updated on every [full refresh][
|
||||
|
||||
The maximum size of the response from the profiles API in a human-readable format.
|
||||
|
||||
**Default:** `8MB`.
|
||||
**Default:** `64MB`.
|
||||
|
||||
## <a href="#PROFILES_URL" id="PROFILES_URL" name="PROFILES_URL">`PROFILES_URL`</a>
|
||||
|
||||
The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`) URLs. See the [external API requirements section][ext-profiles].
|
||||
The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and `grpcs://`) URLs. See the [external API requirements section][ext-profiles].
|
||||
|
||||
**Default:** No default value, the variable is required if there is at least one [server group][conf-sg] with profiles enabled.
|
||||
|
||||
@ -252,7 +285,7 @@ The base backend URL for profiles API. Supports gRPC(S) (`grpc://` and`grpcs://`
|
||||
|
||||
Redis server address. Can be an IP address or a hostname.
|
||||
|
||||
**Default:** No default value, the variable if required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file.
|
||||
**Default:** No default value, the variable is required if the [type][conf-check-kv-type] of remote KV storage for DNS server checking is `redis` in the configuration file.
|
||||
|
||||
[conf-check-kv-type]: configuration.md#check-kv-type
|
||||
|
||||
|
@ -10,7 +10,9 @@ AdGuard DNS uses information from external HTTP APIs for filtering and other pie
|
||||
## Contents
|
||||
|
||||
- [Backend billing statistics](#backend-billstat)
|
||||
- [Backend DNSCheck service](#backend-dnscheck)
|
||||
- [Backend profiles service](#backend-profiles)
|
||||
- [Backend ratelimit service](#backend-ratelimit)
|
||||
- [Consul key-value storage](#consul)
|
||||
- [Filtering](#filters)
|
||||
- [Blocked services](#filters-blocked-services)
|
||||
@ -28,6 +30,15 @@ This service is disabled when all server groups have property [`profiles_enabled
|
||||
[env-billstat_url]: environment.md#BILLSTAT_URL
|
||||
[conf-srvgrp-prof]: configuration.md#sg-*-profiles_enabled
|
||||
|
||||
## <a href="#backend-dnscheck" id="backend-dnscheck" name="backend-dnscheck">Backend DNSCheck service</a>
|
||||
|
||||
This is the service to which the [`DNSCHECK_REMOTEKV_URL`][env-dnscheck_remotekv_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
|
||||
|
||||
This service is only enabled when the `check.kv` object has the [`type`][conf-check-kv-type] property set to `backend`.
|
||||
|
||||
[env-dnscheck_remotekv_url]: environment.md#DNSCHECK_REMOTEKV_URL
|
||||
[conf-check-kv-type]: configuration.md#check-kv-type
|
||||
|
||||
## <a href="#backend-profiles" id="backend-profiles" name="backend-profiles">Backend profiles service</a>
|
||||
|
||||
This is the service to which the [`PROFILES_URL`][env-profiles_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
|
||||
@ -36,6 +47,15 @@ This service is disabled when all server groups have property [`profiles_enabled
|
||||
|
||||
[env-profiles_url]: environment.md#PROFILES_URL
|
||||
|
||||
## <a href="#backend-ratelimit" id="backend-ratelimit" name="backend-ratelimit">Backend ratelimit_service</a>
|
||||
|
||||
This is the service to which the [`BACKEND_RATELIMIT_URL`][env-backend_ratelimit_url] environment variable points. Supports gRPC(s) URLs. The service must correspond to `./internal/backendpb/dns.proto`.
|
||||
|
||||
This service is only enabled when the `ratelimit` object has the [`type`][conf-ratelimit-type] property set to `backend`.
|
||||
|
||||
[conf-ratelimit-type]: configuration.md#ratelimit-type
|
||||
[env-backend_ratelimit_url]: environment.md#BACKEND_RATELIMIT_URL
|
||||
|
||||
## <a href="#consul" id="consul" name="consul">Consul key-value storage</a>
|
||||
|
||||
A [Consul][consul-io] service can be used for the DNS server check and dynamic rate-limit allowlist features. Currently used endpoints can be seen in the documentation of the [`CONSUL_ALLOWLIST_URL`][env-consul-allowlist], [`CONSUL_DNSCHECK_KV_URL`][env-consul-dnscheck-kv], and [`CONSUL_DNSCHECK_SESSION_URL`][env-consul-dnscheck-session] environment variables.
|
||||
|
43
go.mod
43
go.mod
@ -1,34 +1,34 @@
|
||||
module github.com/AdguardTeam/AdGuardDNS
|
||||
|
||||
go 1.23.1
|
||||
go 1.23.2
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.0.0-20240607112746-5690301129fe
|
||||
github.com/AdguardTeam/golibs v0.28.0
|
||||
github.com/AdguardTeam/urlfilter v0.19.0
|
||||
github.com/AdguardTeam/golibs v0.30.1
|
||||
github.com/AdguardTeam/urlfilter v0.20.0
|
||||
github.com/ameshkov/dnscrypt/v2 v2.3.0
|
||||
github.com/axiomhq/hyperloglog v0.2.0
|
||||
github.com/bluele/gcache v0.0.2
|
||||
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
|
||||
github.com/caarlos0/env/v7 v7.1.0
|
||||
github.com/getsentry/sentry-go v0.28.1
|
||||
github.com/getsentry/sentry-go v0.29.1
|
||||
github.com/gomodule/redigo v1.9.2
|
||||
github.com/google/renameio/v2 v2.0.0
|
||||
github.com/miekg/dns v1.1.62
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
|
||||
github.com/prometheus/client_golang v1.20.1
|
||||
github.com/prometheus/client_golang v1.20.5
|
||||
github.com/prometheus/client_model v0.6.1
|
||||
github.com/prometheus/common v0.55.0
|
||||
github.com/quic-go/quic-go v0.47.0
|
||||
github.com/prometheus/common v0.60.0
|
||||
github.com/quic-go/quic-go v0.48.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/crypto v0.27.0
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
|
||||
golang.org/x/net v0.29.0
|
||||
golang.org/x/sys v0.25.0
|
||||
golang.org/x/time v0.6.0
|
||||
google.golang.org/grpc v1.65.0
|
||||
google.golang.org/protobuf v1.34.2
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/sys v0.26.0
|
||||
golang.org/x/time v0.7.0
|
||||
google.golang.org/grpc v1.67.1
|
||||
google.golang.org/protobuf v1.35.1
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
@ -41,24 +41,21 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.20.2 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.10.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
golang.org/x/mod v0.21.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/text v0.18.0 // indirect
|
||||
golang.org/x/tools v0.25.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/tools v0.26.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/AdguardTeam/AdGuardDNS/internal/dnsserver => ./internal/dnsserver
|
||||
|
||||
// TODO(a.garipov): Remove once https://github.com/quic-go/quic-go/pull/4685 is merged.
|
||||
replace github.com/quic-go/quic-go => github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd
|
||||
|
76
go.sum
76
go.sum
@ -1,13 +1,11 @@
|
||||
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4=
|
||||
github.com/AdguardTeam/golibs v0.28.0/go.mod h1:iWdjXPCwmK2g2FKIb/OwEPnovSXeMqRhI8FWLxF5oxE=
|
||||
github.com/AdguardTeam/urlfilter v0.19.0 h1:q7eH13+yNETlpD/VD3u5rLQOripcUdEktqZFy+KiQLk=
|
||||
github.com/AdguardTeam/urlfilter v0.19.0/go.mod h1:+N54ZvxqXYLnXuvpaUhK2exDQW+djZBRSb6F6j0rkBY=
|
||||
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
|
||||
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
|
||||
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
||||
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
|
||||
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd h1:mw4LqrCiv3vcKuCxBRg7kA17xfHKM+9hZgFWmyhe/AY=
|
||||
github.com/ainar-g/quic-go v0.0.0-20240930125330-446bd86056fd/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs=
|
||||
github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE=
|
||||
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
|
||||
@ -29,8 +27,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 h1:y7y0Oa6UawqTFPCDw9JG6pdKt4F9pAhHv0B7FMGaGD0=
|
||||
github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
|
||||
github.com/getsentry/sentry-go v0.28.1 h1:zzaSm/vHmGllRM6Tpx1492r0YDzauArdBfkJRtY6P5k=
|
||||
github.com/getsentry/sentry-go v0.28.1/go.mod h1:1fQZ+7l7eeJ3wYi82q5Hg8GqAPgefRq+FP/QhafYVgg=
|
||||
github.com/getsentry/sentry-go v0.29.1 h1:DyZuChN8Hz3ARxGVV8ePaNXh1dQ7d76AiB117xcREwA=
|
||||
github.com/getsentry/sentry-go v0.29.1/go.mod h1:x3AtIzN01d6SiWkderzaH28Tm0lgkafpJ5Bm3li39O0=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
@ -43,12 +41,12 @@ github.com/gomodule/redigo v1.9.2 h1:HrutZBLhSIU8abiSfW8pj8mPhOyMYjZT/wcA4/L9L9s
|
||||
github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4=
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
|
||||
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@ -79,16 +77,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8=
|
||||
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
|
||||
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
||||
github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
|
||||
github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
|
||||
github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
@ -109,33 +109,33 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
||||
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
|
||||
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
|
||||
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
|
||||
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
|
||||
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd h1:6TEm2ZxXoQmFWFlt1vNxvVOa1Q0dXFQD1m/rYjXmS0E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
|
||||
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
|
||||
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
|
||||
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
|
||||
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
|
||||
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
|
||||
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
|
||||
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
14
go.work.sum
14
go.work.sum
@ -1,5 +1,6 @@
|
||||
cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w=
|
||||
cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
|
||||
cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg=
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
@ -45,6 +46,7 @@ cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGB
|
||||
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||
cloud.google.com/go/contactcenterinsights v1.13.0/go.mod h1:ieq5d5EtHsu8vhe2y3amtZ+BE+AQwX5qAy7cpo0POsI=
|
||||
cloud.google.com/go/container v1.31.0/go.mod h1:7yABn5s3Iv3lmw7oMmyGbeV6tQj86njcTijkkGuvdZA=
|
||||
cloud.google.com/go/containeranalysis v0.11.4/go.mod h1:cVZT7rXYBS9NG1rhQbWL9pWbXCKHWJPYraE8/FTSYPE=
|
||||
@ -156,6 +158,9 @@ github.com/AdguardTeam/golibs v0.19.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMss
|
||||
github.com/AdguardTeam/golibs v0.25.2/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
|
||||
github.com/AdguardTeam/golibs v0.25.3 h1:A06JZGSuAhAC0uq/s7IlNsv/V8TyNJfLalB0vhkd1vA=
|
||||
github.com/AdguardTeam/golibs v0.25.3/go.mod h1:HaTyS2wCbxFudjht9N/+/Qf1b5cMad2BAYSwe7DPCXI=
|
||||
github.com/AdguardTeam/golibs v0.30.0/go.mod h1:vjw1OVZG6BYyoqGRY88U4LCJLOMfhBFhU0UJBdaSAuQ=
|
||||
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
|
||||
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.1 h1:p9gr8Er1TYvf+7ic81Ax1sZ62UNCsMTZNbm7tC59S9o=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.1/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||
@ -245,6 +250,7 @@ github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50 h1:DBmgJDC9dTfkVyGgipa
|
||||
github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50/go.mod h1:5e1+Vvlzido69INQaVO6d87Qn543Xr6nooe9Kz7oBFM=
|
||||
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
|
||||
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
||||
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
||||
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=
|
||||
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM=
|
||||
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI=
|
||||
@ -262,12 +268,14 @@ github.com/envoyproxy/go-control-plane v0.11.1 h1:wSUXTlLfiAQRWs2F+p+EKOY9rUyis1
|
||||
github.com/envoyproxy/go-control-plane v0.11.1/go.mod h1:uhMcXKCQMEJHiAb0w+YGefQLaTEw+YhGluxZkrTmD0g=
|
||||
github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI=
|
||||
github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0=
|
||||
github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4=
|
||||
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
|
||||
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
@ -338,6 +346,7 @@ github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
|
||||
github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4=
|
||||
github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
@ -576,6 +585,7 @@ github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwb
|
||||
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
|
||||
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
|
||||
@ -677,6 +687,7 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
|
||||
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
|
||||
github.com/urfave/negroni/v3 v3.1.1/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc=
|
||||
@ -828,6 +839,7 @@ golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
|
||||
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
|
||||
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
||||
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
|
||||
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -912,6 +924,7 @@ golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
|
||||
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
||||
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM=
|
||||
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
|
||||
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
@ -1000,6 +1013,7 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo=
|
||||
google.golang.org/genproto/googleapis/bytestream v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:vh/N7795ftP0AkN1w8XKqN4w1OdUKXW5Eummda+ofv8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// Client is a wrapper around http.Client.
|
||||
@ -93,6 +94,7 @@ func (c *Client) do(
|
||||
req.Header.Set(httphdr.UserAgent, c.userAgent)
|
||||
|
||||
resp, err = c.http.Do(req)
|
||||
urlutil.RedactUserinfoInURLError(u, err)
|
||||
if err != nil && resp != nil && resp.Header != nil {
|
||||
// A non-nil Response with a non-nil error only occurs when
|
||||
// CheckRedirect fails.
|
||||
|
@ -5,51 +5,13 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// Known scheme constants.
|
||||
//
|
||||
// TODO(a.garipov): Move to agdurl or golibs.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const (
|
||||
SchemeFile = "file"
|
||||
SchemeGRPC = "grpc"
|
||||
SchemeGRPCS = "grpcs"
|
||||
SchemeHTTP = "http"
|
||||
SchemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// CheckGRPCURLScheme returns true if s is a valid gRPC URL scheme. That is,
|
||||
// [SchemeGRPC] or [SchemeGRPCS]
|
||||
//
|
||||
// TODO(a.garipov): Move to golibs?
|
||||
func CheckGRPCURLScheme(s string) (ok bool) {
|
||||
switch s {
|
||||
case SchemeGRPC, SchemeGRPCS:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CheckHTTPURLScheme returns true if s is a valid HTTP URL scheme. That is,
|
||||
// [SchemeHTTP] or [SchemeHTTPS]
|
||||
//
|
||||
// TODO(a.garipov): Move to golibs?
|
||||
func CheckHTTPURLScheme(s string) (ok bool) {
|
||||
switch s {
|
||||
case SchemeHTTP, SchemeHTTPS:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ParseHTTPURL parses an absolute URL and makes sure that it is a valid HTTP(S)
|
||||
// URL. All returned errors will have the underlying type [*url.Error].
|
||||
//
|
||||
// TODO(a.garipov): Define as a type?
|
||||
// TODO(a.garipov): Define as a type?
|
||||
func ParseHTTPURL(s string) (u *url.URL, err error) {
|
||||
u, err = url.Parse(s)
|
||||
if err != nil {
|
||||
@ -63,7 +25,7 @@ func ParseHTTPURL(s string) (u *url.URL, err error) {
|
||||
URL: s,
|
||||
Err: errors.Error("empty host"),
|
||||
}
|
||||
case !CheckHTTPURLScheme(u.Scheme):
|
||||
case !urlutil.IsValidHTTPURLScheme(u.Scheme):
|
||||
return nil, &url.Error{
|
||||
Op: "parse",
|
||||
URL: s,
|
||||
|
@ -6,12 +6,19 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Common user credentials for tests.
|
||||
const (
|
||||
testUsername = "user"
|
||||
testPassword = "pass"
|
||||
)
|
||||
|
||||
func TestParseHTTPURL(t *testing.T) {
|
||||
goodURL := testURL()
|
||||
goodURL := testURL(url.UserPassword(testUsername, testPassword))
|
||||
|
||||
badSchemeURL := netutil.CloneURL(goodURL)
|
||||
badSchemeURL.Scheme = "ftp"
|
||||
@ -61,10 +68,11 @@ func TestParseHTTPURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testURL() (u *url.URL) {
|
||||
// testURL is a helper function that returns an url with dummy values.
|
||||
func testURL(info *url.Userinfo) (u *url.URL) {
|
||||
return &url.URL{
|
||||
Scheme: agdhttp.SchemeHTTP,
|
||||
User: url.UserPassword("user", "pass"),
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
User: info,
|
||||
Host: "example.com",
|
||||
Path: "/a/b/c/",
|
||||
RawQuery: "d=e",
|
||||
|
@ -3,6 +3,7 @@
|
||||
package agdtest
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -18,7 +19,7 @@ const FilteredResponseTTL = FilteredResponseTTLSec * time.Second
|
||||
// number to simplify message creation.
|
||||
const FilteredResponseTTLSec = 10
|
||||
|
||||
// NewConstructorWithTTL returns a standard dnsmsg.Constructor for tests, using
|
||||
// NewConstructorWithTTL returns a standard *dnsmsg.Constructor for tests, using
|
||||
// ttl as the TTL for filtered responses.
|
||||
func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Constructor) {
|
||||
tb.Helper()
|
||||
@ -26,28 +27,57 @@ func NewConstructorWithTTL(tb testing.TB, ttl time.Duration) (c *dnsmsg.Construc
|
||||
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: NewCloner(),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: NewSDEConfig(true),
|
||||
FilteredResponseTTL: ttl,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewConstructor returns a standard dnsmsg.Constructor for tests, using
|
||||
// [FilteredResponseTTL] as the TTL for filtered responses.
|
||||
// NewConstructor returns a standard *dnsmsg.Constructor for tests, using
|
||||
// [FilteredResponseTTL] as the TTL for filtered responses. The returned
|
||||
// constructor also has the Structured DNS Errors feature enabled.
|
||||
func NewConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
|
||||
tb.Helper()
|
||||
|
||||
c, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: NewCloner(),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: NewSDEConfig(true),
|
||||
FilteredResponseTTL: FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// SDEText is a test Structured DNS Error text.
|
||||
//
|
||||
// NOTE: Keep in sync with [NewSDEConfig].
|
||||
//
|
||||
// TODO(e.burkov): Add some helper when this message becomes configurable.
|
||||
const SDEText = `{` +
|
||||
`"j":"Filtering",` +
|
||||
`"o":"Test Org",` +
|
||||
`"c":["mailto:support@dns.example"]` +
|
||||
`}`
|
||||
|
||||
// NewSDEConfig returns a standard *dnsmsg.StructuredDNSErrorsConfig for tests.
|
||||
func NewSDEConfig(enabled bool) (c *dnsmsg.StructuredDNSErrorsConfig) {
|
||||
return &dnsmsg.StructuredDNSErrorsConfig{
|
||||
Contact: []*url.URL{{
|
||||
Scheme: "mailto",
|
||||
Opaque: "support@dns.example",
|
||||
}},
|
||||
Justification: "Filtering",
|
||||
Organization: "Test Org",
|
||||
Enabled: enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCloner returns a standard dnsmsg.Cloner for tests.
|
||||
func NewCloner() (c *dnsmsg.Cloner) {
|
||||
return dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{})
|
||||
|
@ -116,7 +116,7 @@ func (r *Refresher) Refresh(ctx context.Context) (err error) {
|
||||
// type check
|
||||
var _ billstat.Recorder = (*BillStatRecorder)(nil)
|
||||
|
||||
// BillStatRecorder is a billstat.Recorder for tests.
|
||||
// BillStatRecorder is a [billstat.Recorder] for tests.
|
||||
type BillStatRecorder struct {
|
||||
OnRecord func(
|
||||
ctx context.Context,
|
||||
@ -128,7 +128,7 @@ type BillStatRecorder struct {
|
||||
)
|
||||
}
|
||||
|
||||
// Record implements the billstat.Recorder interface for *BillStatRecorder.
|
||||
// Record implements the [billstat.Recorder] interface for *BillStatRecorder.
|
||||
func (r *BillStatRecorder) Record(
|
||||
ctx context.Context,
|
||||
id agd.DeviceID,
|
||||
@ -143,12 +143,12 @@ func (r *BillStatRecorder) Record(
|
||||
// type check
|
||||
var _ billstat.Uploader = (*BillStatUploader)(nil)
|
||||
|
||||
// BillStatUploader is a billstat.Uploader for tests.
|
||||
// BillStatUploader is a [billstat.Uploader] for tests.
|
||||
type BillStatUploader struct {
|
||||
OnUpload func(ctx context.Context, records billstat.Records) (err error)
|
||||
}
|
||||
|
||||
// Upload implements the billstat.Uploader interface for *BillStatUploader.
|
||||
// Upload implements the [billstat.Uploader] interface for *BillStatUploader.
|
||||
func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records) (err error) {
|
||||
return b.OnUpload(ctx, records)
|
||||
}
|
||||
@ -158,7 +158,7 @@ func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records)
|
||||
// type check
|
||||
var _ dnscheck.Interface = (*DNSCheck)(nil)
|
||||
|
||||
// DNSCheck is a dnscheck.Interface for tests.
|
||||
// DNSCheck is a [dnscheck.Interface] for tests.
|
||||
type DNSCheck struct {
|
||||
OnCheck func(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (reqp *dns.Msg, err error)
|
||||
}
|
||||
@ -177,12 +177,12 @@ func (db *DNSCheck) Check(
|
||||
// type check
|
||||
var _ dnsdb.Interface = (*DNSDB)(nil)
|
||||
|
||||
// DNSDB is a dnsdb.Interface for tests.
|
||||
// DNSDB is a [dnsdb.Interface] for tests.
|
||||
type DNSDB struct {
|
||||
OnRecord func(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo)
|
||||
}
|
||||
|
||||
// Record implements the dnsdb.Interface interface for *DNSDB.
|
||||
// Record implements the [dnsdb.Interface] interface for *DNSDB.
|
||||
func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) {
|
||||
db.OnRecord(ctx, resp, ri)
|
||||
}
|
||||
@ -204,7 +204,7 @@ func (c *ErrorCollector) Collect(ctx context.Context, err error) {
|
||||
c.OnCollect(ctx, err)
|
||||
}
|
||||
|
||||
// NewErrorCollector returns a new [ErrorCollector] all methods of which panic.
|
||||
// NewErrorCollector returns a new *ErrorCollector all methods of which panic.
|
||||
func NewErrorCollector() (c *ErrorCollector) {
|
||||
return &ErrorCollector{
|
||||
OnCollect: func(_ context.Context, err error) {
|
||||
@ -297,21 +297,38 @@ func (s *FilterStorage) HasListID(id agd.FilterListID) (ok bool) {
|
||||
// type check
|
||||
var _ geoip.Interface = (*GeoIP)(nil)
|
||||
|
||||
// GeoIP is a geoip.Interface for tests.
|
||||
// GeoIP is a [geoip.Interface] for tests.
|
||||
type GeoIP struct {
|
||||
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
|
||||
OnData func(host string, ip netip.Addr) (l *geoip.Location, err error)
|
||||
OnSubnetByLocation func(l *geoip.Location, fam netutil.AddrFamily) (n netip.Prefix, err error)
|
||||
}
|
||||
|
||||
// SubnetByLocation implements the geoip.Interface interface for *GeoIP.
|
||||
func (g *GeoIP) SubnetByLocation(l *geoip.Location, fam netutil.AddrFamily,
|
||||
// Data implements the [geoip.Interface] interface for *GeoIP.
|
||||
func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) {
|
||||
return g.OnData(host, ip)
|
||||
}
|
||||
|
||||
// SubnetByLocation implements the [geoip.Interface] interface for *GeoIP.
|
||||
func (g *GeoIP) SubnetByLocation(
|
||||
l *geoip.Location,
|
||||
fam netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
return g.OnSubnetByLocation(l, fam)
|
||||
}
|
||||
|
||||
// Data implements the geoip.Interface interface for *GeoIP.
|
||||
func (g *GeoIP) Data(host string, ip netip.Addr) (l *geoip.Location, err error) {
|
||||
return g.OnData(host, ip)
|
||||
// NewGeoIP returns a new *GeoIP all methods of which panic.
|
||||
func NewGeoIP() (c *GeoIP) {
|
||||
return &GeoIP{
|
||||
OnData: func(host string, ip netip.Addr) (l *geoip.Location, err error) {
|
||||
panic(fmt.Errorf("unexpected call to GeoIP.Data(%v, %v)", host, ip))
|
||||
},
|
||||
OnSubnetByLocation: func(
|
||||
l *geoip.Location,
|
||||
fam netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic(fmt.Errorf("unexpected call to GeoIP.SubnetByLocation(%v, %v)", l, fam))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Package profiledb
|
||||
@ -398,6 +415,58 @@ func (db *ProfileDB) ProfileByLinkedIP(
|
||||
return db.OnProfileByLinkedIP(ctx, ip)
|
||||
}
|
||||
|
||||
// NewProfileDB returns a new *ProfileDB all methods of which panic.
|
||||
func NewProfileDB() (db *ProfileDB) {
|
||||
return &ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
_ context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic(fmt.Errorf(
|
||||
"unexpected call to ProfileDB.CreateAutoDevice(%v, %v, %v)",
|
||||
id,
|
||||
humanID,
|
||||
devType,
|
||||
))
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
ip netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDedicatedIP(%v)", ip))
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
id agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByDeviceID(%v)", id))
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
profID agd.ProfileID,
|
||||
humanID agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic(fmt.Errorf(
|
||||
"unexpected call to ProfileDB.ProfileByHumanID(%v, %v)",
|
||||
profID,
|
||||
humanID,
|
||||
))
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
ip netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic(fmt.Errorf("unexpected call to ProfileDB.ProfileByLinkedIP(%v)", ip))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ profiledb.Storage = (*ProfileStorage)(nil)
|
||||
|
||||
@ -436,12 +505,12 @@ func (s *ProfileStorage) Profiles(
|
||||
// type check
|
||||
var _ querylog.Interface = (*QueryLog)(nil)
|
||||
|
||||
// QueryLog is a querylog.Interface for tests.
|
||||
// QueryLog is a [querylog.Interface] for tests.
|
||||
type QueryLog struct {
|
||||
OnWrite func(ctx context.Context, e *querylog.Entry) (err error)
|
||||
}
|
||||
|
||||
// Write implements the querylog.Interface interface for *QueryLog.
|
||||
// Write implements the [querylog.Interface] interface for *QueryLog.
|
||||
func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
|
||||
return ql.OnWrite(ctx, e)
|
||||
}
|
||||
@ -451,12 +520,12 @@ func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
|
||||
// type check
|
||||
var _ rulestat.Interface = (*RuleStat)(nil)
|
||||
|
||||
// RuleStat is a rulestat.Interface for tests.
|
||||
// RuleStat is a [rulestat.Interface] for tests.
|
||||
type RuleStat struct {
|
||||
OnCollect func(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText)
|
||||
}
|
||||
|
||||
// Collect implements the rulestat.Interface interface for *RuleStat.
|
||||
// Collect implements the [rulestat.Interface] interface for *RuleStat.
|
||||
func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.FilterRuleText) {
|
||||
s.OnCollect(ctx, id, text)
|
||||
}
|
||||
@ -501,7 +570,7 @@ func (c *ListenConfig) ListenPacket(
|
||||
// type check
|
||||
var _ ratelimit.Interface = (*RateLimit)(nil)
|
||||
|
||||
// RateLimit is a ratelimit.Interface for tests.
|
||||
// RateLimit is a [ratelimit.Interface] for tests.
|
||||
type RateLimit struct {
|
||||
OnIsRateLimited func(
|
||||
ctx context.Context,
|
||||
@ -511,7 +580,7 @@ type RateLimit struct {
|
||||
OnCountResponses func(ctx context.Context, resp *dns.Msg, ip netip.Addr)
|
||||
}
|
||||
|
||||
// IsRateLimited implements the ratelimit.Interface interface for *RateLimit.
|
||||
// IsRateLimited implements the [ratelimit.Interface] interface for *RateLimit.
|
||||
func (l *RateLimit) IsRateLimited(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
@ -520,12 +589,27 @@ func (l *RateLimit) IsRateLimited(
|
||||
return l.OnIsRateLimited(ctx, req, ip)
|
||||
}
|
||||
|
||||
// CountResponses implements the ratelimit.Interface interface for
|
||||
// *RateLimit.
|
||||
// CountResponses implements the [ratelimit.Interface] interface for *RateLimit.
|
||||
func (l *RateLimit) CountResponses(ctx context.Context, req *dns.Msg, ip netip.Addr) {
|
||||
l.OnCountResponses(ctx, req, ip)
|
||||
}
|
||||
|
||||
// NewRateLimit returns a new *RateLimit all methods of which panic.
|
||||
func NewRateLimit() (c *RateLimit) {
|
||||
return &RateLimit{
|
||||
OnIsRateLimited: func(
|
||||
_ context.Context,
|
||||
req *dns.Msg,
|
||||
addr netip.Addr,
|
||||
) (shouldDrop, isAllowlisted bool, err error) {
|
||||
panic(fmt.Errorf("unexpected call to RateLimit.IsRateLimited(%v, %v)", req, addr))
|
||||
},
|
||||
OnCountResponses: func(_ context.Context, resp *dns.Msg, addr netip.Addr) {
|
||||
panic(fmt.Errorf("unexpected call to RateLimit.CountResponses(%v, %v)", resp, addr))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteKV is an [remotekv.Interface] implementation for tests.
|
||||
type RemoteKV struct {
|
||||
OnGet func(ctx context.Context, key string) (val []byte, ok bool, err error)
|
||||
|
@ -16,8 +16,8 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// newClient returns new properly initialized DNSServiceClient.
|
||||
func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
|
||||
// newClient returns new properly initialized gRPC connection to the API server.
|
||||
func newClient(apiURL *url.URL) (client *grpc.ClientConn, err error) {
|
||||
var creds credentials.TransportCredentials
|
||||
switch s := apiURL.Scheme; s {
|
||||
case "grpc":
|
||||
@ -38,7 +38,7 @@ func newClient(apiURL *url.URL) (client DNSServiceClient, err error) {
|
||||
// called right before the initial refresh.
|
||||
conn.Connect()
|
||||
|
||||
return NewDNSServiceClient(conn), nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// reportf is a helper method for reporting non-critical errors.
|
||||
|
@ -65,3 +65,39 @@ func (s *testDNSServiceServer) SaveDevicesBillingStat(
|
||||
) (err error) {
|
||||
return s.OnSaveDevicesBillingStat(srv)
|
||||
}
|
||||
|
||||
// testRemoteKVServiceServer is the [backendpb.RemoteKVServiceServer] for tests.
|
||||
type testRemoteKVServiceServer struct {
|
||||
backendpb.UnimplementedRemoteKVServiceServer
|
||||
|
||||
OnGet func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVGetRequest,
|
||||
) (resp *backendpb.RemoteKVGetResponse, err error)
|
||||
|
||||
OnSet func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVSetRequest,
|
||||
) (resp *backendpb.RemoteKVSetResponse, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ backendpb.RemoteKVServiceServer = (*testRemoteKVServiceServer)(nil)
|
||||
|
||||
// Get implements the [backendpb.RemoteKVServiceServer] interface for
|
||||
// *testRemoteKVServiceServer.
|
||||
func (s *testRemoteKVServiceServer) Get(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVGetRequest,
|
||||
) (resp *backendpb.RemoteKVGetResponse, err error) {
|
||||
return s.OnGet(ctx, req)
|
||||
}
|
||||
|
||||
// Set implements the [backendpb.RemoteKVServiceServer] interface for
|
||||
// *testRemoteKVServiceServer.
|
||||
func (s *testRemoteKVServiceServer) Set(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVSetRequest,
|
||||
) (resp *backendpb.RemoteKVSetResponse, err error) {
|
||||
return s.OnSet(ctx, req)
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func NewBillStat(c *BillStatConfig) (b *BillStat, err error) {
|
||||
return &BillStat{
|
||||
errColl: c.ErrColl,
|
||||
metrics: c.Metrics,
|
||||
client: client,
|
||||
client: NewDNSServiceClient(client),
|
||||
apiKey: c.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,42 @@ service DNSService {
|
||||
rpc createDeviceByHumanId(CreateDeviceRequest) returns (CreateDeviceResponse);
|
||||
}
|
||||
|
||||
service RateLimitService {
|
||||
|
||||
/*
|
||||
Gets rate limit settings.
|
||||
*/
|
||||
rpc getRateLimitSettings(RateLimitSettingsRequest) returns (RateLimitSettingsResponse);
|
||||
}
|
||||
|
||||
service RemoteKVService {
|
||||
|
||||
/**
|
||||
Get the value for the specified key.
|
||||
|
||||
This method may return the following errors:
|
||||
- AuthenticationFailedError: If the authentication failed.
|
||||
*/
|
||||
rpc get(RemoteKVGetRequest) returns (RemoteKVGetResponse);
|
||||
|
||||
/**
|
||||
Set the value for the specified key.
|
||||
|
||||
This method may return the following errors:
|
||||
- AuthenticationFailedError: If the authentication failed.
|
||||
- BadRequestError: If the request is invalid: value size exceeds the 512kb.
|
||||
*/
|
||||
rpc set(RemoteKVSetRequest) returns (RemoteKVSetResponse);
|
||||
}
|
||||
|
||||
message RateLimitSettingsRequest {
|
||||
|
||||
}
|
||||
|
||||
message RateLimitSettingsResponse {
|
||||
repeated CidrRange allowed_subnets = 1;
|
||||
}
|
||||
|
||||
message DNSProfilesRequest {
|
||||
google.protobuf.Timestamp sync_time = 1;
|
||||
}
|
||||
@ -212,3 +248,24 @@ message RateLimitSettings {
|
||||
uint32 rps = 2;
|
||||
repeated CidrRange client_cidr = 3;
|
||||
}
|
||||
|
||||
message RemoteKVGetRequest {
|
||||
string key = 1;
|
||||
}
|
||||
|
||||
message RemoteKVGetResponse {
|
||||
oneof value {
|
||||
bytes data = 1;
|
||||
google.protobuf.Empty empty = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message RemoteKVSetRequest {
|
||||
string key = 1;
|
||||
bytes data = 2;
|
||||
google.protobuf.Duration ttl = 3;
|
||||
}
|
||||
|
||||
message RemoteKVSetResponse {
|
||||
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.27.1
|
||||
// - protoc v5.28.3
|
||||
// source: dns.proto
|
||||
|
||||
package backendpb
|
||||
@ -235,3 +235,269 @@ var DNSService_ServiceDesc = grpc.ServiceDesc{
|
||||
},
|
||||
Metadata: "dns.proto",
|
||||
}
|
||||
|
||||
const (
|
||||
RateLimitService_GetRateLimitSettings_FullMethodName = "/RateLimitService/getRateLimitSettings"
|
||||
)
|
||||
|
||||
// RateLimitServiceClient is the client API for RateLimitService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type RateLimitServiceClient interface {
|
||||
// Gets rate limit settings.
|
||||
GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error)
|
||||
}
|
||||
|
||||
type rateLimitServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewRateLimitServiceClient(cc grpc.ClientConnInterface) RateLimitServiceClient {
|
||||
return &rateLimitServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *rateLimitServiceClient) GetRateLimitSettings(ctx context.Context, in *RateLimitSettingsRequest, opts ...grpc.CallOption) (*RateLimitSettingsResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RateLimitSettingsResponse)
|
||||
err := c.cc.Invoke(ctx, RateLimitService_GetRateLimitSettings_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// RateLimitServiceServer is the server API for RateLimitService service.
|
||||
// All implementations must embed UnimplementedRateLimitServiceServer
|
||||
// for forward compatibility.
|
||||
type RateLimitServiceServer interface {
|
||||
// Gets rate limit settings.
|
||||
GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error)
|
||||
mustEmbedUnimplementedRateLimitServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedRateLimitServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedRateLimitServiceServer struct{}
|
||||
|
||||
func (UnimplementedRateLimitServiceServer) GetRateLimitSettings(context.Context, *RateLimitSettingsRequest) (*RateLimitSettingsResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetRateLimitSettings not implemented")
|
||||
}
|
||||
func (UnimplementedRateLimitServiceServer) mustEmbedUnimplementedRateLimitServiceServer() {}
|
||||
func (UnimplementedRateLimitServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeRateLimitServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to RateLimitServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeRateLimitServiceServer interface {
|
||||
mustEmbedUnimplementedRateLimitServiceServer()
|
||||
}
|
||||
|
||||
func RegisterRateLimitServiceServer(s grpc.ServiceRegistrar, srv RateLimitServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedRateLimitServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&RateLimitService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _RateLimitService_GetRateLimitSettings_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RateLimitSettingsRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: RateLimitService_GetRateLimitSettings_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(RateLimitServiceServer).GetRateLimitSettings(ctx, req.(*RateLimitSettingsRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// RateLimitService_ServiceDesc is the grpc.ServiceDesc for RateLimitService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var RateLimitService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "RateLimitService",
|
||||
HandlerType: (*RateLimitServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "getRateLimitSettings",
|
||||
Handler: _RateLimitService_GetRateLimitSettings_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "dns.proto",
|
||||
}
|
||||
|
||||
const (
|
||||
RemoteKVService_Get_FullMethodName = "/RemoteKVService/get"
|
||||
RemoteKVService_Set_FullMethodName = "/RemoteKVService/set"
|
||||
)
|
||||
|
||||
// RemoteKVServiceClient is the client API for RemoteKVService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type RemoteKVServiceClient interface {
|
||||
// *
|
||||
// Get the value for the specified key.
|
||||
//
|
||||
// This method may return the following errors:
|
||||
// - AuthenticationFailedError: If the authentication failed.
|
||||
Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error)
|
||||
// *
|
||||
// Set the value for the specified key.
|
||||
//
|
||||
// This method may return the following errors:
|
||||
// - AuthenticationFailedError: If the authentication failed.
|
||||
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
|
||||
Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error)
|
||||
}
|
||||
|
||||
type remoteKVServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewRemoteKVServiceClient(cc grpc.ClientConnInterface) RemoteKVServiceClient {
|
||||
return &remoteKVServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *remoteKVServiceClient) Get(ctx context.Context, in *RemoteKVGetRequest, opts ...grpc.CallOption) (*RemoteKVGetResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RemoteKVGetResponse)
|
||||
err := c.cc.Invoke(ctx, RemoteKVService_Get_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *remoteKVServiceClient) Set(ctx context.Context, in *RemoteKVSetRequest, opts ...grpc.CallOption) (*RemoteKVSetResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RemoteKVSetResponse)
|
||||
err := c.cc.Invoke(ctx, RemoteKVService_Set_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// RemoteKVServiceServer is the server API for RemoteKVService service.
|
||||
// All implementations must embed UnimplementedRemoteKVServiceServer
|
||||
// for forward compatibility.
|
||||
type RemoteKVServiceServer interface {
|
||||
// *
|
||||
// Get the value for the specified key.
|
||||
//
|
||||
// This method may return the following errors:
|
||||
// - AuthenticationFailedError: If the authentication failed.
|
||||
Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error)
|
||||
// *
|
||||
// Set the value for the specified key.
|
||||
//
|
||||
// This method may return the following errors:
|
||||
// - AuthenticationFailedError: If the authentication failed.
|
||||
// - BadRequestError: If the request is invalid: value size exceeds the 512kb.
|
||||
Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error)
|
||||
mustEmbedUnimplementedRemoteKVServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedRemoteKVServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedRemoteKVServiceServer struct{}
|
||||
|
||||
func (UnimplementedRemoteKVServiceServer) Get(context.Context, *RemoteKVGetRequest) (*RemoteKVGetResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
|
||||
}
|
||||
func (UnimplementedRemoteKVServiceServer) Set(context.Context, *RemoteKVSetRequest) (*RemoteKVSetResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
|
||||
}
|
||||
func (UnimplementedRemoteKVServiceServer) mustEmbedUnimplementedRemoteKVServiceServer() {}
|
||||
func (UnimplementedRemoteKVServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeRemoteKVServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to RemoteKVServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeRemoteKVServiceServer interface {
|
||||
mustEmbedUnimplementedRemoteKVServiceServer()
|
||||
}
|
||||
|
||||
func RegisterRemoteKVServiceServer(s grpc.ServiceRegistrar, srv RemoteKVServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedRemoteKVServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&RemoteKVService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _RemoteKVService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RemoteKVGetRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(RemoteKVServiceServer).Get(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: RemoteKVService_Get_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(RemoteKVServiceServer).Get(ctx, req.(*RemoteKVGetRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _RemoteKVService_Set_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RemoteKVSetRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(RemoteKVServiceServer).Set(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: RemoteKVService_Set_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(RemoteKVServiceServer).Set(ctx, req.(*RemoteKVSetRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// RemoteKVService_ServiceDesc is the grpc.ServiceDesc for RemoteKVService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var RemoteKVService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "RemoteKVService",
|
||||
HandlerType: (*RemoteKVServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "get",
|
||||
Handler: _RemoteKVService_Get_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "set",
|
||||
Handler: _RemoteKVService_Set_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "dns.proto",
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage, err error) {
|
||||
return &ProfileStorage{
|
||||
bindSet: c.BindSet,
|
||||
errColl: c.ErrColl,
|
||||
client: client,
|
||||
client: NewDNSServiceClient(client),
|
||||
logger: c.Logger,
|
||||
metrics: c.Metrics,
|
||||
apiKey: c.APIKey,
|
||||
|
103
internal/backendpb/ratelimiter.go
Normal file
103
internal/backendpb/ratelimiter.go
Normal file
@ -0,0 +1,103 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
)
|
||||
|
||||
// RateLimiterConfig is the configuration structure for the business logic
|
||||
// backend rate limiter.
|
||||
type RateLimiterConfig struct {
|
||||
// Logger is used for logging the operation of the rate limiter. It must
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// GRPCMetrics is used for the collection of the protobuf errors.
|
||||
GRPCMetrics Metrics
|
||||
|
||||
// Metrics is used to collect allowlist statistics.
|
||||
Metrics consul.Metrics
|
||||
|
||||
// Allowlist is the allowlist to update.
|
||||
Allowlist *ratelimit.DynamicAllowlist
|
||||
|
||||
// ErrColl is used to collect errors during refreshes.
|
||||
ErrColl errcoll.Interface
|
||||
|
||||
// Endpoint is the backend API URL. The scheme should be either "grpc" or
|
||||
// "grpcs". It must not be nil.
|
||||
Endpoint *url.URL
|
||||
|
||||
// APIKey is the API key used for authentication, if any. If empty, no
|
||||
// authentication is performed.
|
||||
APIKey string
|
||||
}
|
||||
|
||||
// RateLimiter is the implementation of the [agdservice.Refresher] interface
|
||||
// that retrieves the rate limit settings from the business logic backend.
|
||||
type RateLimiter struct {
|
||||
logger *slog.Logger
|
||||
grpcMetrics Metrics
|
||||
metrics consul.Metrics
|
||||
allowlist *ratelimit.DynamicAllowlist
|
||||
errColl errcoll.Interface
|
||||
client RateLimitServiceClient
|
||||
apiKey string
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new properly initialized rate limiter. c must not
|
||||
// be nil.
|
||||
func NewRateLimiter(c *RateLimiterConfig) (l *RateLimiter, err error) {
|
||||
client, err := newClient(c.Endpoint)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
logger: c.Logger,
|
||||
grpcMetrics: c.GRPCMetrics,
|
||||
metrics: c.Metrics,
|
||||
allowlist: c.Allowlist,
|
||||
errColl: c.ErrColl,
|
||||
client: NewRateLimitServiceClient(client),
|
||||
apiKey: c.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agdservice.Refresher = (*RateLimiter)(nil)
|
||||
|
||||
// Refresh implements the [agdservice.Refresher] interface for *RateLimiter.
|
||||
func (l *RateLimiter) Refresh(ctx context.Context) (err error) {
|
||||
l.logger.InfoContext(ctx, "refresh started")
|
||||
defer l.logger.InfoContext(ctx, "refresh finished")
|
||||
|
||||
defer func() { l.metrics.SetStatus(ctx, err) }()
|
||||
|
||||
ctx = ctxWithAuthentication(ctx, l.apiKey)
|
||||
backendResp, err := l.client.GetRateLimitSettings(ctx, &RateLimitSettingsRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"loading backend rate limit settings: %w",
|
||||
fixGRPCError(ctx, l.grpcMetrics, err),
|
||||
)
|
||||
}
|
||||
|
||||
allowedSubnets := backendResp.AllowedSubnets
|
||||
prefixes := cidrRangeToInternal(ctx, l.errColl, allowedSubnets)
|
||||
l.allowlist.Update(prefixes)
|
||||
|
||||
l.logger.InfoContext(ctx, "refresh successful", "num_records", len(prefixes))
|
||||
|
||||
l.metrics.SetSize(ctx, len(prefixes))
|
||||
|
||||
return nil
|
||||
}
|
110
internal/backendpb/ratelimiter_test.go
Normal file
110
internal/backendpb/ratelimiter_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package backendpb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// testRateLimitServiceServer is the [backendpb.RateLimitServiceServer] for
|
||||
// tests.
|
||||
type testRateLimitServiceServer struct {
|
||||
backendpb.UnimplementedRateLimitServiceServer
|
||||
|
||||
OnGetRateLimitSettings func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RateLimitSettingsRequest,
|
||||
) (resp *backendpb.RateLimitSettingsResponse, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
|
||||
|
||||
// GetRateLimitSettings implements the [backendpb.RateLimitServiceServer]
|
||||
// interface for *testRateLimitServiceServer.
|
||||
func (s *testRateLimitServiceServer) GetRateLimitSettings(
|
||||
ctx context.Context,
|
||||
req *backendpb.RateLimitSettingsRequest,
|
||||
) (resp *backendpb.RateLimitSettingsResponse, err error) {
|
||||
return s.OnGetRateLimitSettings(ctx, req)
|
||||
}
|
||||
|
||||
func TestRateLimiter_Refresh(t *testing.T) {
|
||||
var (
|
||||
allowedIP = netip.MustParseAddr("1.2.3.4")
|
||||
notAllowedIP = netip.MustParseAddr("4.3.2.1")
|
||||
|
||||
cidr = &backendpb.CidrRange{
|
||||
Address: allowedIP.AsSlice(),
|
||||
Prefix: 32,
|
||||
}
|
||||
)
|
||||
|
||||
srv := &testRateLimitServiceServer{
|
||||
OnGetRateLimitSettings: func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RateLimitSettingsRequest,
|
||||
) (resp *backendpb.RateLimitSettingsResponse, err error) {
|
||||
return &backendpb.RateLimitSettingsResponse{
|
||||
AllowedSubnets: []*backendpb.CidrRange{cidr},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcSrv := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(1*time.Second),
|
||||
grpc.Creds(insecure.NewCredentials()),
|
||||
)
|
||||
backendpb.RegisterRateLimitServiceServer(grpcSrv, srv)
|
||||
|
||||
go func() {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
srvErr := grpcSrv.Serve(ln)
|
||||
require.NoError(pt, srvErr)
|
||||
}()
|
||||
t.Cleanup(grpcSrv.GracefulStop)
|
||||
|
||||
allowlist := ratelimit.NewDynamicAllowlist(nil, nil)
|
||||
l, err := backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Metrics: consul.EmptyMetrics{},
|
||||
GRPCMetrics: backendpb.EmptyMetrics{},
|
||||
Allowlist: allowlist,
|
||||
Endpoint: &url.URL{
|
||||
Scheme: "grpc",
|
||||
Host: ln.Addr().String(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
err = l.Refresh(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok, err := allowlist.IsAllowed(ctx, allowedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = allowlist.IsAllowed(ctx, notAllowedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, ok)
|
||||
}
|
96
internal/backendpb/remotekv.go
Normal file
96
internal/backendpb/remotekv.go
Normal file
@ -0,0 +1,96 @@
|
||||
package backendpb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
)
|
||||
|
||||
// RemoteKVConfig is the configuration for the business logic backend key-value
|
||||
// storage.
|
||||
type RemoteKVConfig struct {
|
||||
// Metrics is used for the collection of the remote key-value storage
|
||||
// statistics.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps, it worths of a separate metrics interface,
|
||||
// since it's only used for the collection of the protobuf errors.
|
||||
Metrics Metrics
|
||||
|
||||
// Endpoint is the backend API URL. The scheme should be either "grpc" or
|
||||
// "grpcs".
|
||||
Endpoint *url.URL
|
||||
|
||||
// APIKey is the API key used for authentication, if any.
|
||||
APIKey string
|
||||
|
||||
// TTL is the TTL of the values in the storage.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// RemoteKV is the implementation of the [remotekv.Interface] interface that
|
||||
// uses the business logic backend as the key-value storage. It is safe for
|
||||
// concurrent use.
|
||||
type RemoteKV struct {
|
||||
metrics Metrics
|
||||
client RemoteKVServiceClient
|
||||
apiKey string
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewRemoteKV returns a new [RemoteKV] that retrieves information from the
|
||||
// business logic backend.
|
||||
func NewRemoteKV(c *RemoteKVConfig) (kv *RemoteKV, err error) {
|
||||
client, err := newClient(c.Endpoint)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RemoteKV{
|
||||
metrics: c.Metrics,
|
||||
client: NewRemoteKVServiceClient(client),
|
||||
apiKey: c.APIKey,
|
||||
ttl: c.TTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ remotekv.Interface = (*RemoteKV)(nil)
|
||||
|
||||
// Get implements the [remotekv.Interface] interface for *RemoteKV.
|
||||
func (kv *RemoteKV) Get(ctx context.Context, key string) (val []byte, ok bool, err error) {
|
||||
req := &RemoteKVGetRequest{
|
||||
Key: key,
|
||||
}
|
||||
|
||||
ctx = ctxWithAuthentication(ctx, kv.apiKey)
|
||||
resp, err := kv.client.Get(ctx, req)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("getting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
|
||||
}
|
||||
|
||||
val = resp.GetData()
|
||||
|
||||
return val, val != nil, nil
|
||||
}
|
||||
|
||||
// Set implements the [remotekv.Interface] interface for *RemoteKV.
|
||||
func (kv *RemoteKV) Set(ctx context.Context, key string, val []byte) (err error) {
|
||||
req := &RemoteKVSetRequest{
|
||||
Key: key,
|
||||
Data: val,
|
||||
Ttl: durationpb.New(kv.ttl),
|
||||
}
|
||||
|
||||
ctx = ctxWithAuthentication(ctx, kv.apiKey)
|
||||
_, err = kv.client.Set(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting %q key: %w", key, fixGRPCError(ctx, kv.metrics, err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
106
internal/backendpb/remotekv_test.go
Normal file
106
internal/backendpb/remotekv_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
package backendpb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func TestRemoteKV_Get(t *testing.T) {
|
||||
const testTTL = 10 * time.Second
|
||||
|
||||
pt := &testutil.PanicT{}
|
||||
|
||||
strg := map[string][]byte{}
|
||||
srv := &testRemoteKVServiceServer{
|
||||
OnGet: func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVGetRequest,
|
||||
) (resp *backendpb.RemoteKVGetResponse, err error) {
|
||||
resp = &backendpb.RemoteKVGetResponse{
|
||||
Value: &backendpb.RemoteKVGetResponse_Empty{},
|
||||
}
|
||||
|
||||
if val, ok := strg[req.Key]; ok {
|
||||
resp.Value = &backendpb.RemoteKVGetResponse_Data{Data: val}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
|
||||
OnSet: func(
|
||||
ctx context.Context,
|
||||
req *backendpb.RemoteKVSetRequest,
|
||||
) (resp *backendpb.RemoteKVSetResponse, err error) {
|
||||
require.Equal(pt, testTTL, req.Ttl.AsDuration())
|
||||
|
||||
strg[req.Key] = req.Data
|
||||
|
||||
return &backendpb.RemoteKVSetResponse{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcSrv := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(1*time.Second),
|
||||
grpc.Creds(insecure.NewCredentials()),
|
||||
)
|
||||
backendpb.RegisterRemoteKVServiceServer(grpcSrv, srv)
|
||||
|
||||
go func() {
|
||||
srvErr := grpcSrv.Serve(l)
|
||||
require.NoError(pt, srvErr)
|
||||
}()
|
||||
t.Cleanup(grpcSrv.GracefulStop)
|
||||
|
||||
kv, err := backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
|
||||
Metrics: backendpb.EmptyMetrics{},
|
||||
Endpoint: &url.URL{
|
||||
Scheme: "grpc",
|
||||
Host: l.Addr().String(),
|
||||
},
|
||||
APIKey: "apikey",
|
||||
TTL: testTTL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
keyWithData = "key"
|
||||
keyNoData = "unknown"
|
||||
)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
val := []byte("value")
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
setErr := kv.Set(ctx, keyWithData, val)
|
||||
require.NoError(t, setErr)
|
||||
|
||||
gotVal, ok, getErr := kv.Get(ctx, keyWithData)
|
||||
require.NoError(t, getErr)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, val, gotVal)
|
||||
})
|
||||
|
||||
t.Run("not_found", func(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
val, ok, getErr := kv.Get(ctx, keyNoData)
|
||||
require.NoError(t, getErr)
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, val)
|
||||
})
|
||||
}
|
@ -20,23 +20,19 @@ type connIndex struct {
|
||||
}
|
||||
|
||||
// subnetCompare is a comparison function for the two subnets. It returns -1 if
|
||||
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
|
||||
// a sorts before b, 1 if a sorts after b, and 0 if their relative sorting
|
||||
// position is the same.
|
||||
func subnetCompare(x, y netip.Prefix) (cmp int) {
|
||||
if x == y {
|
||||
return 0
|
||||
}
|
||||
func subnetCompare(a, b netip.Prefix) (cmp int) {
|
||||
aAddr, aBits := a.Addr(), a.Bits()
|
||||
bAddr, bBits := b.Addr(), b.Bits()
|
||||
|
||||
xAddr, xBits := x.Addr(), x.Bits()
|
||||
yAddr, yBits := y.Addr(), y.Bits()
|
||||
if xBits == yBits {
|
||||
return xAddr.Compare(yAddr)
|
||||
}
|
||||
|
||||
if xBits > yBits {
|
||||
switch {
|
||||
case aBits > bBits:
|
||||
return -1
|
||||
} else {
|
||||
case aBits < bBits:
|
||||
return 1
|
||||
default:
|
||||
return aAddr.Compare(bAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,8 +41,8 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
|
||||
//
|
||||
// TODO(a.garipov): Merge with [addListenerChannel].
|
||||
func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
|
||||
cmpFunc := func(x, y *chanPacketConn) (cmp int) {
|
||||
return subnetCompare(x.subnet, y.subnet)
|
||||
cmpFunc := func(a, b *chanPacketConn) (cmp int) {
|
||||
return subnetCompare(a.subnet, b.subnet)
|
||||
}
|
||||
|
||||
newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc)
|
||||
@ -67,8 +63,8 @@ func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
|
||||
//
|
||||
// TODO(a.garipov): Merge with [addPacketConnChannel].
|
||||
func (idx *connIndex) addListener(l *chanListener) (err error) {
|
||||
cmpFunc := func(x, y *chanListener) (cmp int) {
|
||||
return subnetCompare(x.subnet, y.subnet)
|
||||
cmpFunc := func(a, b *chanListener) (cmp int) {
|
||||
return subnetCompare(a.subnet, b.subnet)
|
||||
}
|
||||
|
||||
newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc)
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
@ -100,8 +100,9 @@ func newBillStatUploader(
|
||||
mtrc backendpb.Metrics,
|
||||
) (s billstat.Uploader, err error) {
|
||||
apiURL := netutil.CloneURL(&envs.BillStatURL.URL)
|
||||
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) {
|
||||
return nil, fmt.Errorf("invalid backend api url: %s", apiURL)
|
||||
err = urlutil.ValidateGRPCURL(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("billstat api url: %w", err)
|
||||
}
|
||||
|
||||
return backendpb.NewBillStat(&backendpb.BillStatConfig{
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@ -14,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
@ -38,9 +38,11 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/c2h5oh/datasize"
|
||||
@ -112,6 +114,8 @@ type builder struct {
|
||||
ruleStat rulestat.Interface
|
||||
safeBrowsing *hashprefix.Filter
|
||||
safeBrowsingHashes *hashprefix.Storage
|
||||
sdeConf *dnsmsg.StructuredDNSErrorsConfig
|
||||
tlsMtrc tlsconfig.Metrics
|
||||
webSvc *websvc.Service
|
||||
|
||||
// The fields below are initialized later, just like with the fields above,
|
||||
@ -556,12 +560,42 @@ func (b *builder) initBindToDevice(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constants for the experimental Structured DNS Errors feature.
|
||||
//
|
||||
// TODO(a.garipov): Make configurable.
|
||||
const (
|
||||
sdeJustification = "Filtered by AdGuard DNS"
|
||||
sdeOrganization = "AdGuard DNS"
|
||||
)
|
||||
|
||||
// Variables for the experimental Structured DNS Errors feature.
|
||||
//
|
||||
// TODO(a.garipov): Make configurable.
|
||||
var (
|
||||
sdeContactURL = &url.URL{
|
||||
Scheme: "mailto",
|
||||
Opaque: "support@adguard-dns.io",
|
||||
}
|
||||
)
|
||||
|
||||
// initMsgConstructor initializes the common DNS message constructor.
|
||||
func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
|
||||
fltConf := b.conf.Filters
|
||||
b.sdeConf = &dnsmsg.StructuredDNSErrorsConfig{
|
||||
Contact: []*url.URL{
|
||||
sdeContactURL,
|
||||
},
|
||||
Justification: sdeJustification,
|
||||
Organization: sdeOrganization,
|
||||
Enabled: fltConf.SDEEnabled,
|
||||
}
|
||||
|
||||
b.messages, err = dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: b.cloner,
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: b.conf.Filters.ResponseTTL.Duration,
|
||||
StructuredErrors: b.sdeConf,
|
||||
FilteredResponseTTL: fltConf.ResponseTTL.Duration,
|
||||
EDEEnabled: fltConf.EDEEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating dns message constructor: %w", err)
|
||||
@ -579,8 +613,17 @@ func (b *builder) initMsgConstructor(ctx context.Context) (err error) {
|
||||
// - [builder.initFilteringGroups]
|
||||
// - [builder.initMsgConstructor]
|
||||
func (b *builder) initServerGroups(ctx context.Context) (err error) {
|
||||
mtrc, err := metrics.NewTLSConfig(b.mtrcNamespace, b.promRegisterer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registering tls metrics: %w", err)
|
||||
}
|
||||
|
||||
b.tlsMtrc = mtrc
|
||||
|
||||
c := b.conf
|
||||
b.serverGroups, err = c.ServerGroups.toInternal(
|
||||
ctx,
|
||||
mtrc,
|
||||
b.messages,
|
||||
b.btdManager,
|
||||
b.filteringGroups,
|
||||
@ -671,7 +714,7 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
tickRot := newTicketRotator(b.baseLogger, b.errColl, b.serverGroups)
|
||||
tickRot := newTicketRotator(b.baseLogger, b.errColl, b.tlsMtrc, b.serverGroups)
|
||||
err = tickRot.Refresh(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initial session ticket refresh: %w", err)
|
||||
@ -700,6 +743,29 @@ func (b *builder) initTLS(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initGRPCMetrics initializes the gRPC metrics if necessary.
|
||||
func (b *builder) initGRPCMetrics(ctx context.Context) (err error) {
|
||||
switch {
|
||||
case
|
||||
b.profilesEnabled,
|
||||
b.conf.Check.RemoteKV.Type == kvModeBackend,
|
||||
b.conf.RateLimit.Allowlist.Type == rlAllowlistTypeBackend:
|
||||
// Go on.
|
||||
default:
|
||||
// Don't initialize the metrics if no protobuf backend is used.
|
||||
return nil
|
||||
}
|
||||
|
||||
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registering backendbp metrics: %w", err)
|
||||
}
|
||||
|
||||
b.logger.DebugContext(ctx, "initialized grpc metrics")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initBillStat initializes the billing-statistics recorder if necessary. It
|
||||
// also adds the refresher with ID [debugIDBillStat] to the debug refreshers.
|
||||
func (b *builder) initBillStat(ctx context.Context) (err error) {
|
||||
@ -709,11 +775,6 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.backendGRPCMtrc, err = metrics.NewBackendPB(b.mtrcNamespace, b.promRegisterer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registering backendbp metrics: %w", err)
|
||||
}
|
||||
|
||||
upl, err := newBillStatUploader(b.env, b.errColl, b.backendGRPCMtrc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating billstat uploader: %w", err)
|
||||
@ -760,9 +821,9 @@ func (b *builder) initBillStat(ctx context.Context) (err error) {
|
||||
|
||||
// initProfileDB initializes the profile database if necessary.
|
||||
//
|
||||
// [builder.initBillStat] and [builder.initServerGroups] must be called before
|
||||
// this method. It also adds the refresher with ID [debugIDProfileDB] to the
|
||||
// debug refreshers.
|
||||
// [builder.initGRPCMetrics] and [builder.initServerGroups] must be called
|
||||
// before this method. It also adds the refresher with ID [debugIDProfileDB] to
|
||||
// the debug refreshers.
|
||||
func (b *builder) initProfileDB(ctx context.Context) (err error) {
|
||||
if !b.profilesEnabled {
|
||||
b.profileDB = &profiledb.Disabled{}
|
||||
@ -771,8 +832,9 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
|
||||
}
|
||||
|
||||
apiURL := netutil.CloneURL(&b.env.ProfilesURL.URL)
|
||||
if !agdhttp.CheckGRPCURLScheme(apiURL.Scheme) {
|
||||
return fmt.Errorf("invalid backend api url: %s", apiURL)
|
||||
err = urlutil.ValidateGRPCURL(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile api url: %w", err)
|
||||
}
|
||||
|
||||
respSzEst := b.conf.RateLimit.ResponseSizeEstimate
|
||||
@ -842,7 +904,8 @@ func (b *builder) initProfileDB(ctx context.Context) (err error) {
|
||||
|
||||
// initDNSCheck initializes the DNS checker.
|
||||
//
|
||||
// [builder.initMsgConstructor] must be called before this method.
|
||||
// [builder.initGRPCMetrics] and [builder.initMsgConstructor] must be called
|
||||
// before this method.
|
||||
func (b *builder) initDNSCheck(ctx context.Context) (err error) {
|
||||
b.dnsCheck = b.plugins.DNSCheck()
|
||||
if b.dnsCheck != nil {
|
||||
@ -853,7 +916,14 @@ func (b *builder) initDNSCheck(ctx context.Context) (err error) {
|
||||
|
||||
c := b.conf.Check
|
||||
|
||||
checkConf, err := c.toInternal(b.env, b.messages, b.errColl, b.mtrcNamespace, b.promRegisterer)
|
||||
checkConf, err := c.toInternal(
|
||||
b.env,
|
||||
b.messages,
|
||||
b.errColl,
|
||||
b.mtrcNamespace,
|
||||
b.promRegisterer,
|
||||
b.backendGRPCMtrc,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing dnscheck: %w", err)
|
||||
}
|
||||
@ -911,18 +981,44 @@ func (b *builder) initRuleStat(ctx context.Context) (err error) {
|
||||
// well as starts and registers the rate-limiter refresher in the signal
|
||||
// handler. It also adds the refresher with ID [debugIDAllowlist] to the debug
|
||||
// refreshers.
|
||||
//
|
||||
// [builder.initGRPCMetrics] must be called before this method.
|
||||
func (b *builder) initRateLimiter(ctx context.Context) (err error) {
|
||||
c := b.conf.RateLimit
|
||||
allowSubnets := netutil.UnembedPrefixes(c.Allowlist.List)
|
||||
allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil)
|
||||
updater := consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
|
||||
Logger: b.baseLogger.With(slogutil.KeyPrefix, "ratelimit_allowlist_updater"),
|
||||
Allowlist: allowlist,
|
||||
ConsulURL: &b.env.ConsulAllowlistURL.URL,
|
||||
ErrColl: b.errColl,
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Timeout: 15 * time.Second,
|
||||
})
|
||||
|
||||
typ := b.conf.RateLimit.Allowlist.Type
|
||||
mtrc, err := metrics.NewAllowlist(b.mtrcNamespace, b.promRegisterer, typ)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit metrics: %w", err)
|
||||
}
|
||||
|
||||
var updater agdservice.Refresher
|
||||
if typ == rlAllowlistTypeBackend {
|
||||
updater, err = backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
|
||||
Logger: b.baseLogger.With(slogutil.KeyPrefix, "backend_ratelimiter"),
|
||||
Metrics: mtrc,
|
||||
GRPCMetrics: b.backendGRPCMtrc,
|
||||
Allowlist: allowlist,
|
||||
Endpoint: &b.env.BackendRateLimitURL.URL,
|
||||
ErrColl: b.errColl,
|
||||
APIKey: b.env.BackendRateLimitAPIKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ratelimit: %w", err)
|
||||
}
|
||||
} else {
|
||||
updater = consul.NewAllowlistUpdater(&consul.AllowlistUpdaterConfig{
|
||||
Logger: b.baseLogger.With(slogutil.KeyPrefix, "ratelimit_allowlist_updater"),
|
||||
Allowlist: allowlist,
|
||||
ConsulURL: &b.env.ConsulAllowlistURL.URL,
|
||||
ErrColl: b.errColl,
|
||||
Metrics: mtrc,
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Timeout: 15 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
err = updater.Refresh(ctx)
|
||||
if err != nil {
|
||||
@ -956,9 +1052,11 @@ func (b *builder) initRateLimiter(ctx context.Context) (err error) {
|
||||
|
||||
// initWeb initializes the web service, starts it, and registers it in the
|
||||
// signal handler.
|
||||
//
|
||||
// [builder.initServerGroups] must be called before this method.
|
||||
func (b *builder) initWeb(ctx context.Context) (err error) {
|
||||
c := b.conf.Web
|
||||
webConf, err := c.toInternal(b.env, b.dnsCheck, b.errColl)
|
||||
webConf, err := c.toInternal(ctx, b.env, b.dnsCheck, b.errColl, b.tlsMtrc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting web configuration: %w", err)
|
||||
}
|
||||
@ -1057,45 +1155,55 @@ func (b *builder) initDNS(ctx context.Context) (err error) {
|
||||
b.fwdHandler = forward.NewHandler(b.conf.Upstream.toInternal(b.baseLogger))
|
||||
b.dnsDB = b.conf.DNSDB.toInternal(b.errColl)
|
||||
|
||||
cacheConf := b.conf.Cache
|
||||
dnsConf := &dnssvc.Config{
|
||||
dnsHdlrsConf := &dnssvc.HandlersConfig{
|
||||
BaseLogger: b.baseLogger,
|
||||
Messages: b.messages,
|
||||
Cache: b.conf.Cache.toInternal(),
|
||||
Cloner: b.cloner,
|
||||
ControlConf: b.controlConf,
|
||||
ConnLimiter: b.connLimit,
|
||||
HumanIDParser: agd.NewHumanIDParser(),
|
||||
Messages: b.messages,
|
||||
PluginRegistry: b.plugins,
|
||||
StructuredErrors: b.sdeConf,
|
||||
AccessManager: b.access,
|
||||
SafeBrowsing: b.hashMatcher,
|
||||
BillStat: b.billStat,
|
||||
CacheManager: b.cacheManager,
|
||||
ProfileDB: b.profileDB,
|
||||
PrometheusRegisterer: b.promRegisterer,
|
||||
DNSCheck: b.dnsCheck,
|
||||
NonDNS: b.webSvc,
|
||||
DNSDB: b.dnsDB,
|
||||
ErrColl: b.errColl,
|
||||
FilterStorage: b.filterStorage,
|
||||
GeoIP: b.geoIP,
|
||||
Handler: b.fwdHandler,
|
||||
HashMatcher: b.hashMatcher,
|
||||
ProfileDB: b.profileDB,
|
||||
PrometheusRegisterer: b.promRegisterer,
|
||||
QueryLog: b.queryLog(),
|
||||
RuleStat: b.ruleStat,
|
||||
RateLimit: b.rateLimit,
|
||||
RuleStat: b.ruleStat,
|
||||
MetricsNamespace: b.mtrcNamespace,
|
||||
FilteringGroups: b.filteringGroups,
|
||||
ServerGroups: b.serverGroups,
|
||||
HandleTimeout: b.conf.DNS.HandleTimeout.Duration,
|
||||
CacheSize: cacheConf.Size,
|
||||
ECSCacheSize: cacheConf.ECSSize,
|
||||
CacheMinTTL: cacheConf.TTLOverride.Min.Duration,
|
||||
UseCacheTTLOverride: cacheConf.TTLOverride.Enabled,
|
||||
UseECSCache: cacheConf.Type == cacheTypeECS,
|
||||
EDEEnabled: b.conf.Filters.EDEEnabled,
|
||||
}
|
||||
|
||||
dnsHdlrs, err := dnssvc.NewHandlers(ctx, dnsHdlrsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns handlers: %w", err)
|
||||
}
|
||||
|
||||
dnsConf := &dnssvc.Config{
|
||||
Handlers: dnsHdlrs,
|
||||
Cloner: b.cloner,
|
||||
ControlConf: b.controlConf,
|
||||
ConnLimiter: b.connLimit,
|
||||
NonDNS: b.webSvc,
|
||||
ErrColl: b.errColl,
|
||||
MetricsNamespace: b.mtrcNamespace,
|
||||
ServerGroups: b.serverGroups,
|
||||
HandleTimeout: b.conf.DNS.HandleTimeout.Duration,
|
||||
}
|
||||
|
||||
b.dnsSvc, err = dnssvc.New(dnsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing dns: %w", err)
|
||||
return fmt.Errorf("dns service: %w", err)
|
||||
}
|
||||
|
||||
b.logger.DebugContext(ctx, "initialized dns")
|
||||
|
@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
@ -43,6 +44,28 @@ const (
|
||||
cacheTypeSimple = "simple"
|
||||
)
|
||||
|
||||
// toInternal converts c to the cache configuration for the DNS server. c must
|
||||
// be valid.
|
||||
func (c *cacheConfig) toInternal() (cacheConf *dnssvc.CacheConfig) {
|
||||
var typ dnssvc.CacheType
|
||||
if c.Size == 0 {
|
||||
// TODO(a.garipov): Add as a type in the configuration file.
|
||||
typ = dnssvc.CacheTypeNone
|
||||
} else if c.Type == cacheTypeSimple {
|
||||
typ = dnssvc.CacheTypeSimple
|
||||
} else {
|
||||
typ = dnssvc.CacheTypeECS
|
||||
}
|
||||
|
||||
return &dnssvc.CacheConfig{
|
||||
MinTTL: c.TTLOverride.Min.Duration,
|
||||
ECSCount: c.ECSSize,
|
||||
NoECSCount: c.Size,
|
||||
Type: typ,
|
||||
OverrideCacheTTL: c.TTLOverride.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ validator = (*cacheConfig)(nil)
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
@ -54,8 +55,9 @@ func (c *checkConfig) toInternal(
|
||||
errColl errcoll.Interface,
|
||||
namespace string,
|
||||
reg prometheus.Registerer,
|
||||
backendMtrc backendpb.Metrics,
|
||||
) (conf *dnscheck.RemoteKVConfig, err error) {
|
||||
kv, err := newDNSCheckKV(c, envs, namespace, reg)
|
||||
kv, err := newRemoteKV(c.RemoteKV, envs, namespace, reg, backendMtrc)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
@ -85,21 +87,33 @@ const maxRespSize = 1 * datasize.MB
|
||||
// [remotekv.KeyNamespace].
|
||||
const keyNamespaceCheck = "check"
|
||||
|
||||
// newDNSCheckKV returns a new properly initialized remote key-value storage.
|
||||
func newDNSCheckKV(
|
||||
conf *checkConfig,
|
||||
// newRemoteKV returns a new properly initialized remote key-value storage.
|
||||
func newRemoteKV(
|
||||
c *remoteKVConfig,
|
||||
envs *environment,
|
||||
namespace string,
|
||||
reg prometheus.Registerer,
|
||||
backendMtrc backendpb.Metrics,
|
||||
) (kv remotekv.Interface, err error) {
|
||||
if conf.RemoteKV.Type == kvModeRedis {
|
||||
switch c.Type {
|
||||
case kvModeBackend:
|
||||
kv, err = backendpb.NewRemoteKV(&backendpb.RemoteKVConfig{
|
||||
Metrics: backendMtrc,
|
||||
Endpoint: &envs.DNSCheckRemoteKVURL.URL,
|
||||
APIKey: envs.DNSCheckRemoteKVAPIKey,
|
||||
TTL: c.TTL.Duration,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing backend dnscheck kv: %w", err)
|
||||
}
|
||||
case kvModeRedis:
|
||||
var redisKVMtrc rediskv.Metrics
|
||||
redisKVMtrc, err = metrics.NewRedisKV(namespace, reg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering redis kv metrics: %w", err)
|
||||
}
|
||||
|
||||
kv := rediskv.NewRedisKV(&rediskv.RedisKVConfig{
|
||||
redisKV := rediskv.NewRedisKV(&rediskv.RedisKVConfig{
|
||||
Metrics: redisKVMtrc,
|
||||
Addr: &netutil.HostPort{
|
||||
Host: envs.RedisAddr,
|
||||
@ -108,18 +122,20 @@ func newDNSCheckKV(
|
||||
MaxActive: envs.RedisMaxActive,
|
||||
MaxIdle: envs.RedisMaxIdle,
|
||||
IdleTimeout: envs.RedisIdleTimeout.Duration,
|
||||
TTL: conf.RemoteKV.TTL.Duration,
|
||||
TTL: c.TTL.Duration,
|
||||
})
|
||||
|
||||
return remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{
|
||||
KV: kv,
|
||||
kv = remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{
|
||||
KV: redisKV,
|
||||
Prefix: fmt.Sprintf("%s:%s:", envs.RedisKeyPrefix, keyNamespaceCheck),
|
||||
}), nil
|
||||
}
|
||||
})
|
||||
case kvModeConsul:
|
||||
consulKVURL := envs.ConsulDNSCheckKVURL
|
||||
consulSessionURL := envs.ConsulDNSCheckSessionURL
|
||||
if consulKVURL == nil || consulSessionURL == nil {
|
||||
return remotekv.Empty{}, nil
|
||||
}
|
||||
|
||||
consulKVURL := envs.ConsulDNSCheckKVURL
|
||||
consulSessionURL := envs.ConsulDNSCheckSessionURL
|
||||
if consulKVURL != nil && consulSessionURL != nil {
|
||||
kv, err = consulkv.NewKV(&consulkv.Config{
|
||||
URL: &consulKVURL.URL,
|
||||
SessionURL: &consulSessionURL.URL,
|
||||
@ -129,14 +145,14 @@ func newDNSCheckKV(
|
||||
}),
|
||||
// TODO(ameshkov): Consider making configurable.
|
||||
Limiter: rate.NewLimiter(rate.Limit(200)/60, 1),
|
||||
TTL: conf.RemoteKV.TTL.Duration,
|
||||
TTL: c.TTL.Duration,
|
||||
MaxRespSize: maxRespSize,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing consul dnscheck: %w", err)
|
||||
return nil, fmt.Errorf("initializing consul dnscheck kv: %w", err)
|
||||
}
|
||||
} else {
|
||||
kv = remotekv.Empty{}
|
||||
default:
|
||||
return remotekv.Empty{}, nil
|
||||
}
|
||||
|
||||
return kv, nil
|
||||
@ -221,15 +237,16 @@ func validateNonNilIPs(ips []netip.Addr, fam netutil.AddrFamily) (err error) {
|
||||
|
||||
// DNSCheck key-value database modes.
|
||||
const (
|
||||
kvModeConsul = "consul"
|
||||
kvModeRedis = "redis"
|
||||
kvModeBackend = "backend"
|
||||
kvModeConsul = "consul"
|
||||
kvModeRedis = "redis"
|
||||
)
|
||||
|
||||
// remoteKVConfig is remote key-value store configuration for DNS server
|
||||
// checking.
|
||||
type remoteKVConfig struct {
|
||||
// Type defines the type of remote key-value store. Allowed values are
|
||||
// [kvModeConsul] and [kvModeRedis].
|
||||
// [kvModeBackend], [kvModeConsul] and [kvModeRedis].
|
||||
Type string `yaml:"type"`
|
||||
|
||||
// TTL defines, for how long to keep the information about a single client.
|
||||
@ -248,6 +265,10 @@ func (c *remoteKVConfig) validate() (err error) {
|
||||
ttl := c.TTL
|
||||
|
||||
switch c.Type {
|
||||
case kvModeBackend:
|
||||
if ttl.Duration <= 0 {
|
||||
return newNotPositiveError("ttl", ttl)
|
||||
}
|
||||
case kvModeConsul:
|
||||
if ttl.Duration < consulkv.MinTTL || ttl.Duration > consulkv.MaxTTL {
|
||||
return fmt.Errorf(
|
||||
@ -270,7 +291,7 @@ func (c *remoteKVConfig) validate() (err error) {
|
||||
case "":
|
||||
return fmt.Errorf("type: %w", errors.ErrEmptyValue)
|
||||
default:
|
||||
return fmt.Errorf("type: %q: %w", c.Type, errors.ErrBadEnumValue)
|
||||
return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -100,6 +100,8 @@ func Main(plugins *plugin.Registry) {
|
||||
|
||||
errors.Check(b.initTLS(ctx))
|
||||
|
||||
errors.Check(b.initGRPCMetrics(ctx))
|
||||
|
||||
errors.Check(b.initBillStat(ctx))
|
||||
|
||||
errors.Check(b.initProfileDB(ctx))
|
||||
|
@ -6,9 +6,10 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
@ -25,13 +26,17 @@ import (
|
||||
)
|
||||
|
||||
// environment represents the configuration that is kept in the environment.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Name variables more consistently.
|
||||
type environment struct {
|
||||
AdultBlockingURL *urlutil.URL `env:"ADULT_BLOCKING_URL"`
|
||||
BackendRateLimitURL *urlutil.URL `env:"BACKEND_RATELIMIT_URL"`
|
||||
BillStatURL *urlutil.URL `env:"BILLSTAT_URL"`
|
||||
BlockedServiceIndexURL *urlutil.URL `env:"BLOCKED_SERVICE_INDEX_URL"`
|
||||
ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL,notEmpty"`
|
||||
ConsulAllowlistURL *urlutil.URL `env:"CONSUL_ALLOWLIST_URL"`
|
||||
ConsulDNSCheckKVURL *urlutil.URL `env:"CONSUL_DNSCHECK_KV_URL"`
|
||||
ConsulDNSCheckSessionURL *urlutil.URL `env:"CONSUL_DNSCHECK_SESSION_URL"`
|
||||
DNSCheckRemoteKVURL *urlutil.URL `env:"DNSCHECK_REMOTEKV_URL"`
|
||||
FilterIndexURL *urlutil.URL `env:"FILTER_INDEX_URL,notEmpty"`
|
||||
GeneralSafeSearchURL *urlutil.URL `env:"GENERAL_SAFE_SEARCH_URL"`
|
||||
LinkedIPTargetURL *urlutil.URL `env:"LINKED_IP_TARGET_URL"`
|
||||
@ -41,23 +46,25 @@ type environment struct {
|
||||
SafeBrowsingURL *urlutil.URL `env:"SAFE_BROWSING_URL"`
|
||||
YoutubeSafeSearchURL *urlutil.URL `env:"YOUTUBE_SAFE_SEARCH_URL"`
|
||||
|
||||
BillStatAPIKey string `env:"BILLSTAT_API_KEY"`
|
||||
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
|
||||
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
|
||||
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
|
||||
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
|
||||
ProfilesAPIKey string `env:"PROFILES_API_KEY"`
|
||||
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"`
|
||||
RedisAddr string `env:"REDIS_ADDR"`
|
||||
RedisKeyPrefix string `env:"REDIS_KEY_PREFIX" envDefault:"agdns"`
|
||||
QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"`
|
||||
SSLKeyLogFile string `env:"SSL_KEY_LOG_FILE"`
|
||||
SentryDSN string `env:"SENTRY_DSN" envDefault:"stderr"`
|
||||
WebStaticDir string `env:"WEB_STATIC_DIR"`
|
||||
BackendRateLimitAPIKey string `env:"BACKEND_RATELIMIT_API_KEY"`
|
||||
BillStatAPIKey string `env:"BILLSTAT_API_KEY"`
|
||||
ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
|
||||
DNSCheckRemoteKVAPIKey string `env:"DNSCHECK_REMOTEKV_API_KEY"`
|
||||
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
|
||||
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
|
||||
GeoIPCountryPath string `env:"GEOIP_COUNTRY_PATH" envDefault:"./country.mmdb"`
|
||||
ProfilesAPIKey string `env:"PROFILES_API_KEY"`
|
||||
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.pb"`
|
||||
RedisAddr string `env:"REDIS_ADDR"`
|
||||
RedisKeyPrefix string `env:"REDIS_KEY_PREFIX" envDefault:"agdns"`
|
||||
QueryLogPath string `env:"QUERYLOG_PATH" envDefault:"./querylog.jsonl"`
|
||||
SSLKeyLogFile string `env:"SSL_KEY_LOG_FILE"`
|
||||
SentryDSN string `env:"SENTRY_DSN" envDefault:"stderr"`
|
||||
WebStaticDir string `env:"WEB_STATIC_DIR"`
|
||||
|
||||
ListenAddr net.IP `env:"LISTEN_ADDR" envDefault:"127.0.0.1"`
|
||||
|
||||
ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"8MB"`
|
||||
ProfilesMaxRespSize datasize.ByteSize `env:"PROFILES_MAX_RESP_SIZE" envDefault:"64MB"`
|
||||
|
||||
RedisIdleTimeout timeutil.Duration `env:"REDIS_IDLE_TIMEOUT" envDefault:"30s"`
|
||||
|
||||
@ -100,7 +107,8 @@ func (envs *environment) validate() (err error) {
|
||||
|
||||
errs = envs.validateHTTPURLs(errs)
|
||||
|
||||
if s := envs.FilterIndexURL.Scheme; s != agdhttp.SchemeFile && !agdhttp.CheckHTTPURLScheme(s) {
|
||||
if s := envs.FilterIndexURL.Scheme; !strings.EqualFold(s, urlutil.SchemeFile) &&
|
||||
!urlutil.IsValidHTTPURLScheme(s) {
|
||||
errs = append(errs, fmt.Errorf(
|
||||
"env %s: not a valid http(s) url or file uri",
|
||||
"FILTER_INDEX_URL",
|
||||
@ -140,10 +148,6 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
|
||||
url: envs.BlockedServiceIndexURL,
|
||||
name: "BLOCKED_SERVICE_INDEX_URL",
|
||||
isRequired: bool(envs.BlockedServiceEnabled),
|
||||
}, {
|
||||
url: envs.ConsulAllowlistURL,
|
||||
name: "CONSUL_ALLOWLIST_URL",
|
||||
isRequired: true,
|
||||
}, {
|
||||
url: envs.ConsulDNSCheckKVURL,
|
||||
name: "CONSUL_DNSCHECK_KV_URL",
|
||||
@ -184,15 +188,14 @@ func (envs *environment) validateHTTPURLs(errs []error) (res []error) {
|
||||
continue
|
||||
}
|
||||
|
||||
u := urlData.url
|
||||
if u == nil {
|
||||
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue))
|
||||
|
||||
continue
|
||||
var u *url.URL
|
||||
if urlData.url != nil {
|
||||
u = &urlData.url.URL
|
||||
}
|
||||
|
||||
if !agdhttp.CheckHTTPURLScheme(u.Scheme) {
|
||||
res = append(res, fmt.Errorf("env %s: not a valid http(s) url", urlData.name))
|
||||
err := urlutil.ValidateHTTPURL(u)
|
||||
if err != nil {
|
||||
res = append(res, fmt.Errorf("env %s: %w", urlData.name, err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -227,59 +230,86 @@ func (envs *environment) validateWebStaticDir() (err error) {
|
||||
// validateFromValidConfig returns an error if environment variables that depend
|
||||
// on configuration properties contain errors. conf is expected to be valid.
|
||||
func (envs *environment) validateFromValidConfig(conf *configuration) (err error) {
|
||||
err = envs.validateRedis(conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if !conf.isProfilesEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if envs.ProfilesMaxRespSize > math.MaxInt {
|
||||
return fmt.Errorf(
|
||||
"PROFILES_MAX_RESP_SIZE: %w: must be less than or equal to %s, got %s",
|
||||
errors.ErrOutOfRange,
|
||||
datasize.ByteSize(math.MaxInt),
|
||||
envs.ProfilesMaxRespSize,
|
||||
)
|
||||
}
|
||||
|
||||
return envs.validateProfilesURLs()
|
||||
}
|
||||
|
||||
// validateRedis returns an error if environment variables for Redis as a remote
|
||||
// key-value store for DNS server checking contain errors.
|
||||
func (envs *environment) validateRedis(conf *configuration) (err error) {
|
||||
if conf.Check.RemoteKV.Type != kvModeRedis {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
if envs.RedisAddr == "" {
|
||||
errs = append(errs, fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue))
|
||||
|
||||
switch typ := conf.Check.RemoteKV.Type; typ {
|
||||
case kvModeRedis:
|
||||
errs = envs.validateRedis(errs)
|
||||
case kvModeBackend:
|
||||
errs = envs.validateBackendKV(errs)
|
||||
default:
|
||||
// Probably consul.
|
||||
}
|
||||
|
||||
if envs.RedisIdleTimeout.Duration <= 0 {
|
||||
errs = append(errs, newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout))
|
||||
if conf.isProfilesEnabled() {
|
||||
errs = envs.validateProfilesURLs(errs)
|
||||
|
||||
if envs.ProfilesMaxRespSize > math.MaxInt {
|
||||
errs = append(errs, fmt.Errorf(
|
||||
"PROFILES_MAX_RESP_SIZE: %w: must be less than or equal to %s, got %s",
|
||||
errors.ErrOutOfRange,
|
||||
datasize.ByteSize(math.MaxInt),
|
||||
envs.ProfilesMaxRespSize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
if envs.RedisMaxActive < 0 {
|
||||
errs = append(errs, newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive))
|
||||
}
|
||||
|
||||
if envs.RedisMaxIdle < 0 {
|
||||
errs = append(errs, newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle))
|
||||
}
|
||||
errs = envs.validateRateLimitURLs(conf, errs)
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// validateRedis appends validation errors to the given errs if environment
|
||||
// variables for Redis contain errors.
|
||||
func (envs *environment) validateRedis(errs []error) (withRedis []error) {
|
||||
withRedis = errs
|
||||
|
||||
if envs.RedisAddr == "" {
|
||||
err := fmt.Errorf("REDIS_ADDR: %q", errors.ErrEmptyValue)
|
||||
withRedis = append(withRedis, err)
|
||||
}
|
||||
|
||||
if envs.RedisIdleTimeout.Duration <= 0 {
|
||||
err := newNotPositiveError("REDIS_IDLE_TIMEOUT", envs.RedisIdleTimeout)
|
||||
withRedis = append(withRedis, err)
|
||||
}
|
||||
|
||||
if envs.RedisMaxActive < 0 {
|
||||
err := newNegativeError("REDIS_MAX_ACTIVE", envs.RedisMaxActive)
|
||||
withRedis = append(withRedis, err)
|
||||
}
|
||||
|
||||
if envs.RedisMaxIdle < 0 {
|
||||
err := newNegativeError("REDIS_MAX_IDLE", envs.RedisMaxIdle)
|
||||
withRedis = append(withRedis, err)
|
||||
}
|
||||
|
||||
return withRedis
|
||||
}
|
||||
|
||||
// validateBackendKV appends validation errors to the given errs if environment
|
||||
// variables for a backend key-value store contain errors.
|
||||
func (envs *environment) validateBackendKV(errs []error) (withKV []error) {
|
||||
withKV = errs
|
||||
|
||||
var u *url.URL
|
||||
if envs.DNSCheckRemoteKVURL != nil {
|
||||
u = &envs.DNSCheckRemoteKVURL.URL
|
||||
}
|
||||
|
||||
err := urlutil.ValidateGRPCURL(u)
|
||||
if err != nil {
|
||||
withKV = append(withKV, fmt.Errorf("env DNSCHECK_REMOTEKV_URL: %w", err))
|
||||
}
|
||||
|
||||
return withKV
|
||||
}
|
||||
|
||||
// validateProfilesURLs appends validation errors to the given errs if profiles
|
||||
// URLs in environment variables are invalid. All errors are appended to errs
|
||||
// and returned as res.
|
||||
func (envs *environment) validateProfilesURLs() (err error) {
|
||||
// URLs in environment variables are invalid.
|
||||
func (envs *environment) validateProfilesURLs(errs []error) (withURLs []error) {
|
||||
withURLs = errs
|
||||
|
||||
grpcOnlyURLs := []*urlEnvData{{
|
||||
url: envs.BillStatURL,
|
||||
name: "BILLSTAT_URL",
|
||||
@ -290,24 +320,52 @@ func (envs *environment) validateProfilesURLs() (err error) {
|
||||
isRequired: true,
|
||||
}}
|
||||
|
||||
var res []error
|
||||
for _, urlData := range grpcOnlyURLs {
|
||||
if !urlData.isRequired {
|
||||
continue
|
||||
}
|
||||
|
||||
if urlData.url == nil {
|
||||
res = append(res, fmt.Errorf("env %s: %w", urlData.name, errors.ErrEmptyValue))
|
||||
|
||||
continue
|
||||
var u *url.URL
|
||||
if urlData.url != nil {
|
||||
u = &urlData.url.URL
|
||||
}
|
||||
|
||||
if !agdhttp.CheckGRPCURLScheme(urlData.url.Scheme) {
|
||||
res = append(res, fmt.Errorf("env %s: not a valid grpc(s) url", urlData.name))
|
||||
err := urlutil.ValidateGRPCURL(u)
|
||||
if err != nil {
|
||||
withURLs = append(withURLs, fmt.Errorf("env %s: %w", urlData.name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(res...)
|
||||
return withURLs
|
||||
}
|
||||
|
||||
// validateRateLimitURLs appends validation errors to the given errs if rate
|
||||
// limit URLs in environment variables are invalid.
|
||||
func (envs *environment) validateRateLimitURLs(
|
||||
conf *configuration,
|
||||
errs []error,
|
||||
) (withURLs []error) {
|
||||
rlURL := envs.BackendRateLimitURL
|
||||
rlEnv := "BACKEND_RATELIMIT_URL"
|
||||
validateFunc := urlutil.ValidateGRPCURL
|
||||
|
||||
if conf.RateLimit.Allowlist.Type == rlAllowlistTypeConsul {
|
||||
rlURL = envs.ConsulAllowlistURL
|
||||
rlEnv = "CONSUL_ALLOWLIST_URL"
|
||||
validateFunc = urlutil.ValidateHTTPURL
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
if rlURL != nil {
|
||||
u = &rlURL.URL
|
||||
}
|
||||
|
||||
err := validateFunc(u)
|
||||
if err != nil {
|
||||
return append(errs, fmt.Errorf("env %s: %w", rlEnv, err))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// configureLogs sets the configuration for the plain text logs. It also
|
||||
|
@ -58,6 +58,12 @@ type filtersConfig struct {
|
||||
|
||||
// MaxSize is the maximum size of the downloadable filtering rule-list.
|
||||
MaxSize datasize.ByteSize `yaml:"max_size"`
|
||||
|
||||
// EDEEnabled enables the Extended DNS Errors feature.
|
||||
EDEEnabled bool `yaml:"ede_enabled"`
|
||||
|
||||
// SDEEnabled enables the experimental Structured DNS Errors feature.
|
||||
SDEEnabled bool `yaml:"sde_enabled"`
|
||||
}
|
||||
|
||||
// toInternal converts c to the filter storage configuration for the DNS server.
|
||||
@ -117,33 +123,31 @@ var _ validator = (*filtersConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *filtersConfig.
|
||||
func (c *filtersConfig) validate() (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
if c == nil {
|
||||
return errors.ErrNoValue
|
||||
case c.SafeSearchCacheSize <= 0:
|
||||
return newNotPositiveError("safe_search_cache_size", c.SafeSearchCacheSize)
|
||||
case c.ResponseTTL.Duration <= 0:
|
||||
return newNotPositiveError("response_ttl", c.ResponseTTL)
|
||||
case c.RefreshIvl.Duration <= 0:
|
||||
return newNotPositiveError("refresh_interval", c.RefreshIvl)
|
||||
case c.RefreshTimeout.Duration <= 0:
|
||||
return newNotPositiveError("refresh_timeout", c.RefreshTimeout)
|
||||
case c.IndexRefreshTimeout.Duration <= 0:
|
||||
return newNotPositiveError("index_refresh_timeout", c.IndexRefreshTimeout)
|
||||
case c.RuleListRefreshTimeout.Duration <= 0:
|
||||
return newNotPositiveError("rule_list_refresh_timeout", c.RuleListRefreshTimeout)
|
||||
case c.MaxSize <= 0:
|
||||
return newNotPositiveError("max_size", c.MaxSize)
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
errs := []error{
|
||||
validatePositive("custom_filter_cache_size", c.CustomFilterCacheSize),
|
||||
validatePositive("safe_search_cache_size", c.SafeSearchCacheSize),
|
||||
validatePositive("response_ttl", c.ResponseTTL),
|
||||
validatePositive("refresh_interval", c.RefreshIvl),
|
||||
validatePositive("refresh_timeout", c.RefreshTimeout),
|
||||
validatePositive("index_refresh_timeout", c.IndexRefreshTimeout),
|
||||
validatePositive("rule_list_refresh_timeout", c.RuleListRefreshTimeout),
|
||||
validatePositive("max_size", c.MaxSize),
|
||||
}
|
||||
|
||||
if !c.EDEEnabled && c.SDEEnabled {
|
||||
errs = append(errs, errors.Error("ede must be enabled to enable sde"))
|
||||
}
|
||||
|
||||
err = c.RuleListCache.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rule_list_cache: %w", err)
|
||||
errs = append(errs, fmt.Errorf("rule_list_cache: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// fltRuleListCache contains filtering rule-list cache configuration.
|
||||
|
@ -4,6 +4,7 @@ package plugin
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
)
|
||||
|
||||
@ -13,16 +14,19 @@ import (
|
||||
type Registry struct {
|
||||
dnscheck dnscheck.Interface
|
||||
mainMwMtrc metrics.MainMiddleware
|
||||
postInitMw dnsserver.Middleware
|
||||
}
|
||||
|
||||
// NewRegistry returns a new registry with the given custom implementations.
|
||||
func NewRegistry(
|
||||
dnsCk dnscheck.Interface,
|
||||
mainMwMtrc metrics.MainMiddleware,
|
||||
postInitMw dnsserver.Middleware,
|
||||
) (r *Registry) {
|
||||
return &Registry{
|
||||
dnscheck: dnsCk,
|
||||
mainMwMtrc: mainMwMtrc,
|
||||
postInitMw: postInitMw,
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,3 +48,13 @@ func (r *Registry) MainMiddlewareMetrics() (mainMwMtrc metrics.MainMiddleware) {
|
||||
|
||||
return r.mainMwMtrc
|
||||
}
|
||||
|
||||
// PostInitialMiddleware returns a custom implementation of the post-initial
|
||||
// middleware, if any.
|
||||
func (r *Registry) PostInitialMiddleware() (postInitMw dnsserver.Middleware) {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.postInitMw
|
||||
}
|
||||
|
@ -15,6 +15,12 @@ import (
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// Constants for rate limit settings endpoints.
|
||||
const (
|
||||
rlAllowlistTypeBackend = "backend"
|
||||
rlAllowlistTypeConsul = "consul"
|
||||
)
|
||||
|
||||
// rateLimitConfig is the configuration of the instance's rate limiting.
|
||||
type rateLimitConfig struct {
|
||||
// AllowList is the allowlist of clients.
|
||||
@ -57,20 +63,14 @@ type rateLimitConfig struct {
|
||||
RefuseANY bool `yaml:"refuse_any"`
|
||||
}
|
||||
|
||||
// allowListConfig is the consul allow list configuration.
|
||||
type allowListConfig struct {
|
||||
// List contains IPs and CIDRs.
|
||||
List []netutil.Prefix `yaml:"list"`
|
||||
|
||||
// RefreshIvl time between two updates of allow list from the Consul URL.
|
||||
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
||||
}
|
||||
|
||||
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
|
||||
// addresses.
|
||||
type rateLimitOptions struct {
|
||||
// RPS is the maximum number of requests per second.
|
||||
RPS uint `yaml:"rps"`
|
||||
// Count is the maximum number of requests per interval.
|
||||
Count uint `yaml:"count"`
|
||||
|
||||
// Interval is the time during which to count the number of requests.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||
// rate limiter bucket keys.
|
||||
@ -87,7 +87,8 @@ func (o *rateLimitOptions) validate() (err error) {
|
||||
}
|
||||
|
||||
return cmp.Or(
|
||||
validatePositive("rps", o.RPS),
|
||||
validatePositive("count", o.Count),
|
||||
validatePositive("interval", o.Interval),
|
||||
validatePositive("subnet_key_len", o.SubnetKeyLen),
|
||||
)
|
||||
}
|
||||
@ -100,9 +101,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
|
||||
ResponseSizeEstimate: c.ResponseSizeEstimate,
|
||||
Duration: c.BackoffDuration.Duration,
|
||||
Period: c.BackoffPeriod.Duration,
|
||||
IPv4RPS: c.IPv4.RPS,
|
||||
IPv4Count: c.IPv4.Count,
|
||||
IPv4Interval: c.IPv4.Interval.Duration,
|
||||
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
|
||||
IPv6RPS: c.IPv6.RPS,
|
||||
IPv6Count: c.IPv6.Count,
|
||||
IPv6Interval: c.IPv6.Interval.Duration,
|
||||
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
|
||||
Count: c.BackoffCount,
|
||||
RefuseANY: c.RefuseANY,
|
||||
@ -114,14 +117,12 @@ var _ validator = (*rateLimitConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *rateLimitConfig.
|
||||
func (c *rateLimitConfig) validate() (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
if c == nil {
|
||||
return errors.ErrNoValue
|
||||
case c.Allowlist == nil:
|
||||
return fmt.Errorf("allowlist: %w", errors.ErrNoValue)
|
||||
}
|
||||
|
||||
return cmp.Or(
|
||||
validateProp("allowlist", c.Allowlist.validate),
|
||||
validateProp("connection_limit", c.ConnectionLimit.validate),
|
||||
validateProp("ipv4", c.IPv4.validate),
|
||||
validateProp("ipv6", c.IPv6.validate),
|
||||
@ -131,10 +132,41 @@ func (c *rateLimitConfig) validate() (err error) {
|
||||
validatePositive("backoff_duration", c.BackoffDuration),
|
||||
validatePositive("backoff_period", c.BackoffPeriod),
|
||||
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
|
||||
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
|
||||
)
|
||||
}
|
||||
|
||||
// allowListConfig is the consul allow list configuration.
|
||||
type allowListConfig struct {
|
||||
// Type defines where the rate limit settings are received from. Allowed
|
||||
// values are [rlAllowlistTypeBackend] and [rlAllowlistTypeConsul].
|
||||
Type string `yaml:"type"`
|
||||
|
||||
// List contains IPs and CIDRs.
|
||||
List []netutil.Prefix `yaml:"list"`
|
||||
|
||||
// RefreshIvl time between two updates of allow list from the Consul URL.
|
||||
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ validator = (*allowListConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *allowListConfig.
|
||||
func (c *allowListConfig) validate() (err error) {
|
||||
if c == nil {
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
switch c.Type {
|
||||
case rlAllowlistTypeBackend, rlAllowlistTypeConsul:
|
||||
// Go on.
|
||||
default:
|
||||
return fmt.Errorf("type: %w: %q", errors.ErrBadEnumValue, c.Type)
|
||||
}
|
||||
|
||||
return validatePositive("refresh_interval", c.RefreshIvl)
|
||||
}
|
||||
|
||||
// connLimitConfig is the configuration structure for the stream-connection
|
||||
// limiter.
|
||||
type connLimitConfig struct {
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
@ -14,6 +14,7 @@ import (
|
||||
// toInternal returns the configuration of DNS servers for a single server
|
||||
// group. srvs and other parts of the configuration must be valid.
|
||||
func (srvs servers) toInternal(
|
||||
mtrc tlsconfig.Metrics,
|
||||
tlsConfig *agd.TLS,
|
||||
btdMgr *bindtodevice.Manager,
|
||||
ratelimitConf *rateLimitConfig,
|
||||
@ -68,8 +69,8 @@ func (srvs servers) toInternal(
|
||||
tlsConf := tlsConfig.Conf.Clone()
|
||||
|
||||
// Attach the functions that will count TLS handshake metrics.
|
||||
tlsConf.GetConfigForClient = metrics.TLSMetricsBeforeHandshake(string(srv.Protocol))
|
||||
tlsConf.VerifyConnection = metrics.TLSMetricsAfterHandshake(
|
||||
tlsConf.GetConfigForClient = mtrc.BeforeHandshake(string(srv.Protocol))
|
||||
tlsConf.VerifyConnection = mtrc.AfterHandshake(
|
||||
string(srv.Protocol),
|
||||
srv.Name,
|
||||
tlsConfig.DeviceDomains,
|
||||
|
@ -1,11 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
@ -17,6 +19,8 @@ type serverGroups []*serverGroup
|
||||
// toInternal returns the configuration for all server groups in the DNS
|
||||
// service. srvGrps and other parts of the configuration must be valid.
|
||||
func (srvGrps serverGroups) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
messages *dnsmsg.Constructor,
|
||||
btdMgr *bindtodevice.Manager,
|
||||
fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup,
|
||||
@ -32,7 +36,7 @@ func (srvGrps serverGroups) toInternal(
|
||||
}
|
||||
|
||||
var tlsConf *agd.TLS
|
||||
tlsConf, err = g.TLS.toInternal()
|
||||
tlsConf, err = g.TLS.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tls: %w", err)
|
||||
}
|
||||
@ -45,7 +49,13 @@ func (srvGrps serverGroups) toInternal(
|
||||
ProfilesEnabled: g.ProfilesEnabled,
|
||||
}
|
||||
|
||||
svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf, btdMgr, ratelimitConf, dnsConf)
|
||||
svcSrvGrps[i].Servers, err = g.Servers.toInternal(
|
||||
mtrc,
|
||||
tlsConf,
|
||||
btdMgr,
|
||||
ratelimitConf,
|
||||
dnsConf,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server group %q: %w", g.Name, err)
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
type ticketRotator struct {
|
||||
logger *slog.Logger
|
||||
errColl errcoll.Interface
|
||||
mtrc tlsconfig.Metrics
|
||||
confs map[*tls.Config][]string
|
||||
}
|
||||
|
||||
@ -29,6 +30,7 @@ type ticketRotator struct {
|
||||
func newTicketRotator(
|
||||
logger *slog.Logger,
|
||||
errColl errcoll.Interface,
|
||||
mtrc tlsconfig.Metrics,
|
||||
grps []*agd.ServerGroup,
|
||||
) (tr *ticketRotator) {
|
||||
confs := map[*tls.Config][]string{}
|
||||
@ -49,6 +51,7 @@ func newTicketRotator(
|
||||
return &ticketRotator{
|
||||
logger: logger.With(slogutil.KeyPrefix, "tickrot"),
|
||||
errColl: errColl,
|
||||
mtrc: mtrc,
|
||||
confs: confs,
|
||||
}
|
||||
}
|
||||
@ -81,7 +84,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
|
||||
var key [sessTickLen]byte
|
||||
key, err = readSessionTicketKey(fileName)
|
||||
if err != nil {
|
||||
metrics.TLSSessionTicketsRotateStatus.Set(0)
|
||||
r.mtrc.SetSessionTicketRotationStatus(ctx, false)
|
||||
|
||||
return fmt.Errorf("session ticket for srv %s: %w", conf.ServerName, err)
|
||||
}
|
||||
@ -96,8 +99,7 @@ func (r *ticketRotator) Refresh(ctx context.Context) (err error) {
|
||||
conf.SetSessionTicketKeys(keys)
|
||||
}
|
||||
|
||||
metrics.TLSSessionTicketsRotateStatus.Set(1)
|
||||
metrics.TLSSessionTicketsRotateTime.SetToCurrentTime()
|
||||
r.mtrc.SetSessionTicketRotationStatus(ctx, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
@ -9,10 +10,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// tlsConfig are the TLS settings of a DNS server, if any.
|
||||
@ -37,12 +37,15 @@ type tlsConfig struct {
|
||||
|
||||
// toInternal converts c to the TLS configuration for a DNS server. c must be
|
||||
// valid.
|
||||
func (c *tlsConfig) toInternal() (conf *agd.TLS, err error) {
|
||||
func (c *tlsConfig) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (conf *agd.TLS, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConf, err := c.Certificates.toInternal()
|
||||
tlsConf, err := c.Certificates.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("certificates: %w", err)
|
||||
}
|
||||
@ -123,7 +126,10 @@ type tlsConfigCert struct {
|
||||
type tlsConfigCerts []*tlsConfigCert
|
||||
|
||||
// toInternal converts certs to a TLS configuration. certs must be valid.
|
||||
func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
|
||||
func (certs tlsConfigCerts) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (conf *tls.Config, err error) {
|
||||
if len(certs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@ -146,13 +152,8 @@ func (certs tlsConfigCerts) toInternal() (conf *tls.Config, err error) {
|
||||
tlsCerts[i] = cert
|
||||
|
||||
authAlgo, subj := leaf.PublicKeyAlgorithm.String(), leaf.Subject.String()
|
||||
metrics.TLSCertificateInfo.With(prometheus.Labels{
|
||||
"auth_algo": authAlgo,
|
||||
"subject": subj,
|
||||
}).Set(1)
|
||||
metrics.TLSCertificateNotAfter.With(prometheus.Labels{
|
||||
"subject": subj,
|
||||
}).Set(float64(leaf.NotAfter.Unix()))
|
||||
|
||||
mtrc.SetCertificateInfo(ctx, authAlgo, subj, leaf.NotAfter)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"maps"
|
||||
@ -13,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
@ -62,9 +64,11 @@ type webConfig struct {
|
||||
// toInternal converts c to the AdGuardDNS web service configuration. c must be
|
||||
// valid.
|
||||
func (c *webConfig) toInternal(
|
||||
ctx context.Context,
|
||||
envs *environment,
|
||||
dnsCk dnscheck.Interface,
|
||||
errColl errcoll.Interface,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (conf *websvc.Config, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
@ -83,7 +87,7 @@ func (c *webConfig) toInternal(
|
||||
conf.RootRedirectURL = netutil.CloneURL(&c.RootRedirectURL.URL)
|
||||
}
|
||||
|
||||
conf.LinkedIP, err = c.LinkedIP.toInternal(envs.LinkedIPTargetURL)
|
||||
conf.LinkedIP, err = c.LinkedIP.toInternal(ctx, mtrc, envs.LinkedIPTargetURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting linked_ip: %w", err)
|
||||
}
|
||||
@ -107,7 +111,7 @@ func (c *webConfig) toInternal(
|
||||
}}
|
||||
|
||||
for _, bp := range blockPages {
|
||||
*bp.webConfPtr, err = bp.conf.toInternal()
|
||||
*bp.webConfPtr, err = bp.conf.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", bp.name, err)
|
||||
}
|
||||
@ -119,7 +123,7 @@ func (c *webConfig) toInternal(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf.NonDoHBind, err = c.NonDoHBind.toInternal()
|
||||
conf.NonDoHBind, err = c.NonDoHBind.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting non_doh_bind: %w", err)
|
||||
}
|
||||
@ -225,6 +229,8 @@ type linkedIPServer struct {
|
||||
|
||||
// toInternal converts s to a linkedIP server configuration. s must be valid.
|
||||
func (s *linkedIPServer) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
targetURL *urlutil.URL,
|
||||
) (srv *websvc.LinkedIPServer, err error) {
|
||||
if s == nil {
|
||||
@ -232,7 +238,7 @@ func (s *linkedIPServer) toInternal(
|
||||
}
|
||||
|
||||
srv = &websvc.LinkedIPServer{}
|
||||
srv.Bind, err = s.Bind.toInternal()
|
||||
srv.Bind, err = s.Bind.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting bind: %w", err)
|
||||
}
|
||||
@ -280,7 +286,10 @@ type blockPageServer struct {
|
||||
}
|
||||
|
||||
// toInternal converts s to a block page server configuration. s must be valid.
|
||||
func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err error) {
|
||||
func (s *blockPageServer) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (conf *websvc.BlockPageServerConfig, err error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@ -289,7 +298,7 @@ func (s *blockPageServer) toInternal() (conf *websvc.BlockPageServerConfig, err
|
||||
ContentFilePath: s.BlockPage,
|
||||
}
|
||||
|
||||
conf.Bind, err = s.Bind.toInternal()
|
||||
conf.Bind, err = s.Bind.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting bind: %w", err)
|
||||
}
|
||||
@ -326,11 +335,14 @@ type bindData []*bindItem
|
||||
|
||||
// toInternal converts bd to bind data for the AdGuard DNS web service. bd must
|
||||
// be valid.
|
||||
func (bd bindData) toInternal() (data []*websvc.BindData, err error) {
|
||||
func (bd bindData) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (data []*websvc.BindData, err error) {
|
||||
data = make([]*websvc.BindData, len(bd))
|
||||
|
||||
for i, d := range bd {
|
||||
data[i], err = d.toInternal()
|
||||
data[i], err = d.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
|
||||
}
|
||||
@ -369,8 +381,11 @@ type bindItem struct {
|
||||
|
||||
// toInternal converts i to bind data for the AdGuard DNS web service. i must
|
||||
// be valid.
|
||||
func (i *bindItem) toInternal() (data *websvc.BindData, err error) {
|
||||
tlsConf, err := i.Certificates.toInternal()
|
||||
func (i *bindItem) toInternal(
|
||||
ctx context.Context,
|
||||
mtrc tlsconfig.Metrics,
|
||||
) (data *websvc.BindData, err error) {
|
||||
tlsConf, err := i.Certificates.toInternal(ctx, mtrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("certificates: %w", err)
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
@ -28,6 +27,7 @@ type AllowlistUpdater struct {
|
||||
http *agdhttp.Client
|
||||
url *url.URL
|
||||
errColl errcoll.Interface
|
||||
metrics Metrics
|
||||
}
|
||||
|
||||
// AllowlistUpdaterConfig is the configuration structure for the allowlist
|
||||
@ -45,6 +45,9 @@ type AllowlistUpdaterConfig struct {
|
||||
// ErrColl is used to collect errors during refreshes.
|
||||
ErrColl errcoll.Interface
|
||||
|
||||
// Metrics is used to collect allowlist statistics.
|
||||
Metrics Metrics
|
||||
|
||||
// Timeout is the timeout for Consul queries.
|
||||
Timeout time.Duration
|
||||
}
|
||||
@ -60,6 +63,7 @@ func NewAllowlistUpdater(c *AllowlistUpdaterConfig) (upd *AllowlistUpdater) {
|
||||
}),
|
||||
url: c.ConsulURL,
|
||||
errColl: c.ErrColl,
|
||||
metrics: c.Metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,10 +76,7 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
|
||||
upd.logger.InfoContext(ctx, "refresh started")
|
||||
defer upd.logger.InfoContext(ctx, "refresh finished")
|
||||
|
||||
defer func() {
|
||||
metrics.ConsulAllowlistUpdateTime.SetToCurrentTime()
|
||||
metrics.SetStatusGauge(metrics.ConsulAllowlistUpdateStatus, err)
|
||||
}()
|
||||
defer func() { upd.metrics.SetStatus(ctx, err) }()
|
||||
|
||||
consulNets, err := upd.loadConsul(ctx)
|
||||
if err != nil {
|
||||
@ -89,13 +90,11 @@ func (upd *AllowlistUpdater) Refresh(ctx context.Context) (err error) {
|
||||
ctx,
|
||||
"refresh successful",
|
||||
"num_records", len(consulNets),
|
||||
"url", &urlutil.URL{
|
||||
URL: *upd.url,
|
||||
},
|
||||
"url", urlutil.RedactUserinfo(upd.url),
|
||||
)
|
||||
|
||||
upd.allowlist.Update(consulNets)
|
||||
metrics.ConsulAllowlistSize.Set(float64(len(consulNets)))
|
||||
upd.metrics.SetSize(ctx, len(consulNets))
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -107,7 +106,7 @@ type consulRecord struct {
|
||||
|
||||
// loadConsul fetches, decodes, and returns the list of IP networks from consul.
|
||||
func (upd *AllowlistUpdater) loadConsul(ctx context.Context) (nets []netip.Prefix, err error) {
|
||||
defer func() { err = errors.Annotate(err, "loading allowlist nets from %s: %w", upd.url) }()
|
||||
defer func() { err = errors.Annotate(err, "loading allowlist nets: %w") }()
|
||||
|
||||
httpResp, err := upd.http.Get(ctx, upd.url)
|
||||
if err != nil {
|
||||
|
@ -83,6 +83,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
|
||||
Allowlist: al,
|
||||
ConsulURL: u,
|
||||
ErrColl: agdtest.NewErrorCollector(),
|
||||
Metrics: consul.EmptyMetrics{},
|
||||
Timeout: testTimeout,
|
||||
})
|
||||
|
||||
@ -128,6 +129,7 @@ func TestNewAllowlistUpdater(t *testing.T) {
|
||||
Allowlist: al,
|
||||
ConsulURL: u,
|
||||
ErrColl: errColl,
|
||||
Metrics: consul.EmptyMetrics{},
|
||||
Timeout: testTimeout,
|
||||
})
|
||||
|
||||
@ -161,6 +163,7 @@ func TestAllowlistUpdater_Refresh_deadline(t *testing.T) {
|
||||
Allowlist: al,
|
||||
ConsulURL: u,
|
||||
ErrColl: errColl,
|
||||
Metrics: consul.EmptyMetrics{},
|
||||
Timeout: testTimeout,
|
||||
})
|
||||
|
||||
|
26
internal/consul/metrics.go
Normal file
26
internal/consul/metrics.go
Normal file
@ -0,0 +1,26 @@
|
||||
package consul
|
||||
|
||||
import "context"
|
||||
|
||||
// Metrics is an interface that is used for the collection of the allowlist
|
||||
// statistics.
|
||||
type Metrics interface {
|
||||
// SetSize sets the number of received subnets.
|
||||
SetSize(ctx context.Context, n int)
|
||||
|
||||
// SetStatus sets the status and time of the allowlist refresh attempt.
|
||||
SetStatus(ctx context.Context, err error)
|
||||
}
|
||||
|
||||
// EmptyMetrics is the implementation of the [Metrics] interface that does
|
||||
// nothing.
|
||||
type EmptyMetrics struct{}
|
||||
|
||||
// type check
|
||||
var _ Metrics = EmptyMetrics{}
|
||||
|
||||
// SetSize implements the [Metrics] interface for EmptyMetrics.
|
||||
func (EmptyMetrics) SetSize(_ context.Context, _ int) {}
|
||||
|
||||
// SetStatus plements the [Metrics] interface for EmptyMetrics.
|
||||
func (EmptyMetrics) SetStatus(_ context.Context, _ error) {}
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/debugsvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -92,7 +93,7 @@ func TestService_Start(t *testing.T) {
|
||||
})
|
||||
|
||||
srvURL := &url.URL{
|
||||
Scheme: agdhttp.SchemeHTTP,
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr,
|
||||
}
|
||||
|
||||
|
@ -201,18 +201,21 @@ func (cc *RemoteKV) newInfo(ri *agd.RequestInfo) (inf *info) {
|
||||
}
|
||||
|
||||
// resp returns the corresponding response.
|
||||
//
|
||||
// TODO(e.burkov): Inspect the reason for using different message constructors
|
||||
// for different DNS types, and consider using only one of them.
|
||||
func (cc *RemoteKV) resp(ri *agd.RequestInfo, req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
qt := ri.QType
|
||||
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return ri.Messages.NewMsgNODATA(req), nil
|
||||
return ri.Messages.NewRespRCode(req, dns.RcodeSuccess), nil
|
||||
}
|
||||
|
||||
if qt == dns.TypeA {
|
||||
return cc.messages.NewIPRespMsg(req, cc.ipv4...)
|
||||
return cc.messages.NewRespIP(req, cc.ipv4...)
|
||||
}
|
||||
|
||||
return cc.messages.NewIPRespMsg(req, cc.ipv6...)
|
||||
return cc.messages.NewRespIP(req, cc.ipv6...)
|
||||
}
|
||||
|
||||
// type check
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -86,7 +87,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
|
||||
|
||||
t.Run("hit", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, (&url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: randomid + "-" + checkDomain,
|
||||
Path: "/dnscheck/test",
|
||||
}).String(), strings.NewReader(""))
|
||||
@ -103,7 +104,7 @@ func TestConsul_ServeHTTP(t *testing.T) {
|
||||
|
||||
t.Run("miss", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, (&url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: "non" + randomid + "-" + checkDomain,
|
||||
Path: "/dnscheck/test",
|
||||
}).String(), strings.NewReader(""))
|
||||
|
@ -18,13 +18,14 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefault_ServeHTTP(t *testing.T) {
|
||||
const dname = "some-domain.name"
|
||||
const domain = "domain.example"
|
||||
|
||||
testIP := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
@ -39,7 +40,7 @@ func TestDefault_ServeHTTP(t *testing.T) {
|
||||
rcode,
|
||||
dnsservertest.NewReq(name, qtype, dns.ClassINET),
|
||||
dnsservertest.SectionAnswer{
|
||||
dnsservertest.NewA(dname, 0, testIP),
|
||||
dnsservertest.NewA(domain, 0, testIP),
|
||||
},
|
||||
)
|
||||
}
|
||||
@ -52,38 +53,38 @@ func TestDefault_ServeHTTP(t *testing.T) {
|
||||
}{{
|
||||
name: "single",
|
||||
msgs: []*dns.Msg{
|
||||
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
|
||||
},
|
||||
wantHdr: successHdr,
|
||||
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`)},
|
||||
wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,1`)},
|
||||
}, {
|
||||
name: "existing",
|
||||
msgs: []*dns.Msg{
|
||||
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
|
||||
},
|
||||
wantHdr: successHdr,
|
||||
wantResp: [][]byte{[]byte(dname + `,A,NOERROR,` + testIP.String() + `,2`)},
|
||||
wantResp: [][]byte{[]byte(domain + `,A,NOERROR,` + testIP.String() + `,2`)},
|
||||
}, {
|
||||
name: "different",
|
||||
msgs: []*dns.Msg{
|
||||
newMsg(dns.RcodeSuccess, dname, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, "sub."+dname, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, domain, dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, "sub."+domain, dns.TypeA),
|
||||
},
|
||||
wantHdr: successHdr,
|
||||
wantResp: [][]byte{
|
||||
[]byte("sub." + dname + `,A,NOERROR,` + testIP.String() + `,1`),
|
||||
[]byte(dname + `,A,NOERROR,` + testIP.String() + `,1`),
|
||||
[]byte("sub." + domain + `,A,NOERROR,` + testIP.String() + `,1`),
|
||||
[]byte(domain + `,A,NOERROR,` + testIP.String() + `,1`),
|
||||
},
|
||||
}, {
|
||||
name: "non-recordable",
|
||||
msgs: []*dns.Msg{
|
||||
// Not NOERROR.
|
||||
newMsg(dns.RcodeBadName, dname, dns.TypeA),
|
||||
newMsg(dns.RcodeBadName, domain, dns.TypeA),
|
||||
// Not A/AAAA.
|
||||
newMsg(dns.RcodeSuccess, dname, dns.TypeSRV),
|
||||
newMsg(dns.RcodeSuccess, domain, dns.TypeSRV),
|
||||
// Android metrics.
|
||||
newMsg(dns.RcodeSuccess, dname+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA),
|
||||
newMsg(dns.RcodeSuccess, domain+"-dnsotls-ds.metric.gstatic.com.", dns.TypeA),
|
||||
},
|
||||
wantHdr: successHdr,
|
||||
wantResp: [][]byte{},
|
||||
@ -109,7 +110,10 @@ func TestDefault_ServeHTTP(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
(&url.URL{Scheme: "http", Host: "example.com"}).String(),
|
||||
(&url.URL{
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: "dnsdb.example",
|
||||
}).String(),
|
||||
nil,
|
||||
)
|
||||
r.Header.Add(httphdr.AcceptEncoding, agdhttp.HdrValGzip)
|
||||
|
@ -15,6 +15,11 @@ type ConstructorConfig struct {
|
||||
// Cloner used to clone DNS messages. It must not be nil.
|
||||
Cloner *Cloner
|
||||
|
||||
// StructuredErrors is the configuration for the experimental Structured DNS
|
||||
// Errors feature. It must not be nil. If enabled,
|
||||
// [ConstructorConfig.Enabled] should also be true.
|
||||
StructuredErrors *StructuredDNSErrorsConfig
|
||||
|
||||
// BlockingMode is the blocking mode to use in
|
||||
// [Constructor.NewBlockedRespMsg]. It must not be nil.
|
||||
BlockingMode BlockingMode
|
||||
@ -22,6 +27,9 @@ type ConstructorConfig struct {
|
||||
// FilteredResponseTTL is the time-to-live value used for responses created
|
||||
// by this message constructor. It must be non-negative.
|
||||
FilteredResponseTTL time.Duration
|
||||
|
||||
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes.
|
||||
EDEEnabled bool
|
||||
}
|
||||
|
||||
// validate checks the configuration for errors.
|
||||
@ -33,13 +41,19 @@ func (conf *ConstructorConfig) validate() (err error) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
err = conf.StructuredErrors.validate(conf.EDEEnabled)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("structured errors: %w", err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if conf.BlockingMode == nil {
|
||||
err = fmt.Errorf("blocking mode: %w", errors.ErrNoValue)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if conf.FilteredResponseTTL < 0 {
|
||||
err = fmt.Errorf("filtered response TTL: %w", errors.ErrNegative)
|
||||
err = fmt.Errorf("filtered response ttl: %w", errors.ErrNegative)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
@ -51,7 +65,9 @@ func (conf *ConstructorConfig) validate() (err error) {
|
||||
type Constructor struct {
|
||||
cloner *Cloner
|
||||
blockingMode BlockingMode
|
||||
sde string
|
||||
fltRespTTL time.Duration
|
||||
edeEnabled bool
|
||||
}
|
||||
|
||||
// NewConstructor returns a properly initialized constructor using conf.
|
||||
@ -60,10 +76,17 @@ func NewConstructor(conf *ConstructorConfig) (c *Constructor, err error) {
|
||||
return nil, fmt.Errorf("configuration: %w", err)
|
||||
}
|
||||
|
||||
var sde string
|
||||
if sdeConf := conf.StructuredErrors; sdeConf.Enabled {
|
||||
sde = sdeConf.iJSON()
|
||||
}
|
||||
|
||||
return &Constructor{
|
||||
cloner: conf.Cloner,
|
||||
blockingMode: conf.BlockingMode,
|
||||
sde: sde,
|
||||
fltRespTTL: conf.FilteredResponseTTL,
|
||||
edeEnabled: conf.EDEEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -72,156 +95,6 @@ func (c *Constructor) Cloner() (cloner *Cloner) {
|
||||
return c.cloner
|
||||
}
|
||||
|
||||
// FilteredResponseTTL returns the TTL that the constructor uses to build
|
||||
// blocked responses.
|
||||
func (c *Constructor) FilteredResponseTTL() (ttl time.Duration) {
|
||||
return c.fltRespTTL
|
||||
}
|
||||
|
||||
// NewBlockedRespMsg returns a blocked DNS response message based on the
|
||||
// constructor's blocking mode.
|
||||
func (c *Constructor) NewBlockedRespMsg(req *dns.Msg) (msg *dns.Msg, err error) {
|
||||
switch m := c.blockingMode.(type) {
|
||||
case *BlockingModeCustomIP:
|
||||
return c.newBlockedCustomIPRespMsg(req, m)
|
||||
case *BlockingModeNullIP:
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
return c.NewIPRespMsg(req, netip.Addr{})
|
||||
default:
|
||||
return c.NewMsgNODATA(req), nil
|
||||
}
|
||||
case *BlockingModeNXDOMAIN:
|
||||
return c.NewMsgNXDOMAIN(req), nil
|
||||
case *BlockingModeREFUSED:
|
||||
return c.NewMsgREFUSED(req), nil
|
||||
default:
|
||||
// Consider unhandled sum type members as unrecoverable programmer
|
||||
// errors.
|
||||
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
|
||||
}
|
||||
}
|
||||
|
||||
// newBlockedCustomIPRespMsg returns a blocked DNS response message with either
|
||||
// the custom IPs from the blocking mode options or a NODATA one.
|
||||
func (c *Constructor) newBlockedCustomIPRespMsg(
|
||||
req *dns.Msg,
|
||||
m *BlockingModeCustomIP,
|
||||
) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
if len(m.IPv4) > 0 {
|
||||
return c.NewIPRespMsg(req, m.IPv4...)
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if len(m.IPv6) > 0 {
|
||||
return c.NewIPRespMsg(req, m.IPv6...)
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
return c.NewMsgNODATA(req), nil
|
||||
}
|
||||
|
||||
// NewIPRespMsg returns a DNS A or AAAA response message with the given IP
|
||||
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
|
||||
// null) IP. The TTL is also set to c.FilteredResponseTTL.
|
||||
func (c *Constructor) NewIPRespMsg(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
return c.newMsgA(req, ips...)
|
||||
case dns.TypeAAAA:
|
||||
return c.newMsgAAAA(req, ips...)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
|
||||
}
|
||||
}
|
||||
|
||||
// NewCNAMEWithIPs generates a filtered response to req with CNAME record and
|
||||
// provided ips. cname is the fully-qualified name and must not be empty, ips
|
||||
// must be of the same family.
|
||||
func (c *Constructor) NewCNAMEWithIPs(
|
||||
req *dns.Msg,
|
||||
cname string,
|
||||
ips ...netip.Addr,
|
||||
) (resp *dns.Msg, err error) {
|
||||
resp = c.NewRespMsg(req)
|
||||
|
||||
resp.Answer = make([]dns.RR, 0, len(ips)+1)
|
||||
resp.Answer = append(resp.Answer, c.NewAnswerCNAME(req, cname))
|
||||
|
||||
var ans dns.RR
|
||||
for i, ip := range ips {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
ans, err = c.NewAnswerA(cname, ip)
|
||||
case dns.TypeAAAA:
|
||||
ans, err = c.NewAnswerAAAA(cname, ip)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad ip at idx %d: %w", i, err)
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// NewMsgFORMERR returns a properly initialized FORMERR response.
|
||||
func (c *Constructor) NewMsgFORMERR(req *dns.Msg) (resp *dns.Msg) {
|
||||
return c.newMsgRCode(req, dns.RcodeFormatError)
|
||||
}
|
||||
|
||||
// NewMsgNXDOMAIN returns a properly initialized NXDOMAIN response.
|
||||
func (c *Constructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
|
||||
return c.newMsgRCode(req, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
// NewMsgREFUSED returns a properly initialized REFUSED response.
|
||||
func (c *Constructor) NewMsgREFUSED(req *dns.Msg) (resp *dns.Msg) {
|
||||
return c.newMsgRCode(req, dns.RcodeRefused)
|
||||
}
|
||||
|
||||
// NewMsgSERVFAIL returns a properly initialized SERVFAIL response.
|
||||
func (c *Constructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
|
||||
return c.newMsgRCode(req, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
// NewMsgNODATA returns a properly initialized NODATA response.
|
||||
//
|
||||
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (c *Constructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
return c.newMsgRCode(req, dns.RcodeSuccess)
|
||||
}
|
||||
|
||||
// newMsgRCode returns a properly initialized response with the given RCode.
|
||||
func (c *Constructor) newMsgRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
|
||||
resp = (&dns.Msg{}).SetRcode(req, int(rc))
|
||||
resp.Ns = c.newSOARecords(req)
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewTXTRespMsg returns a DNS TXT response message with the given strings as
|
||||
// content. The TTL is also set to c.FilteredResponseTTL.
|
||||
func (c *Constructor) NewTXTRespMsg(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
|
||||
ans, err := c.NewAnswerTXT(req, strs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg = c.NewRespMsg(req)
|
||||
msg.Answer = append(msg.Answer, ans)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// AppendDebugExtra appends to response message a DNS TXT extra with CHAOS
|
||||
// class.
|
||||
func (c *Constructor) AppendDebugExtra(req, resp *dns.Msg, str string) (err error) {
|
||||
@ -386,26 +259,23 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
|
||||
zone = req.Question[0].Name
|
||||
}
|
||||
|
||||
// TODO(a.garipov): A lot of this is copied from AdGuard Home and needs
|
||||
// to be inspected and refactored.
|
||||
// TODO(a.garipov): A lot of this is copied from AdGuard Home and needs to
|
||||
// be inspected and refactored.
|
||||
soa := &dns.SOA{
|
||||
// values copied from verisign's nonexistent .com domain
|
||||
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
|
||||
// Use values from verisign's nonexistent.com domain. Their exact
|
||||
// values are not important in our use case because they are used for
|
||||
// domain transfers between primary/secondary DNS servers.
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
// copied from AdGuard DNS
|
||||
// Copied from AdGuard DNS.
|
||||
Ns: "fake-for-negative-caching.adguard.com.",
|
||||
Serial: 100500,
|
||||
// rest is request-specific
|
||||
Hdr: dns.RR_Header{
|
||||
Name: zone,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Ttl: uint32(c.fltRespTTL.Seconds()),
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
|
||||
// Rest is request-specific.
|
||||
Hdr: c.newHdrWithClass(zone, dns.TypeSOA, dns.ClassINET),
|
||||
// Zone will be appended later if it's not empty or ".".
|
||||
Mbox: "hostmaster.",
|
||||
}
|
||||
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
@ -415,25 +285,10 @@ func (c *Constructor) newSOARecords(req *dns.Msg) (soaRecs []dns.RR) {
|
||||
return []dns.RR{soa}
|
||||
}
|
||||
|
||||
// NewRespMsg creates a DNS response for req and sets all necessary flags and
|
||||
// fields. It also guarantees that req.Question will be not empty.
|
||||
func (c *Constructor) NewRespMsg(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
RecursionAvailable: true,
|
||||
},
|
||||
Compress: true,
|
||||
}
|
||||
|
||||
resp.SetReply(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// newMsgA returns a new DNS response with the given IPv4 addresses. If any IP
|
||||
// address is nil, it is replaced by an unspecified (aka null) IP, 0.0.0.0.
|
||||
func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
msg = c.NewRespMsg(req)
|
||||
msg = c.NewResp(req)
|
||||
for i, ip := range ips {
|
||||
var ans dns.RR
|
||||
ans, err = c.NewAnswerA(req.Question[0].Name, ip)
|
||||
@ -450,7 +305,7 @@ func (c *Constructor) newMsgA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, er
|
||||
// newMsgAAAA returns a new DNS response with the given IPv6 addresses. If any
|
||||
// IP address is nil, it is replaced by an unspecified (aka null) IP, [::].
|
||||
func (c *Constructor) newMsgAAAA(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
msg = c.NewRespMsg(req)
|
||||
msg = c.NewResp(req)
|
||||
for i, ip := range ips {
|
||||
var ans dns.RR
|
||||
ans, err = c.NewAnswerAAAA(req.Question[0].Name, ip)
|
||||
|
@ -1,18 +1,16 @@
|
||||
package dnsmsg_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newTXTExtra is a helper constructor of the expected extra data.
|
||||
@ -27,311 +25,88 @@ func newTXTExtra(ttl uint32, strs ...string) (extra []dns.RR) {
|
||||
}}
|
||||
}
|
||||
|
||||
// newConstructor returns a new dnsmsg.Constructor with [testFltRespTTL].
|
||||
func newConstructor(tb testing.TB) (c *dnsmsg.Constructor) {
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: dnsmsg.NewCloner(dnsmsg.EmptyClonerStat{}),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
|
||||
return msgs
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedRespMsg_nullIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := newConstructor(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantAnsNum int
|
||||
qt dnsmsg.RRType
|
||||
}{{
|
||||
name: "a",
|
||||
wantAnsNum: 1,
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa",
|
||||
wantAnsNum: 1,
|
||||
qt: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "txt",
|
||||
wantAnsNum: 0,
|
||||
qt: dns.TypeTXT,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET)
|
||||
|
||||
resp, respErr := msgs.NewBlockedRespMsg(req)
|
||||
require.NoError(t, respErr)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
if tc.wantAnsNum == 0 {
|
||||
assert.Empty(t, resp.Answer)
|
||||
|
||||
require.Len(t, resp.Ns, 1)
|
||||
|
||||
nsTTL := resp.Ns[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
|
||||
} else {
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
ansTTL := resp.Answer[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedRespMsg_customIP(t *testing.T) {
|
||||
func TestNewConstructor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cloner := agdtest.NewCloner()
|
||||
|
||||
testCases := []struct {
|
||||
blockingMode dnsmsg.BlockingMode
|
||||
name string
|
||||
wantA bool
|
||||
wantAAAA bool
|
||||
}{{
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{testIPv4},
|
||||
IPv6: []netip.Addr{testIPv6},
|
||||
},
|
||||
name: "both",
|
||||
wantA: true,
|
||||
wantAAAA: true,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{testIPv4},
|
||||
},
|
||||
name: "ipv4_only",
|
||||
wantA: true,
|
||||
wantAAAA: false,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv6: []netip.Addr{testIPv6},
|
||||
},
|
||||
name: "ipv6_only",
|
||||
wantA: false,
|
||||
wantAAAA: true,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{},
|
||||
IPv6: []netip.Addr{},
|
||||
},
|
||||
name: "empty",
|
||||
wantA: false,
|
||||
wantAAAA: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: tc.blockingMode,
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reqA := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
|
||||
respA, err := msgs.NewBlockedRespMsg(reqA)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, respA.Rcode)
|
||||
|
||||
if tc.wantA {
|
||||
require.Len(t, respA.Answer, 1)
|
||||
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, respA.Answer[0])
|
||||
assert.Equal(t, net.IP(testIPv4.AsSlice()), a.A)
|
||||
} else {
|
||||
assert.Empty(t, respA.Answer)
|
||||
}
|
||||
|
||||
reqAAAA := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET)
|
||||
respAAAA, err := msgs.NewBlockedRespMsg(reqAAAA)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, respAAAA.Rcode)
|
||||
|
||||
if tc.wantAAAA {
|
||||
require.Len(t, respAAAA.Answer, 1)
|
||||
|
||||
aaaa := testutil.RequireTypeAssert[*dns.AAAA](t, respAAAA.Answer[0])
|
||||
assert.Equal(t, net.IP(testIPv6.AsSlice()), aaaa.AAAA)
|
||||
} else {
|
||||
assert.Empty(t, respAAAA.Answer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedRespMsg_noAnswer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
|
||||
cloner := agdtest.NewCloner()
|
||||
|
||||
testCases := []struct {
|
||||
blockingMode dnsmsg.BlockingMode
|
||||
name string
|
||||
rcode dnsmsg.RCode
|
||||
}{{
|
||||
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
|
||||
name: "nxdomain",
|
||||
rcode: dns.RcodeNameError,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeREFUSED{},
|
||||
name: "refused",
|
||||
rcode: dns.RcodeRefused,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: tc.blockingMode,
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := msgs.NewBlockedRespMsg(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
|
||||
assert.Empty(t, resp.Answer)
|
||||
|
||||
require.Len(t, resp.Ns, 1)
|
||||
|
||||
nsTTL := resp.Ns[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_noAnswerMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := newConstructor(t)
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
|
||||
|
||||
testCases := []struct {
|
||||
method func(req *dns.Msg) (resp *dns.Msg)
|
||||
name string
|
||||
want dnsmsg.RCode
|
||||
}{{
|
||||
method: msgs.NewMsgFORMERR,
|
||||
name: "formerr",
|
||||
want: dns.RcodeFormatError,
|
||||
}, {
|
||||
method: msgs.NewMsgNXDOMAIN,
|
||||
name: "nxdomain",
|
||||
want: dns.RcodeNameError,
|
||||
}, {
|
||||
method: msgs.NewMsgREFUSED,
|
||||
name: "refused",
|
||||
want: dns.RcodeRefused,
|
||||
}, {
|
||||
method: msgs.NewMsgSERVFAIL,
|
||||
name: "servfail",
|
||||
want: dns.RcodeServerFailure,
|
||||
}, {
|
||||
method: msgs.NewMsgNODATA,
|
||||
name: "nodata",
|
||||
want: dns.RcodeSuccess,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := tc.method(req)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Ns, 1)
|
||||
|
||||
assert.Empty(t, resp.Answer)
|
||||
assert.Equal(t, tc.want, dnsmsg.RCode(resp.Rcode))
|
||||
|
||||
nsTTL := resp.Ns[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewTXTRespMsg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := newConstructor(t)
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET)
|
||||
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
|
||||
badContactURL := errors.Must(url.Parse("invalid-scheme://devteam@adguard.com"))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
conf *dnsmsg.ConstructorConfig
|
||||
wantErrMsg string
|
||||
strs []string
|
||||
}{{
|
||||
name: "success",
|
||||
name: "good",
|
||||
conf: &dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
strs: []string{"111"},
|
||||
}, {
|
||||
name: "success_many",
|
||||
wantErrMsg: "",
|
||||
strs: []string{"111", "222"},
|
||||
name: "all_bad",
|
||||
conf: &dnsmsg.ConstructorConfig{
|
||||
FilteredResponseTTL: -1,
|
||||
},
|
||||
wantErrMsg: "configuration: " +
|
||||
"cloner: no value\n" +
|
||||
"structured errors: no value\n" +
|
||||
"blocking mode: no value\n" +
|
||||
"filtered response ttl: negative value",
|
||||
}, {
|
||||
name: "success_nil",
|
||||
wantErrMsg: "",
|
||||
strs: nil,
|
||||
name: "sde_enabled",
|
||||
conf: &dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: false,
|
||||
},
|
||||
wantErrMsg: "configuration: structured errors: " +
|
||||
"ede must be enabled to enable sde",
|
||||
}, {
|
||||
name: "success_empty",
|
||||
wantErrMsg: "",
|
||||
strs: []string{},
|
||||
name: "sde_empty",
|
||||
conf: &dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
},
|
||||
wantErrMsg: "configuration: structured errors: " +
|
||||
"contact data: empty value\n" +
|
||||
"justification: empty value",
|
||||
}, {
|
||||
name: "too_long",
|
||||
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255",
|
||||
strs: []string{tooLong},
|
||||
name: "sde_bad",
|
||||
conf: &dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
StructuredErrors: &dnsmsg.StructuredDNSErrorsConfig{
|
||||
Enabled: true,
|
||||
Contact: []*url.URL{badContactURL, nil},
|
||||
Justification: "\uFFFE",
|
||||
Organization: "\uFFFE",
|
||||
},
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
},
|
||||
wantErrMsg: "configuration: structured errors: " +
|
||||
`contact data: at index 0: scheme: bad enum value: "invalid-scheme"` + "\n" +
|
||||
"contact data: at index 1: no value\n" +
|
||||
"justification: bad code point at index 0\n" +
|
||||
"organization: bad code point at index 0",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp, respErr := msgs.NewTXTRespMsg(req, tc.strs...)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr)
|
||||
|
||||
if tc.wantErrMsg != "" {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
ans := resp.Answer[0]
|
||||
ansTTL := ans.Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), ansTTL)
|
||||
|
||||
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
|
||||
assert.Equal(t, tc.strs, txt.Txt)
|
||||
_, err := dnsmsg.NewConstructor(tc.conf)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -339,7 +114,7 @@ func TestConstructor_NewTXTRespMsg(t *testing.T) {
|
||||
func TestConstructor_AppendDebugExtra(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := newConstructor(t)
|
||||
msgs := agdtest.NewConstructor(t)
|
||||
|
||||
shortText := "This is a short test text"
|
||||
longText := strings.Repeat("a", 2*dnsmsg.MaxTXTStringLen)
|
||||
|
@ -78,7 +78,9 @@ func TestClone(t *testing.T) {
|
||||
},
|
||||
name: "empty_slice_ans",
|
||||
}, {
|
||||
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET),
|
||||
msg: dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
}),
|
||||
name: "a",
|
||||
}}
|
||||
|
||||
|
@ -72,8 +72,7 @@ func (c *optCloner) clone(rr *dns.OPT) (clone *dns.OPT, full bool) {
|
||||
optClone = opt
|
||||
case *dns.EDNS0_EDE:
|
||||
opt := c.ede.Get()
|
||||
opt.InfoCode = orig.InfoCode
|
||||
opt.ExtraText = orig.ExtraText
|
||||
*opt = *orig
|
||||
|
||||
optClone = opt
|
||||
case *dns.EDNS0_SUBNET:
|
||||
|
228
internal/dnsmsg/response.go
Normal file
228
internal/dnsmsg/response.go
Normal file
@ -0,0 +1,228 @@
|
||||
package dnsmsg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// NewResp creates a response DNS message for req and sets all necessary flags
|
||||
// and fields. resp contains no resource records.
|
||||
func (c *Constructor) NewResp(req *dns.Msg) (resp *dns.Msg) {
|
||||
return (&dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
RecursionAvailable: true,
|
||||
},
|
||||
Compress: true,
|
||||
}).SetReply(req)
|
||||
}
|
||||
|
||||
// NewBlockedResp returns a blocked response DNS message based on the
|
||||
// constructor's blocking mode.
|
||||
func (c *Constructor) NewBlockedResp(req *dns.Msg) (msg *dns.Msg, err error) {
|
||||
switch m := c.blockingMode.(type) {
|
||||
case *BlockingModeCustomIP:
|
||||
return c.newBlockedCustomIPResp(req, m)
|
||||
case *BlockingModeNullIP:
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
return c.NewBlockedNullIPResp(req)
|
||||
default:
|
||||
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
|
||||
msg.Ns = c.newSOARecords(req)
|
||||
}
|
||||
case *BlockingModeNXDOMAIN:
|
||||
msg = c.NewBlockedRespRCode(req, dns.RcodeNameError)
|
||||
msg.Ns = c.newSOARecords(req)
|
||||
case *BlockingModeREFUSED:
|
||||
msg = c.NewBlockedRespRCode(req, dns.RcodeRefused)
|
||||
msg.Ns = c.newSOARecords(req)
|
||||
default:
|
||||
// Consider unhandled sum type members as unrecoverable programmer
|
||||
// errors.
|
||||
panic(fmt.Errorf("unexpected type %T", c.blockingMode))
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// NewRespRCode returns a response DNS message with given response code and a
|
||||
// predefined authority section.
|
||||
//
|
||||
// Use [dns.RcodeSuccess] for a proper NODATA response, see
|
||||
// https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (c *Constructor) NewRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
|
||||
resp = c.NewResp(req)
|
||||
resp.Rcode = int(rc)
|
||||
|
||||
resp.Ns = c.newSOARecords(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewBlockedRespRCode returns a blocked response DNS message with given
|
||||
// response code.
|
||||
//
|
||||
// TODO(e.burkov): Add SOA records to the response, like in
|
||||
// [Constructor.NewRespRCode].
|
||||
func (c *Constructor) NewBlockedRespRCode(req *dns.Msg, rc RCode) (resp *dns.Msg) {
|
||||
resp = c.NewResp(req)
|
||||
resp.Rcode = int(rc)
|
||||
|
||||
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewRespTXT returns a DNS TXT response message with the given strings as
|
||||
// content. The TTL of the TXT answer is set to c.FilteredResponseTTL.
|
||||
func (c *Constructor) NewRespTXT(req *dns.Msg, strs ...string) (msg *dns.Msg, err error) {
|
||||
ans, err := c.NewAnswerTXT(req, strs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg = c.NewResp(req)
|
||||
msg.Answer = append(msg.Answer, ans)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// NewRespIP returns an A or AAAA DNS response message with the given IP
|
||||
// addresses. If any IP address is nil, it is replaced by an unspecified (aka
|
||||
// null) IP. The TTL is also set to c.FilteredResponseTTL.
|
||||
func (c *Constructor) NewRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
return c.newMsgA(req, ips...)
|
||||
case dns.TypeAAAA:
|
||||
return c.newMsgAAAA(req, ips...)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad qtype for a or aaaa resp: %d", qt)
|
||||
}
|
||||
}
|
||||
|
||||
// NewBlockedRespIP returns an A or AAAA DNS response message with the given IP
|
||||
// addresses. The TTL of each record is set to c.FilteredResponseTTL. ips
|
||||
// should not contain zero values due to the extended error code semantics, use
|
||||
// [NewBlockedNullIPResp] for this case.
|
||||
//
|
||||
// TODO(a.garipov): Consider merging with [NewRespIP] if AddEDE with the Forged
|
||||
// Answer code isn't used again.
|
||||
func (c *Constructor) NewBlockedRespIP(req *dns.Msg, ips ...netip.Addr) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
msg, err = c.newMsgA(req, ips...)
|
||||
case dns.TypeAAAA:
|
||||
msg, err = c.newMsgAAAA(req, ips...)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad qtype for an ip resp: %d", qt)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// NewBlockedNullIPResp returns a blocked A or AAAA DNS response message with an
|
||||
// unspecified (aka null) IP address. The TTL of the record is set to the
|
||||
// constructor's FilteredResponseTTL.
|
||||
func (c *Constructor) NewBlockedNullIPResp(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
resp, err = c.newMsgA(req, netip.Addr{})
|
||||
case dns.TypeAAAA:
|
||||
resp, err = c.newMsgAAAA(req, netip.Addr{})
|
||||
default:
|
||||
err = fmt.Errorf("bad qtype for an ip resp: %d", qt)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.AddEDE(req, resp, dns.ExtendedErrorCodeFiltered)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// AddEDE adds an Extended DNS Error (EDE) option to the blocked response
|
||||
// message, if the feature is enabled in the Constructor and the request
|
||||
// indicates EDNS support. It does not overwrite EDE if there already is one.
|
||||
// req and resp must not be nil.
|
||||
func (c *Constructor) AddEDE(req, resp *dns.Msg, code uint16) {
|
||||
if !c.edeEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
reqOpt := req.IsEdns0()
|
||||
if reqOpt == nil {
|
||||
// Requestor doesn't implement EDNS, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
|
||||
return
|
||||
}
|
||||
|
||||
respOpt := resp.IsEdns0()
|
||||
if respOpt == nil {
|
||||
respOpt = newOPT(c.cloner, reqOpt.UDPSize(), reqOpt.Do())
|
||||
resp.Extra = append(resp.Extra, respOpt)
|
||||
} else if findEDE(respOpt) != nil {
|
||||
// Do not add an EDE option if there already is one.
|
||||
return
|
||||
}
|
||||
|
||||
sdeText := c.sdeForReqOpt(reqOpt)
|
||||
|
||||
respOpt.Option = append(respOpt.Option, newEDNS0EDE(c.cloner, code, sdeText))
|
||||
}
|
||||
|
||||
// findEDE returns the EDE option if there is one. opt must not be nil.
|
||||
func findEDE(opt *dns.OPT) (ede *dns.EDNS0_EDE) {
|
||||
for _, o := range opt.Option {
|
||||
var ok bool
|
||||
if ede, ok = o.(*dns.EDNS0_EDE); ok {
|
||||
return ede
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sdeForReqOpt returns either the configured SDE text or empty string depending
|
||||
// on the request's EDNS options.
|
||||
func (c *Constructor) sdeForReqOpt(reqOpt *dns.OPT) (sde string) {
|
||||
ede := findEDE(reqOpt)
|
||||
if ede != nil && ede.InfoCode == 0 && ede.ExtraText == "" {
|
||||
return c.sde
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// newBlockedCustomIPResp returns a blocked DNS response message with either the
|
||||
// custom IPs from the blocking mode options or a NODATA one.
|
||||
func (c *Constructor) newBlockedCustomIPResp(
|
||||
req *dns.Msg,
|
||||
m *BlockingModeCustomIP,
|
||||
) (msg *dns.Msg, err error) {
|
||||
switch qt := req.Question[0].Qtype; qt {
|
||||
case dns.TypeA:
|
||||
if len(m.IPv4) > 0 {
|
||||
return c.NewBlockedRespIP(req, m.IPv4...)
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if len(m.IPv6) > 0 {
|
||||
return c.NewBlockedRespIP(req, m.IPv6...)
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
msg = c.NewBlockedRespRCode(req, dns.RcodeSuccess)
|
||||
msg.Ns = c.newSOARecords(req)
|
||||
|
||||
return msg, nil
|
||||
}
|
416
internal/dnsmsg/response_test.go
Normal file
416
internal/dnsmsg/response_test.go
Normal file
@ -0,0 +1,416 @@
|
||||
package dnsmsg_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConstructor_NewBlockedResp_nullIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := agdtest.NewConstructor(t)
|
||||
reqExtra := dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
}
|
||||
|
||||
filteredSDE := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
|
||||
InfoCode: dns.ExtendedErrorCodeFiltered,
|
||||
ExtraText: agdtest.SDEText,
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantAns []dns.RR
|
||||
wantExtra []dns.RR
|
||||
qt dnsmsg.RRType
|
||||
}{{
|
||||
name: "a",
|
||||
wantAns: []dns.RR{dnsservertest.NewA(
|
||||
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified(),
|
||||
)},
|
||||
wantExtra: []dns.RR{filteredSDE},
|
||||
qt: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa",
|
||||
wantAns: []dns.RR{dnsservertest.NewAAAA(
|
||||
testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv6Unspecified(),
|
||||
)},
|
||||
wantExtra: []dns.RR{filteredSDE},
|
||||
qt: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "txt",
|
||||
wantAns: nil,
|
||||
wantExtra: []dns.RR{filteredSDE},
|
||||
qt: dns.TypeTXT,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, tc.qt, dns.ClassINET, reqExtra)
|
||||
|
||||
resp, respErr := msgs.NewBlockedResp(req)
|
||||
require.NoError(t, respErr)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Equal(t, tc.wantAns, resp.Answer)
|
||||
assert.Equal(t, tc.wantExtra, resp.Extra)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedResp_customIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cloner := agdtest.NewCloner()
|
||||
|
||||
// TODO(a.garipov): Test the forged extra as well if the EDE with that code
|
||||
// is used again.
|
||||
reqExtra := dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
}
|
||||
filteredExtra := dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
|
||||
InfoCode: dns.ExtendedErrorCodeFiltered,
|
||||
ExtraText: agdtest.SDEText,
|
||||
})
|
||||
|
||||
ansA := dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv4)
|
||||
ansAAAA := dnsservertest.NewAAAA(testFQDN, agdtest.FilteredResponseTTLSec, testIPv6)
|
||||
|
||||
testCases := []struct {
|
||||
blockingMode dnsmsg.BlockingMode
|
||||
name string
|
||||
wantAnsA []dns.RR
|
||||
wantAnsAAAA []dns.RR
|
||||
wantExtraA []dns.RR
|
||||
wantExtraAAAA []dns.RR
|
||||
}{{
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{testIPv4},
|
||||
IPv6: []netip.Addr{testIPv6},
|
||||
},
|
||||
name: "both",
|
||||
wantAnsA: []dns.RR{ansA},
|
||||
wantAnsAAAA: []dns.RR{ansAAAA},
|
||||
wantExtraA: nil,
|
||||
wantExtraAAAA: nil,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{testIPv4},
|
||||
},
|
||||
name: "ipv4_only",
|
||||
wantAnsA: []dns.RR{ansA},
|
||||
wantAnsAAAA: nil,
|
||||
wantExtraA: nil,
|
||||
wantExtraAAAA: []dns.RR{filteredExtra},
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv6: []netip.Addr{testIPv6},
|
||||
},
|
||||
name: "ipv6_only",
|
||||
wantAnsA: nil,
|
||||
wantAnsAAAA: []dns.RR{ansAAAA},
|
||||
wantExtraA: []dns.RR{filteredExtra},
|
||||
wantExtraAAAA: nil,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeCustomIP{
|
||||
IPv4: []netip.Addr{},
|
||||
IPv6: []netip.Addr{},
|
||||
},
|
||||
name: "empty",
|
||||
wantAnsA: nil,
|
||||
wantAnsAAAA: nil,
|
||||
wantExtraA: []dns.RR{filteredExtra},
|
||||
wantExtraAAAA: []dns.RR{filteredExtra},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: tc.blockingMode,
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("a", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, reqExtra)
|
||||
resp, respErr := msgs.NewBlockedResp(req)
|
||||
require.NoError(t, respErr)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Equal(t, tc.wantAnsA, resp.Answer)
|
||||
assert.Equal(t, tc.wantExtraA, resp.Extra)
|
||||
})
|
||||
|
||||
t.Run("aaaa", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeAAAA, dns.ClassINET, reqExtra)
|
||||
resp, respErr := msgs.NewBlockedResp(req)
|
||||
require.NoError(t, respErr)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Equal(t, tc.wantAnsAAAA, resp.Answer)
|
||||
assert.Equal(t, tc.wantExtraAAAA, resp.Extra)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedResp_nodata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
})
|
||||
cloner := agdtest.NewCloner()
|
||||
|
||||
wantExtra := []dns.RR{dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
|
||||
InfoCode: dns.ExtendedErrorCodeFiltered,
|
||||
ExtraText: agdtest.SDEText,
|
||||
})}
|
||||
|
||||
testCases := []struct {
|
||||
blockingMode dnsmsg.BlockingMode
|
||||
name string
|
||||
rcode dnsmsg.RCode
|
||||
}{{
|
||||
blockingMode: &dnsmsg.BlockingModeNXDOMAIN{},
|
||||
name: "nxdomain",
|
||||
rcode: dns.RcodeNameError,
|
||||
}, {
|
||||
blockingMode: &dnsmsg.BlockingModeREFUSED{},
|
||||
name: "refused",
|
||||
rcode: dns.RcodeRefused,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: tc.blockingMode,
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := msgs.NewBlockedResp(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, tc.rcode, dnsmsg.RCode(resp.Rcode))
|
||||
assert.Empty(t, resp.Answer)
|
||||
|
||||
require.Len(t, resp.Ns, 1)
|
||||
|
||||
nsTTL := resp.Ns[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
|
||||
|
||||
assert.Equal(t, wantExtra, resp.Extra)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewBlockedResp_sde(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reqEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
})
|
||||
reqNoEDNS := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET)
|
||||
|
||||
wantAns := []dns.RR{
|
||||
dnsservertest.NewA(testFQDN, agdtest.FilteredResponseTTLSec, netip.IPv4Unspecified()),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
sde *dnsmsg.StructuredDNSErrorsConfig
|
||||
name string
|
||||
wantExtra []dns.RR
|
||||
ede bool
|
||||
}{{
|
||||
req: reqEDNS,
|
||||
sde: agdtest.NewSDEConfig(true),
|
||||
name: "ede_sde",
|
||||
wantExtra: []dns.RR{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
|
||||
InfoCode: dns.ExtendedErrorCodeFiltered,
|
||||
ExtraText: agdtest.SDEText,
|
||||
}),
|
||||
},
|
||||
ede: true,
|
||||
}, {
|
||||
req: reqEDNS,
|
||||
sde: agdtest.NewSDEConfig(false),
|
||||
name: "ede_no_sde",
|
||||
wantExtra: []dns.RR{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{
|
||||
InfoCode: dns.ExtendedErrorCodeFiltered,
|
||||
}),
|
||||
},
|
||||
ede: true,
|
||||
}, {
|
||||
req: reqEDNS,
|
||||
sde: agdtest.NewSDEConfig(false),
|
||||
name: "no_ede",
|
||||
wantExtra: nil,
|
||||
ede: false,
|
||||
}, {
|
||||
req: reqNoEDNS,
|
||||
sde: agdtest.NewSDEConfig(true),
|
||||
name: "unsupported_ede_sde",
|
||||
wantExtra: nil,
|
||||
ede: true,
|
||||
}, {
|
||||
req: reqNoEDNS,
|
||||
sde: agdtest.NewSDEConfig(false),
|
||||
name: "unsupported_ede_no_sde",
|
||||
wantExtra: nil,
|
||||
ede: true,
|
||||
}, {
|
||||
req: reqNoEDNS,
|
||||
sde: agdtest.NewSDEConfig(false),
|
||||
name: "unsupported_no_ede",
|
||||
wantExtra: nil,
|
||||
ede: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: agdtest.NewCloner(),
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: tc.sde,
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: tc.ede,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := msgs.NewBlockedResp(tc.req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Equal(t, wantAns, resp.Answer)
|
||||
assert.Equal(t, tc.wantExtra, resp.Extra)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewRespRCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := agdtest.NewConstructor(t)
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
})
|
||||
|
||||
for rcode, name := range dns.RcodeToString {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := msgs.NewRespRCode(req, dnsmsg.RCode(rcode))
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Answer)
|
||||
|
||||
assert.Equal(t, rcode, resp.Rcode)
|
||||
|
||||
require.Len(t, resp.Ns, 1)
|
||||
|
||||
nsTTL := resp.Ns[0].Header().Ttl
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), nsTTL)
|
||||
|
||||
assert.Empty(t, resp.Extra)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructor_NewRespTXT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msgs := agdtest.NewConstructor(t)
|
||||
|
||||
req := dnsservertest.NewReq(testFQDN, dns.TypeTXT, dns.ClassINET, dnsservertest.SectionExtra{
|
||||
dnsservertest.NewOPT(true, dns.MaxMsgSize, &dns.EDNS0_EDE{}),
|
||||
})
|
||||
tooLong := strings.Repeat("1", dnsmsg.MaxTXTStringLen+1)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
strs []string
|
||||
}{{
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
strs: []string{"111"},
|
||||
}, {
|
||||
name: "success_many",
|
||||
wantErrMsg: "",
|
||||
strs: []string{"111", "222"},
|
||||
}, {
|
||||
name: "success_nil",
|
||||
wantErrMsg: "",
|
||||
strs: nil,
|
||||
}, {
|
||||
name: "success_empty",
|
||||
wantErrMsg: "",
|
||||
strs: []string{},
|
||||
}, {
|
||||
name: "too_long",
|
||||
wantErrMsg: "txt string at index 0: too long: got 256 bytes, max 255",
|
||||
strs: []string{tooLong},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp, respErr := msgs.NewRespTXT(req, tc.strs...)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, respErr)
|
||||
|
||||
if tc.wantErrMsg != "" {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
ans := resp.Answer[0]
|
||||
txt := testutil.RequireTypeAssert[*dns.TXT](t, ans)
|
||||
|
||||
assert.Equal(t, uint32(agdtest.FilteredResponseTTLSec), txt.Hdr.Ttl)
|
||||
assert.Equal(t, tc.strs, txt.Txt)
|
||||
|
||||
assert.Empty(t, resp.Extra)
|
||||
})
|
||||
}
|
||||
}
|
@ -140,3 +140,36 @@ func newTXT(c *Cloner, txt []string) (rr *dns.TXT) {
|
||||
|
||||
return rr
|
||||
}
|
||||
|
||||
// newOPT constructs a new resource record of type OPT, optionally using c to
|
||||
// allocate the structure.
|
||||
func newOPT(c *Cloner, udpSize uint16, doBit bool) (opt *dns.OPT) {
|
||||
if c == nil {
|
||||
opt = &dns.OPT{}
|
||||
} else {
|
||||
opt = c.opt.rr.Get()
|
||||
opt.Option = opt.Option[:0]
|
||||
}
|
||||
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
opt.SetUDPSize(udpSize)
|
||||
opt.SetDo(doBit)
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// newEDNS0EDE constructs a new resource record of type EDNS0_EDE, optionally
|
||||
// using c to allocate the structure.
|
||||
func newEDNS0EDE(c *Cloner, infoCode uint16, extraText string) (opt *dns.EDNS0_EDE) {
|
||||
if c == nil {
|
||||
opt = &dns.EDNS0_EDE{}
|
||||
} else {
|
||||
opt = c.opt.ede.Get()
|
||||
}
|
||||
|
||||
opt.InfoCode = infoCode
|
||||
opt.ExtraText = extraText
|
||||
|
||||
return opt
|
||||
}
|
||||
|
143
internal/dnsmsg/structurederror.go
Normal file
143
internal/dnsmsg/structurederror.go
Normal file
@ -0,0 +1,143 @@
|
||||
package dnsmsg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// StructuredDNSErrorsConfig is the configuration structure for the experimental
|
||||
// Structured DNS Errors feature.
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-dnsop-structured-dns-error-09.html.
|
||||
//
|
||||
// TODO(a.garipov): Add sub-error?
|
||||
type StructuredDNSErrorsConfig struct {
|
||||
// Justification for this particular DNS filtering. It must not be empty.
|
||||
Justification string
|
||||
|
||||
// Organization is an optional description of the organization.
|
||||
Organization string
|
||||
|
||||
// Contact information for the DNS service. It must not be empty. All
|
||||
// items must not be nil and must be valid mailto, sips, or tel URLs.
|
||||
Contact []*url.URL
|
||||
|
||||
// Enabled, if true, enables the experimental Structured DNS Errors feature.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// iJSON returns the I-JSON representation of this configuration. c must be
|
||||
// valid.
|
||||
func (c *StructuredDNSErrorsConfig) iJSON() (s string) {
|
||||
data := &structuredDNSErrorData{
|
||||
Justification: c.Justification,
|
||||
Organization: c.Organization,
|
||||
}
|
||||
|
||||
for _, cont := range c.Contact {
|
||||
data.Contact = append(data.Contact, cont.String())
|
||||
}
|
||||
|
||||
// The only error that could be returned here is a type error from JSON
|
||||
// encoding, and these should never happen.
|
||||
b := errors.Must(json.Marshal(data))
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// structuredDNSErrorData is the structure for the JSON representation of the
|
||||
// SDE data.
|
||||
//
|
||||
// TODO(a.garipov): Add sub-error?
|
||||
type structuredDNSErrorData struct {
|
||||
Justification string `json:"j"`
|
||||
Organization string `json:"o,omitempty"`
|
||||
Contact []string `json:"c"`
|
||||
}
|
||||
|
||||
// forbiddenRanges contains the ranges of forbidden code points for structured
|
||||
// DNS errors according to the I-JSON specification.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7493#section-2.1.
|
||||
var forbiddenRanges = []*unicode.RangeTable{unicode.Cs, unicode.Noncharacter_Code_Point}
|
||||
|
||||
// isSurrogateOrNonCharacter returns true if r is a surrogate or a non-character
|
||||
// code point.
|
||||
func isSurrogateOrNonCharacter(r rune) (ok bool) {
|
||||
return unicode.IsOneOf(forbiddenRanges, r)
|
||||
}
|
||||
|
||||
// validateSDEString returns an error if s contains a surrogate or a
|
||||
// non-character code point. It always returns nil for an empty string.
|
||||
func validateSDEString(s string) (err error) {
|
||||
if i := strings.IndexFunc(s, isSurrogateOrNonCharacter); i >= 0 {
|
||||
return fmt.Errorf("bad code point at index %d", i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate checks the configuration for errors.
|
||||
func (c *StructuredDNSErrorsConfig) validate(edeEnabled bool) (err error) {
|
||||
if c == nil {
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
} else if !edeEnabled {
|
||||
return errors.Error("ede must be enabled to enable sde")
|
||||
}
|
||||
|
||||
var errs []error
|
||||
if len(c.Contact) == 0 {
|
||||
err = fmt.Errorf("contact data: %w", errors.ErrEmptyValue)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
for i, cont := range c.Contact {
|
||||
err = validateSDEContactURL(cont)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("contact data: at index %d: %w", i, err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Justification == "" {
|
||||
err = fmt.Errorf("justification: %w", errors.ErrEmptyValue)
|
||||
errs = append(errs, err)
|
||||
} else if err = validateSDEString(c.Justification); err != nil {
|
||||
err = fmt.Errorf("justification: %w", err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if err = validateSDEString(c.Organization); err != nil {
|
||||
err = fmt.Errorf("organization: %w", err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// validateSDEContactURL returns an error if u is not a valid SDE contact URL.
|
||||
// It doesn't check for bad code points in the URL since [url.URL.String]
|
||||
// escapes them.
|
||||
func validateSDEContactURL(u *url.URL) (err error) {
|
||||
if u == nil {
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "mailto", "sips", "tel":
|
||||
// TODO(a.garipov): Consider more thorough validations for each scheme.
|
||||
default:
|
||||
return fmt.Errorf("scheme: %w: %q", errors.ErrBadEnumValue, u.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
18
internal/dnsserver/cache/cache.go
vendored
18
internal/dnsserver/cache/cache.go
vendored
@ -31,8 +31,8 @@ type Middleware struct {
|
||||
// cacheMinTTL is the minimum supported TTL for cache items.
|
||||
cacheMinTTL time.Duration
|
||||
|
||||
// useTTLOverride shows if the TTL overrides logic should be used.
|
||||
useTTLOverride bool
|
||||
// overrideTTL shows if the TTL overrides logic should be used.
|
||||
overrideTTL bool
|
||||
}
|
||||
|
||||
// MiddlewareConfig is the configuration structure for NewMiddleware.
|
||||
@ -49,8 +49,8 @@ type MiddlewareConfig struct {
|
||||
// MinTTL is the minimum supported TTL for cache items.
|
||||
MinTTL time.Duration
|
||||
|
||||
// UseTTLOverride shows if the TTL overrides logic should be used.
|
||||
UseTTLOverride bool
|
||||
// OverrideTTL shows if the TTL overrides logic should be used.
|
||||
OverrideTTL bool
|
||||
}
|
||||
|
||||
// NewMiddleware initializes a new LRU caching middleware. c must not be nil.
|
||||
@ -63,10 +63,10 @@ func NewMiddleware(c *MiddlewareConfig) (m *Middleware) {
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
metrics: metrics,
|
||||
cache: gcache.New(c.Size).LRU().Build(),
|
||||
cacheMinTTL: c.MinTTL,
|
||||
useTTLOverride: c.UseTTLOverride,
|
||||
metrics: metrics,
|
||||
cache: gcache.New(c.Size).LRU().Build(),
|
||||
cacheMinTTL: c.MinTTL,
|
||||
overrideTTL: c.OverrideTTL,
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,7 +150,7 @@ func (m *Middleware) set(msg *dns.Msg) (err error) {
|
||||
}
|
||||
|
||||
exp := time.Duration(ttl) * time.Second
|
||||
if m.useTTLOverride && msg.Rcode != dns.RcodeServerFailure {
|
||||
if m.overrideTTL && msg.Rcode != dns.RcodeServerFailure {
|
||||
exp = max(exp, m.cacheMinTTL)
|
||||
setMinTTL(msg, uint32(exp.Seconds()))
|
||||
}
|
||||
|
6
internal/dnsserver/cache/cache_test.go
vendored
6
internal/dnsserver/cache/cache_test.go
vendored
@ -186,9 +186,9 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
withCache := dnsserver.WithMiddlewares(
|
||||
handler,
|
||||
cache.NewMiddleware(&cache.MiddlewareConfig{
|
||||
Size: 100,
|
||||
MinTTL: minTTL,
|
||||
UseTTLOverride: tc.minTTL != nil,
|
||||
Size: 100,
|
||||
MinTTL: minTTL,
|
||||
OverrideTTL: tc.minTTL != nil,
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@ package dnsservertest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
@ -13,9 +14,15 @@ import (
|
||||
// AnswerTTL is the default TTL of the test handler's answers.
|
||||
const AnswerTTL time.Duration = 100 * time.Second
|
||||
|
||||
// CreateTestHandler creates a [dnsserver.Handler] with the specified
|
||||
// NewDefaultHandler returns a simple handler that always returns a response
|
||||
// with a single A record.
|
||||
func NewDefaultHandler() (handler dnsserver.Handler) {
|
||||
return NewDefaultHandlerWithCount(1)
|
||||
}
|
||||
|
||||
// NewDefaultHandlerWithCount creates a [dnsserver.Handler] with the specified
|
||||
// parameters. All responses will have the [TestAnsTTL] TTL.
|
||||
func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
|
||||
func NewDefaultHandlerWithCount(recordsCount int) (h dnsserver.Handler) {
|
||||
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
|
||||
// Check that necessary context keys are set.
|
||||
si := dnsserver.MustServerInfoFromContext(ctx)
|
||||
@ -49,8 +56,12 @@ func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
|
||||
return dnsserver.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// DefaultHandler returns a simple handler that always returns a response with
|
||||
// a single A record.
|
||||
func DefaultHandler() (handler dnsserver.Handler) {
|
||||
return CreateTestHandler(1)
|
||||
// NewPanicHandler returns a DNS handler that panics with an error.
|
||||
func NewPanicHandler() (handler dnsserver.Handler) {
|
||||
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
|
||||
// TODO(a.garipov): Add a helper for these kinds of errors to golibs.
|
||||
panic(fmt.Errorf("unexpected call to ServeDNS(%v, %v)", rw, req))
|
||||
}
|
||||
|
||||
return dnsserver.HandlerFunc(f)
|
||||
}
|
||||
|
@ -273,6 +273,22 @@ func NewNS(name string, ttl uint32, ns string) (rr dns.RR) {
|
||||
}
|
||||
}
|
||||
|
||||
// NewOPT constructs the new resource record of type OPT.
|
||||
func NewOPT(do bool, udpSize uint16, opts ...dns.EDNS0) (rr dns.RR) {
|
||||
opt := &dns.OPT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: ".",
|
||||
Rrtype: dns.TypeOPT,
|
||||
},
|
||||
Option: opts,
|
||||
}
|
||||
|
||||
opt.SetDo(do)
|
||||
opt.SetUDPSize(udpSize)
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// NewECSExtra constructs a new OPT RR for the extra section.
|
||||
func NewECSExtra(ip net.IP, fam uint16, mask, scope uint8) (extra dns.RR) {
|
||||
return &dns.OPT{
|
||||
@ -300,11 +316,8 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
|
||||
padLen := requestPaddingBlockSize - msgLen%requestPaddingBlockSize
|
||||
|
||||
// Truncate padding to fit in UDP buffer.
|
||||
if msgLen+padLen > int(UDPBufferSize) {
|
||||
padLen = int(UDPBufferSize) - msgLen
|
||||
if padLen < 0 {
|
||||
padLen = 0
|
||||
}
|
||||
if bufSzInt := int(UDPBufferSize); msgLen+padLen > bufSzInt {
|
||||
padLen = max(bufSzInt-msgLen, 0)
|
||||
}
|
||||
|
||||
return &dns.OPT{
|
||||
|
@ -23,7 +23,7 @@ func TestMain(m *testing.M) {
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestHandler_ServeDNS(t *testing.T) {
|
||||
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
|
||||
// No-fallbacks handler.
|
||||
handler := forward.NewHandler(&forward.HandlerConfig{
|
||||
@ -47,7 +47,7 @@ func TestHandler_ServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_ServeDNS_fallbackNetError(t *testing.T) {
|
||||
srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
srv, _ := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
handler := forward.NewHandler(&forward.HandlerConfig{
|
||||
UpstreamsAddresses: []*forward.UpstreamPlainConfig{{
|
||||
Network: forward.NetworkAny,
|
||||
|
@ -20,7 +20,7 @@ func TestHandler_Refresh(t *testing.T) {
|
||||
var upstreamIsUp atomic.Bool
|
||||
var upstreamRequestsCount atomic.Int64
|
||||
|
||||
defaultHandler := dnsservertest.DefaultHandler()
|
||||
defaultHandler := dnsservertest.NewDefaultHandler()
|
||||
|
||||
// This handler writes an empty message if upstreamUp flag is false.
|
||||
handlerFunc := dnsserver.HandlerFunc(func(
|
||||
|
@ -36,7 +36,7 @@ func TestUpstreamPlain_Exchange(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
ups := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
|
||||
Network: tc.network,
|
||||
Address: netip.MustParseAddrPort(addr),
|
||||
@ -65,7 +65,7 @@ func TestUpstreamPlain_Exchange_truncated(t *testing.T) {
|
||||
rw.LocalAddr(),
|
||||
rw.RemoteAddr(),
|
||||
)
|
||||
handler := dnsservertest.DefaultHandler()
|
||||
handler := dnsservertest.NewDefaultHandler()
|
||||
err = handler.ServeDNS(ctx, nrw, req)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1,9 +1,9 @@
|
||||
module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
|
||||
|
||||
go 1.23.1
|
||||
go 1.23.2
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/golibs v0.28.0
|
||||
github.com/AdguardTeam/golibs v0.30.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.3.0
|
||||
github.com/ameshkov/dnsstamps v1.0.3
|
||||
github.com/bluele/gcache v0.0.2
|
||||
@ -11,12 +11,12 @@ require (
|
||||
github.com/miekg/dns v1.1.62
|
||||
github.com/panjf2000/ants/v2 v2.10.0
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
|
||||
github.com/prometheus/client_golang v1.20.1
|
||||
github.com/quic-go/quic-go v0.47.0
|
||||
github.com/prometheus/client_golang v1.20.5
|
||||
github.com/quic-go/quic-go v0.48.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
|
||||
golang.org/x/net v0.29.0
|
||||
golang.org/x/sys v0.25.0
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/sys v0.26.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -26,22 +26,23 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 // indirect
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.20.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.55.0 // indirect
|
||||
github.com/prometheus/common v0.60.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/crypto v0.27.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
golang.org/x/crypto v0.28.0 // indirect
|
||||
golang.org/x/mod v0.21.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/text v0.18.0 // indirect
|
||||
golang.org/x/time v0.6.0 // indirect
|
||||
golang.org/x/tools v0.25.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
golang.org/x/tools v0.26.0 // indirect
|
||||
google.golang.org/protobuf v1.35.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
github.com/AdguardTeam/golibs v0.28.0 h1:SK1q8SqkkJ/61pp2abTmio90S4QpteYK9rtgROfnrb4=
|
||||
github.com/AdguardTeam/golibs v0.30.1 h1:/yv7dq2h7WXw/jTDxkE3FP9zHerRT+i03PZRHJX4fPU=
|
||||
github.com/AdguardTeam/golibs v0.30.1/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
|
||||
@ -25,10 +26,10 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31 h1:LcRdQWywSgfi5jPsYZ1r2avbbs5IQ5wtyhMBCcokyo4=
|
||||
github.com/google/pprof v0.0.0-20240929191954-255acd752d31/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662 h1:SKMkD83p7FwUqKmBsPdLHF5dNyxq3jOWwu9w9UyH5vA=
|
||||
github.com/google/pprof v0.0.0-20241023014458-598669927662/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@ -47,18 +48,18 @@ github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible
|
||||
github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8=
|
||||
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
|
||||
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
||||
github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA=
|
||||
github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y=
|
||||
github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E=
|
||||
github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA=
|
||||
github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@ -69,29 +70,29 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
||||
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
|
||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
|
||||
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
|
||||
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
|
||||
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
|
||||
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
|
||||
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
|
||||
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
|
||||
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
@ -23,7 +23,7 @@ func TestCacheMetricsListener_integration_cache(t *testing.T) {
|
||||
})
|
||||
|
||||
handlerWithMiddleware := dnsserver.WithMiddlewares(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
cacheMiddleware,
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
// normal unit test, we create a forward handler, emulate a query and then
|
||||
// check if prom metrics were incremented.
|
||||
func TestForwardMetricsListener_integration_request(t *testing.T) {
|
||||
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
|
||||
// Initialize a new forward.Handler and set the metrics listener.
|
||||
handler := forward.NewHandler(&forward.HandlerConfig{
|
||||
|
@ -19,16 +19,21 @@ import (
|
||||
// normal unit test, we create a cache middleware, emulate a query and then
|
||||
// check if prom metrics were incremented.
|
||||
func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
rps := 5
|
||||
const (
|
||||
count = 5
|
||||
ivl = time.Second
|
||||
)
|
||||
|
||||
rl := ratelimit.NewBackoff(&ratelimit.BackoffConfig{
|
||||
Allowlist: ratelimit.NewDynamicAllowlist([]netip.Prefix{}, []netip.Prefix{}),
|
||||
Period: time.Minute,
|
||||
Duration: time.Minute,
|
||||
Count: uint(rps),
|
||||
Count: count,
|
||||
ResponseSizeEstimate: 1 * datasize.KB,
|
||||
IPv4RPS: uint(rps),
|
||||
IPv6RPS: uint(rps),
|
||||
IPv4Count: count,
|
||||
IPv4Interval: ivl,
|
||||
IPv6Count: count,
|
||||
IPv6Interval: ivl,
|
||||
RefuseANY: true,
|
||||
})
|
||||
rlMw, err := ratelimit.NewMiddleware(&ratelimit.MiddlewareConfig{
|
||||
@ -38,7 +43,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
handlerWithMiddleware := dnsserver.WithMiddlewares(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
rlMw,
|
||||
)
|
||||
|
||||
@ -55,7 +60,7 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
||||
|
||||
err = handlerWithMiddleware.ServeDNS(ctx, nrw, req)
|
||||
require.NoError(t, err)
|
||||
if i < rps {
|
||||
if i < count {
|
||||
dnsservertest.RequireResponse(t, req, nrw.Msg(), 1, dns.RcodeSuccess, false)
|
||||
} else {
|
||||
require.Nil(t, nrw.Msg())
|
||||
|
@ -24,7 +24,7 @@ func TestServerMetricsListener_integration_requestLifetime(t *testing.T) {
|
||||
ConfigBase: dnsserver.ConfigBase{
|
||||
Name: "test",
|
||||
Addr: "127.0.0.1:0",
|
||||
Handler: dnsservertest.DefaultHandler(),
|
||||
Handler: dnsservertest.NewDefaultHandler(),
|
||||
Metrics: prometheus.NewServerMetricsListener(testNamespace),
|
||||
},
|
||||
}
|
||||
|
@ -36,19 +36,29 @@ type BackoffConfig struct {
|
||||
// as several responses.
|
||||
ResponseSizeEstimate datasize.ByteSize
|
||||
|
||||
// IPv4RPS is the maximum number of requests per second allowed from a
|
||||
// single subnet for IPv4 addresses. Any requests above this rate are
|
||||
// counted as the client's backoff count. RPS must be greater than zero.
|
||||
IPv4RPS uint
|
||||
// IPv4Count is the maximum number of requests per a specified interval
|
||||
// allowed from a single subnet for IPv4 addresses. Any requests above this
|
||||
// rate are counted as the client's backoff count. It must be greater than
|
||||
// zero.
|
||||
IPv4Count uint
|
||||
|
||||
// IPv4Interval is the time during which to count the number of requests
|
||||
// for IPv4 addresses.
|
||||
IPv4Interval time.Duration
|
||||
|
||||
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
|
||||
IPv4SubnetKeyLen int
|
||||
|
||||
// IPv6RPS is the maximum number of requests per second allowed from a
|
||||
// single subnet for IPv6 addresses. Any requests above this rate are
|
||||
// counted as the client's backoff count. RPS must be greater than zero.
|
||||
IPv6RPS uint
|
||||
// IPv6Count is the maximum number of requests per a specified interval
|
||||
// allowed from a single subnet for IPv6 addresses. Any requests above this
|
||||
// rate are counted as the client's backoff count. It must be greater than
|
||||
// zero.
|
||||
IPv6Count uint
|
||||
|
||||
// IPv6Interval is the time during which to count the number of requests
|
||||
// for IPv6 addresses.
|
||||
IPv6Interval time.Duration
|
||||
|
||||
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
|
||||
@ -75,9 +85,11 @@ type Backoff struct {
|
||||
allowlist Allowlist
|
||||
respSzEst datasize.ByteSize
|
||||
count uint
|
||||
ipv4rps uint
|
||||
ipv4Count uint
|
||||
ipv4Interval time.Duration
|
||||
ipv4SubnetKeyLen int
|
||||
ipv6rps uint
|
||||
ipv6Count uint
|
||||
ipv6Interval time.Duration
|
||||
ipv6SubnetKeyLen int
|
||||
refuseANY bool
|
||||
}
|
||||
@ -93,9 +105,11 @@ func NewBackoff(c *BackoffConfig) (l *Backoff) {
|
||||
allowlist: c.Allowlist,
|
||||
respSzEst: c.ResponseSizeEstimate,
|
||||
count: c.Count,
|
||||
ipv4rps: c.IPv4RPS,
|
||||
ipv4Count: c.IPv4Count,
|
||||
ipv4Interval: c.IPv4Interval,
|
||||
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
|
||||
ipv6rps: c.IPv6RPS,
|
||||
ipv6Count: c.IPv6Count,
|
||||
ipv6Interval: c.IPv6Interval,
|
||||
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
|
||||
refuseANY: c.RefuseANY,
|
||||
}
|
||||
@ -133,12 +147,12 @@ func (l *Backoff) IsRateLimited(
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
rps := l.ipv4rps
|
||||
count, ivl := l.ipv4Count, l.ipv4Interval
|
||||
if ip.Is6() {
|
||||
rps = l.ipv6rps
|
||||
count, ivl = l.ipv6Count, l.ipv6Interval
|
||||
}
|
||||
|
||||
return l.hasHitRateLimit(key, rps), false, nil
|
||||
return l.hasHitRateLimit(key, count, ivl), false, nil
|
||||
}
|
||||
|
||||
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
|
||||
@ -198,15 +212,15 @@ func (l *Backoff) incBackoff(key string) {
|
||||
l.hitCounters.SetDefault(key, counter)
|
||||
}
|
||||
|
||||
// hasHitRateLimit checks value for a subnet with rps as a maximum number
|
||||
// requests per second.
|
||||
func (l *Backoff) hasHitRateLimit(subnetIPStr string, rps uint) (ok bool) {
|
||||
// hasHitRateLimit checks if the value of requests for given subnet hit the
|
||||
// maximum count of requests per given interval.
|
||||
func (l *Backoff) hasHitRateLimit(subnetIPStr string, count uint, ivl time.Duration) (ok bool) {
|
||||
var r *RequestCounter
|
||||
rVal, ok := l.reqCounters.Get(subnetIPStr)
|
||||
if ok {
|
||||
r = rVal.(*RequestCounter)
|
||||
} else {
|
||||
r = NewRequestCounter(rps, time.Second)
|
||||
r = NewRequestCounter(count, ivl)
|
||||
l.reqCounters.SetDefault(subnetIPStr, r)
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,10 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestRatelimitMiddleware(t *testing.T) {
|
||||
const rps = 10
|
||||
const (
|
||||
rps = 10
|
||||
ivl = time.Second
|
||||
)
|
||||
|
||||
persistent := []netip.Prefix{
|
||||
netip.MustParsePrefix("4.3.2.1/8"),
|
||||
@ -99,9 +102,11 @@ func TestRatelimitMiddleware(t *testing.T) {
|
||||
Duration: time.Minute,
|
||||
Count: rps,
|
||||
ResponseSizeEstimate: 128 * datasize.B,
|
||||
IPv4RPS: rps,
|
||||
IPv4Count: rps,
|
||||
IPv4Interval: ivl,
|
||||
IPv4SubnetKeyLen: 24,
|
||||
IPv6RPS: rps,
|
||||
IPv6Count: rps,
|
||||
IPv6Interval: ivl,
|
||||
IPv6SubnetKeyLen: 48,
|
||||
RefuseANY: true,
|
||||
})
|
||||
@ -112,7 +117,7 @@ func TestRatelimitMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
withMw := dnsserver.WithMiddlewares(
|
||||
dnsservertest.CreateTestHandler(tc.respCount),
|
||||
dnsservertest.NewDefaultHandlerWithCount(tc.respCount),
|
||||
rlMw,
|
||||
)
|
||||
|
||||
|
@ -309,6 +309,10 @@ func (s *ServerBase) serveDNSMsgInternal(
|
||||
s.metrics.OnError(ctx, err)
|
||||
|
||||
resp = genErrorResponse(req, dns.RcodeServerFailure)
|
||||
if isNonCriticalNetError(err) {
|
||||
addEDE(req, resp, dns.ExtendedErrorCodeNetworkError, "")
|
||||
}
|
||||
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
if err != nil {
|
||||
log.Debug("[%d]: error writing a response: %s", req.Id, err)
|
||||
@ -316,6 +320,28 @@ func (s *ServerBase) serveDNSMsgInternal(
|
||||
}
|
||||
}
|
||||
|
||||
// addEDE adds an Extended DNS Error (EDE) option to the blocked response
|
||||
// message, if the request indicates EDNS support.
|
||||
func addEDE(req, resp *dns.Msg, code uint16, text string) {
|
||||
reqOpt := req.IsEdns0()
|
||||
if reqOpt == nil {
|
||||
// Requestor doesn't implement EDNS, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc6891#section-7.
|
||||
return
|
||||
}
|
||||
|
||||
respOpt := resp.IsEdns0()
|
||||
if respOpt == nil {
|
||||
resp.SetEdns0(reqOpt.UDPSize(), reqOpt.Do())
|
||||
respOpt = resp.Extra[len(resp.Extra)-1].(*dns.OPT)
|
||||
}
|
||||
|
||||
respOpt.Option = append(respOpt.Option, &dns.EDNS0_EDE{
|
||||
InfoCode: code,
|
||||
ExtraText: text,
|
||||
})
|
||||
}
|
||||
|
||||
// acceptMsg checks if we should process the incoming DNS query.
|
||||
func (s *ServerBase) acceptMsg(m *dns.Msg) (action dns.MsgAcceptAction) {
|
||||
if m.Response {
|
||||
|
@ -37,7 +37,7 @@ func BenchmarkServeDNS(b *testing.B) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
_, addr := dnsservertest.RunDNSServer(b, dnsservertest.DefaultHandler())
|
||||
_, addr := dnsservertest.RunDNSServer(b, dnsservertest.NewDefaultHandler())
|
||||
|
||||
// Prepare a test message.
|
||||
m := new(dns.Msg)
|
||||
@ -105,7 +105,7 @@ func readMsg(resBuf []byte, network dnsserver.Network, conn net.Conn) (err error
|
||||
|
||||
func BenchmarkServeTLS(b *testing.B) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
addr := dnsservertest.RunTLSServer(b, dnsservertest.DefaultHandler(), tlsConfig)
|
||||
addr := dnsservertest.RunTLSServer(b, dnsservertest.NewDefaultHandler(), tlsConfig)
|
||||
|
||||
// Prepare a test message
|
||||
m := new(dns.Msg)
|
||||
@ -171,7 +171,7 @@ func BenchmarkServeDoH(b *testing.B) {
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
srv, err := dnsservertest.RunLocalHTTPSServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tc.tlsConfig,
|
||||
nil,
|
||||
)
|
||||
@ -250,7 +250,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
|
||||
Net: string(tc.network),
|
||||
}
|
||||
|
||||
s := dnsservertest.RunDNSCryptServer(b, dnsservertest.DefaultHandler())
|
||||
s := dnsservertest.RunDNSCryptServer(b, dnsservertest.NewDefaultHandler())
|
||||
stamp := dnsstamps.ServerStamp{
|
||||
ServerAddrStr: s.ServerAddr,
|
||||
ServerPk: s.ResolverPk,
|
||||
@ -282,7 +282,7 @@ func BenchmarkServeDNSCrypt(b *testing.B) {
|
||||
func BenchmarkServeQUIC(b *testing.B) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, addr, err := dnsservertest.RunLocalQUICServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
)
|
||||
require.NoError(b, err)
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
func TestServerDNS_StartShutdown(t *testing.T) {
|
||||
_, _ = dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
_, _ = dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
}
|
||||
|
||||
func TestServerDNS_integration_query(t *testing.T) {
|
||||
@ -42,7 +42,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 1,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
@ -54,7 +54,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 1,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
@ -89,7 +89,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantMsg: func(t *testing.T, m *dns.Msg) {
|
||||
opt := m.IsEdns0()
|
||||
require.NotNil(t, opt)
|
||||
@ -109,7 +109,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 0,
|
||||
wantRCode: dns.RcodeFormatError,
|
||||
}, {
|
||||
@ -122,7 +122,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "eXaMplE.oRg.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 1,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
@ -136,7 +136,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 0,
|
||||
wantRCode: dns.RcodeNotImplemented,
|
||||
}, {
|
||||
@ -170,7 +170,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
{Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 1,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}, {
|
||||
@ -185,7 +185,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
},
|
||||
},
|
||||
// Set a handler that generates a large response
|
||||
handler: dnsservertest.CreateTestHandler(64),
|
||||
handler: dnsservertest.NewDefaultHandlerWithCount(64),
|
||||
wantRecordsCount: 0,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantTruncated: true,
|
||||
@ -210,7 +210,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
},
|
||||
},
|
||||
// Set a handler that generates a large response
|
||||
handler: dnsservertest.CreateTestHandler(64),
|
||||
handler: dnsservertest.NewDefaultHandlerWithCount(64),
|
||||
wantRecordsCount: 64,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantTruncated: false,
|
||||
@ -226,7 +226,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
},
|
||||
},
|
||||
// Set a handler that generates a large response
|
||||
handler: dnsservertest.CreateTestHandler(64),
|
||||
handler: dnsservertest.NewDefaultHandlerWithCount(64),
|
||||
// No truncate
|
||||
wantRecordsCount: 64,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
@ -257,7 +257,7 @@ func TestServerDNS_integration_query(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
handler: dnsservertest.DefaultHandler(),
|
||||
handler: dnsservertest.NewDefaultHandler(),
|
||||
wantRecordsCount: 1,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
}}
|
||||
@ -304,7 +304,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
|
||||
// As per RFC 7766 we should support queries pipelining for TCP, that is
|
||||
// server must be able to process incoming queries in parallel and write
|
||||
// responses possibly out of order within the same connection.
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
|
||||
// Establish a connection.
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
@ -372,7 +372,7 @@ func TestServerDNS_integration_tcpQueriesPipelining(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerDNS_integration_udpMsgIgnore(t *testing.T) {
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
conn, err := net.Dial("udp", addr)
|
||||
require.Nil(t, err)
|
||||
|
||||
@ -469,7 +469,7 @@ func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
|
||||
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.NewDefaultHandler())
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
require.Nil(t, err)
|
||||
|
||||
|
@ -49,7 +49,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
|
||||
name: "udp_truncate_response",
|
||||
network: dnsserver.NetworkUDP,
|
||||
// Set a handler that generates a large response
|
||||
handler: dnsservertest.CreateTestHandler(64),
|
||||
handler: dnsservertest.NewDefaultHandlerWithCount(64),
|
||||
// DNSCrypt server removes all records from a truncated response
|
||||
expectedRecordsCount: 0,
|
||||
expectedRCode: dns.RcodeSuccess,
|
||||
@ -66,7 +66,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
|
||||
name: "udp_edns0_no_truncate",
|
||||
network: dnsserver.NetworkUDP,
|
||||
// Set a handler that generates a large response
|
||||
handler: dnsservertest.CreateTestHandler(64),
|
||||
handler: dnsservertest.NewDefaultHandlerWithCount(64),
|
||||
expectedRecordsCount: 64,
|
||||
expectedRCode: dns.RcodeSuccess,
|
||||
expectedTruncated: false,
|
||||
@ -89,7 +89,7 @@ func TestServerDNSCrypt_integration_query(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handler := tc.handler
|
||||
if tc.handler == nil {
|
||||
handler = dnsservertest.DefaultHandler()
|
||||
handler = dnsservertest.NewDefaultHandler()
|
||||
}
|
||||
|
||||
s := dnsservertest.RunDNSCryptServer(t, handler)
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
@ -304,9 +305,9 @@ func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Lis
|
||||
// application won't be able to continue listening to DoH.
|
||||
defer s.handlePanicAndExit(ctx)
|
||||
|
||||
scheme := "https"
|
||||
scheme := urlutil.SchemeHTTPS
|
||||
if s.conf.TLSConfig == nil {
|
||||
scheme = "http"
|
||||
scheme = urlutil.SchemeHTTP
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
|
@ -109,7 +109,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
|
||||
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, err := dnsservertest.RunLocalHTTPSServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
nil,
|
||||
)
|
||||
@ -146,7 +146,7 @@ func TestServerHTTPS_integration_nonDNSHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
srv, err := dnsservertest.RunLocalHTTPSServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
nil,
|
||||
testHandler,
|
||||
)
|
||||
@ -321,7 +321,7 @@ func TestDNSMsgToJSONMsg(t *testing.T) {
|
||||
func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, err := dnsservertest.RunLocalHTTPSServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
nil,
|
||||
)
|
||||
@ -346,7 +346,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
|
||||
func TestServerHTTPS_0RTT(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, err := dnsservertest.RunLocalHTTPSServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
nil,
|
||||
)
|
||||
@ -514,7 +514,7 @@ func createDoH3Client(
|
||||
tlsConfig = tlsConfig.Clone()
|
||||
tlsConfig.NextProtos = []string{http3.NextProtoH3}
|
||||
|
||||
transport := &http3.RoundTripper{
|
||||
transport := &http3.Transport{
|
||||
DisableCompression: true,
|
||||
Dial: func(
|
||||
ctx context.Context,
|
||||
|
@ -7,13 +7,18 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// JSONMsg represents a *dns.Msg in the JSON format defined here:
|
||||
// https://developers.google.com/speed/public-dns/docs/doh/json#dns_response_in_json
|
||||
// Note, that we do not implement some parts of it. There is no "Comment" field
|
||||
// and there's no "edns_client_subnet".
|
||||
//
|
||||
// NOTE: This API differs from the Google one in the following ways:
|
||||
// 1. The "Comment" field is not implemented.
|
||||
// 2. The "edns_client_subnet" query parameter is not supported.
|
||||
// 3. The "sde" query parameter is added and supported for the experimental
|
||||
// Structured DNS Errors feature.
|
||||
type JSONMsg struct {
|
||||
Question []JSONQuestion `json:"Question"`
|
||||
Answer []JSONAnswer `json:"Answer"`
|
||||
@ -26,13 +31,13 @@ type JSONMsg struct {
|
||||
Status int `json:"Status"`
|
||||
}
|
||||
|
||||
// JSONQuestion is a part of JSONMsg definition.
|
||||
// JSONQuestion is a part of [JSONMsg] definition.
|
||||
type JSONQuestion struct {
|
||||
Name string `json:"name"`
|
||||
Type uint16 `json:"type"`
|
||||
}
|
||||
|
||||
// JSONAnswer is a part of JSONMsg definition.
|
||||
// JSONAnswer is a part of [JSONMsg] definition.
|
||||
type JSONAnswer struct {
|
||||
Name string `json:"name"`
|
||||
Data string `json:"data"`
|
||||
@ -41,7 +46,7 @@ type JSONAnswer struct {
|
||||
Class uint16 `json:"class"`
|
||||
}
|
||||
|
||||
// DNSMsgToJSONMsg converts the *dns.Msg to the JSON format (*JSONMsg).
|
||||
// DNSMsgToJSONMsg converts the *dns.Msg to the JSON format.
|
||||
func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
|
||||
msg = &JSONMsg{
|
||||
Status: m.Rcode,
|
||||
@ -74,9 +79,9 @@ func DNSMsgToJSONMsg(m *dns.Msg) (msg *JSONMsg) {
|
||||
func rrToJSON(rr dns.RR) (j JSONAnswer) {
|
||||
hdr := rr.Header()
|
||||
|
||||
// Extracting the RR value is a bit tricky since miekg/dns does not
|
||||
// expose the necessary methods. This way we can benefit from the
|
||||
// proper string serialization code that's used inside miekg/dns.
|
||||
// Extracting the RR value is a bit tricky since miekg/dns does not expose
|
||||
// the necessary methods. This way we can benefit from the proper string
|
||||
// serialization code that's used inside miekg/dns.
|
||||
hdrStr := hdr.String()
|
||||
valStr := rr.String()
|
||||
data := strings.TrimLeft(strings.TrimPrefix(valStr, hdrStr), " ")
|
||||
@ -90,19 +95,17 @@ func rrToJSON(rr dns.RR) (j JSONAnswer) {
|
||||
}
|
||||
}
|
||||
|
||||
// dnsMsgToJSON converts the *dns.Msg to the JSON format (JSONMsg) and returns
|
||||
// it in the serialized form.
|
||||
// dnsMsgToJSON converts the *dns.Msg to the JSON format and returns it in the
|
||||
// serialized form.
|
||||
func dnsMsgToJSON(m *dns.Msg) (b []byte, err error) {
|
||||
msg := DNSMsgToJSONMsg(m)
|
||||
return json.Marshal(msg)
|
||||
return json.Marshal(DNSMsgToJSONMsg(m))
|
||||
}
|
||||
|
||||
// httpRequestToMsgJSON builds a DNS message from the request parameters.
|
||||
// We use the same parameters as the ones defined here:
|
||||
// https://developers.google.com/speed/public-dns/docs/doh/json#supported_parameters
|
||||
// Some parameters are not supported: "ct", "edns_client_subnet".
|
||||
func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) {
|
||||
q := req.URL.Query()
|
||||
//
|
||||
// See [JSONMsg].
|
||||
func httpRequestToMsgJSON(httpReq *http.Request) (b []byte, err error) {
|
||||
q := httpReq.URL.Query()
|
||||
|
||||
// Query name, the only required parameter.
|
||||
name := q.Get("name")
|
||||
@ -111,71 +114,91 @@ func httpRequestToMsgJSON(req *http.Request) (b []byte, err error) {
|
||||
return nil, ErrInvalidArgument
|
||||
}
|
||||
|
||||
// RR type can be represented as a number in [1, 65535] or a
|
||||
// canonical string (case-insensitive, such as A or AAAA).
|
||||
var t uint16
|
||||
t, err = urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
|
||||
// RR type can be represented as a number in [1, 65535] or a canonical
|
||||
// string (case-insensitive, such as A or AAAA).
|
||||
qt, err := urlQueryParameterToUint16(q, "type", dns.TypeA, dns.StringToType)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query class can be represented as a number in [1, 65535] or a
|
||||
// canonical string (case-insensitive).
|
||||
var qc uint16
|
||||
qc, err = urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
|
||||
// Query class can be represented as a number in [1, 65535] or a canonical
|
||||
// string (case-insensitive).
|
||||
qc, err := urlQueryParameterToUint16(q, "qc", dns.ClassINET, dns.StringToClass)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable
|
||||
// DNSSEC validation; use cd=0, cd=false, or no cd parameter to
|
||||
// enable DNSSEC validation.
|
||||
var cd bool
|
||||
cd, err = urlQueryParameterToBoolean(q, "cd", false)
|
||||
// The CD (Checking Disabled) flag. Use cd=1, or cd=true to disable DNSSEC
|
||||
// validation; use cd=0, cd=false, or no cd parameter to enable DNSSEC
|
||||
// validation.
|
||||
cd, err := urlQueryParameterToBoolean(q, "cd", false)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The DO (DNSSEC OK) flag. Use do=1, or do=true to include DNSSEC
|
||||
// records (RRSIG, NSEC, NSEC3); use do=0, do=false, or no do parameter
|
||||
// to omit DNSSEC records.
|
||||
var do bool
|
||||
do, err = urlQueryParameterToBoolean(q, "do", false)
|
||||
// The DO (DNSSEC OK) flag. Use do=1 (or do=true) to include DNSSEC records
|
||||
// (RRSIG, NSEC, NSEC3); use do=0 (do=false) or no do parameter to omit
|
||||
// DNSSEC records.
|
||||
do, err := urlQueryParameterToBoolean(q, "do", false)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The experimental Structured DNS Errors feature.
|
||||
sde, err := urlQueryParameterToBoolean(q, "sde", false)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now build a DNS message with all those parameters
|
||||
r := &dns.Msg{
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
CheckingDisabled: cd,
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{
|
||||
{
|
||||
Name: dns.Fqdn(name),
|
||||
Qtype: t,
|
||||
Qclass: qc,
|
||||
},
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(name),
|
||||
Qtype: qt,
|
||||
Qclass: qc,
|
||||
}},
|
||||
}
|
||||
|
||||
if do {
|
||||
r.SetEdns0(dns.MaxMsgSize, do)
|
||||
setEDNSFromQuery(req, do, sde)
|
||||
|
||||
return req.Pack()
|
||||
}
|
||||
|
||||
// setEDNSFromQuery sets the EDNS parameters on the request depending on the
|
||||
// query parameters.
|
||||
func setEDNSFromQuery(req *dns.Msg, do, sde bool) {
|
||||
if !do && !sde {
|
||||
return
|
||||
}
|
||||
|
||||
return r.Pack()
|
||||
req.SetEdns0(dns.MaxMsgSize, do)
|
||||
|
||||
if sde {
|
||||
opt := req.Extra[0].(*dns.OPT)
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{})
|
||||
}
|
||||
}
|
||||
|
||||
// urlQueryParameterToUint16 is a helper function that extracts a uint16 value
|
||||
// from a query parameter. See httpRequestToMsgJSON to see how it's used.
|
||||
// from a query parameter.
|
||||
func urlQueryParameterToUint16(
|
||||
q url.Values,
|
||||
name string,
|
||||
defaultValue uint16,
|
||||
strValuesMap map[string]uint16,
|
||||
) (v uint16, err error) {
|
||||
defer func() { err = errors.Annotate(err, "parameter %q: %w", name) }()
|
||||
|
||||
strValue := q.Get(name)
|
||||
uintValue, convErr := strconv.ParseUint(strValue, 10, 16)
|
||||
switch {
|
||||
@ -199,12 +222,8 @@ func urlQueryParameterToUint16(
|
||||
}
|
||||
|
||||
// urlQueryParameterToBoolean is a helper function that extracts a boolean value
|
||||
// from a query parameter. See httpRequestToMsgJSON to see how it's used.
|
||||
func urlQueryParameterToBoolean(
|
||||
q url.Values,
|
||||
name string,
|
||||
defaultValue bool,
|
||||
) (v bool, err error) {
|
||||
// from a query parameter.
|
||||
func urlQueryParameterToBoolean(q url.Values, name string, defaultValue bool) (v bool, err error) {
|
||||
strValue := q.Get(name)
|
||||
switch strValue {
|
||||
case "1", "true", "True":
|
||||
|
@ -25,7 +25,7 @@ import (
|
||||
func TestServerQUIC_integration_query(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, addr, err := dnsservertest.RunLocalQUICServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -74,7 +74,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
|
||||
func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, addr, err := dnsservertest.RunLocalQUICServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -108,7 +108,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
|
||||
func TestServerQUIC_integration_0RTT(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, addr, err := dnsservertest.RunLocalQUICServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -146,7 +146,7 @@ func TestServerQUIC_integration_0RTT(t *testing.T) {
|
||||
func TestServerQUIC_integration_largeQuery(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
srv, addr, err := dnsservertest.RunLocalQUICServer(
|
||||
dnsservertest.DefaultHandler(),
|
||||
dnsservertest.NewDefaultHandler(),
|
||||
tlsConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
@ -17,7 +17,7 @@ import (
|
||||
|
||||
func TestServerTLS_integration_queryTLS(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
|
||||
|
||||
// Create a test message.
|
||||
req := new(dns.Msg)
|
||||
@ -94,7 +94,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
h := dnsservertest.DefaultHandler()
|
||||
h := dnsservertest.NewDefaultHandler()
|
||||
addr := dnsservertest.RunTLSServer(t, h, tlsConfig)
|
||||
|
||||
conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
|
||||
@ -120,7 +120,7 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
|
||||
func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
|
||||
// Handler that writes a huge response which would not fit
|
||||
// into a UDP response, but it should fit a TCP response just okay.
|
||||
handler := dnsservertest.CreateTestHandler(64)
|
||||
handler := dnsservertest.NewDefaultHandlerWithCount(64)
|
||||
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
addr := dnsservertest.RunTLSServer(t, handler, tlsConfig)
|
||||
@ -155,7 +155,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
|
||||
// i.e. we should be able to process incoming queries in parallel and
|
||||
// write responses out of order.
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
|
||||
|
||||
// First - establish a connection
|
||||
conn, err := tls.Dial("tcp", addr.String(), tlsConfig)
|
||||
@ -221,7 +221,7 @@ func TestServerTLS_integration_queriesPipelining(t *testing.T) {
|
||||
|
||||
func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
|
||||
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.DefaultHandler(), tlsConfig)
|
||||
addr := dnsservertest.RunTLSServer(t, dnsservertest.NewDefaultHandler(), tlsConfig)
|
||||
|
||||
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
|
||||
req.Extra = []dns.RR{dnsservertest.NewEDNS0Padding(req.Len(), dns.DefaultMsgSize)}
|
||||
|
228
internal/dnssvc/config.go
Normal file
228
internal/dnssvc/config.go
Normal file
@ -0,0 +1,228 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Config is the configuration of the AdGuard DNS service.
|
||||
type Config struct {
|
||||
// Handlers are the handlers to use in this DNS service.
|
||||
Handlers Handlers
|
||||
|
||||
// NewListener, when set, is used instead of the package-level function
|
||||
// [NewListener] when creating a DNS listener.
|
||||
//
|
||||
// TODO(a.garipov): This is only used for tests. Replace with a
|
||||
// [netext.ListenConfig].
|
||||
NewListener NewListenerFunc
|
||||
|
||||
// Cloner is used to clone messages more efficiently by disposing of parts
|
||||
// of DNS responses for later reuse. It must not be nil.
|
||||
Cloner *dnsmsg.Cloner
|
||||
|
||||
// ControlConf is the configuration of socket options.
|
||||
ControlConf *netext.ControlConfig
|
||||
|
||||
// ConnLimiter, if not nil, is used to limit the number of simultaneously
|
||||
// active stream-connections.
|
||||
ConnLimiter *connlimiter.Limiter
|
||||
|
||||
// ErrColl is the error collector that is used to collect critical and
|
||||
// non-critical errors. It must not be nil.
|
||||
ErrColl errcoll.Interface
|
||||
|
||||
// NonDNS is the handler for non-DNS HTTP requests. It must not be nil.
|
||||
NonDNS http.Handler
|
||||
|
||||
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
|
||||
// valid Prometheus metric label.
|
||||
MetricsNamespace string
|
||||
|
||||
// ServerGroups are the DNS server groups. Each element must be non-nil.
|
||||
ServerGroups []*agd.ServerGroup
|
||||
|
||||
// HandleTimeout defines the timeout for the entire handling of a single
|
||||
// query. It must be greater than zero.
|
||||
HandleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewListenerFunc is the type for DNS listener constructors.
|
||||
type NewListenerFunc func(
|
||||
srv *agd.Server,
|
||||
baseConf dnsserver.ConfigBase,
|
||||
nonDNS http.Handler,
|
||||
) (l Listener, err error)
|
||||
|
||||
// Listener is a type alias for dnsserver.Server to make internal naming more
|
||||
// consistent.
|
||||
type Listener = dnsserver.Server
|
||||
|
||||
// HandlersConfig is the configuration necessary to create or wrap the main DNS
|
||||
// handler.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding validation functions.
|
||||
type HandlersConfig struct {
|
||||
// BaseLogger is used to create loggers with custom prefixes for middlewares
|
||||
// and the service itself. It must not be nil.
|
||||
BaseLogger *slog.Logger
|
||||
|
||||
// Cloner is used to clone messages more efficiently by disposing of parts
|
||||
// of DNS responses for later reuse. It must not be nil.
|
||||
Cloner *dnsmsg.Cloner
|
||||
|
||||
// Cache is the configuration for the DNS cache.
|
||||
Cache *CacheConfig
|
||||
|
||||
// HumanIDParser is used to normalize and parse human-readable device
|
||||
// identifiers. It must not be nil if at least one server group has
|
||||
// profiles enabled.
|
||||
HumanIDParser *agd.HumanIDParser
|
||||
|
||||
// Messages is the message constructor used to create blocked and other
|
||||
// messages for this DNS service. It must not be nil.
|
||||
Messages *dnsmsg.Constructor
|
||||
|
||||
// PluginRegistry is used to override configuration parameters.
|
||||
PluginRegistry *plugin.Registry
|
||||
|
||||
// StructuredErrors is the configuration for the experimental Structured DNS
|
||||
// Errors feature in the profiles' message constructors. It must not be
|
||||
// nil.
|
||||
StructuredErrors *dnsmsg.StructuredDNSErrorsConfig
|
||||
|
||||
// AccessManager is used to block requests. It must not be nil.
|
||||
AccessManager access.Interface
|
||||
|
||||
// BillStat is used to collect billing statistics. It must not be nil.
|
||||
BillStat billstat.Recorder
|
||||
|
||||
// CacheManager is the global cache manager. It must not be nil.
|
||||
CacheManager agdcache.Manager
|
||||
|
||||
// DNSCheck is used by clients to check if they use AdGuard DNS. It must
|
||||
// not be nil.
|
||||
DNSCheck dnscheck.Interface
|
||||
|
||||
// DNSDB is used to update anonymous statistics about DNS queries. It must
|
||||
// not be nil.
|
||||
DNSDB dnsdb.Interface
|
||||
|
||||
// ErrColl is the error collector that is used to collect critical and
|
||||
// non-critical errors. It must not be nil.
|
||||
ErrColl errcoll.Interface
|
||||
|
||||
// FilterStorage is the storage of all filters. It must not be nil.
|
||||
FilterStorage filter.Storage
|
||||
|
||||
// GeoIP is the GeoIP database used to detect geographic data about IP
|
||||
// addresses in requests and responses. It must not be nil.
|
||||
GeoIP geoip.Interface
|
||||
|
||||
// Handler is the ultimate handler of the DNS query to be wrapped by
|
||||
// middlewares. It must not be nil.
|
||||
Handler dnsserver.Handler
|
||||
|
||||
// HashMatcher is the safe-browsing hash matcher for TXT queries. It must
|
||||
// not be nil.
|
||||
HashMatcher filter.HashMatcher
|
||||
|
||||
// ProfileDB is the AdGuard DNS profile database used to fetch data about
|
||||
// profiles, devices, and so on. It must not be nil if at least one server
|
||||
// group has profiles enabled.
|
||||
ProfileDB profiledb.Interface
|
||||
|
||||
// PrometheusRegisterer is used to register Prometheus metrics. It must not
|
||||
// be nil.
|
||||
PrometheusRegisterer prometheus.Registerer
|
||||
|
||||
// QueryLog is used to write the logs into. It must not be nil.
|
||||
QueryLog querylog.Interface
|
||||
|
||||
// RateLimit is used for allow or decline requests. It must not be nil.
|
||||
RateLimit ratelimit.Interface
|
||||
|
||||
// RuleStat is used to collect statistics about matched filtering rules and
|
||||
// rule lists. It must not be nil.
|
||||
RuleStat rulestat.Interface
|
||||
|
||||
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
|
||||
// valid Prometheus metric label.
|
||||
MetricsNamespace string
|
||||
|
||||
// FilteringGroups are the DNS filtering groups. Each element must be
|
||||
// non-nil.
|
||||
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
|
||||
|
||||
// ServerGroups are the DNS server groups for which to build handlers. Each
|
||||
// element and its servers must be non-nil.
|
||||
ServerGroups []*agd.ServerGroup
|
||||
|
||||
// EDEEnabled enables the addition of the Extended DNS Error (EDE) codes in
|
||||
// the profiles' message constructors.
|
||||
EDEEnabled bool
|
||||
}
|
||||
|
||||
// Handlers contains the map of handlers for each server of each server group.
|
||||
// The pointers are the same as those passed in a [HandlersConfig] to
|
||||
// [NewHandlers].
|
||||
type Handlers map[HandlerKey]dnsserver.Handler
|
||||
|
||||
// HandlerKey is a key for the [Handlers] map.
|
||||
type HandlerKey struct {
|
||||
Server *agd.Server
|
||||
ServerGroup *agd.ServerGroup
|
||||
}
|
||||
|
||||
// CacheConfig is the configuration for the DNS cache.
|
||||
type CacheConfig struct {
|
||||
// MinTTL is the minimum supported TTL for cache items.
|
||||
MinTTL time.Duration
|
||||
|
||||
// ECSCount is the size of the DNS cache for domain names that support
|
||||
// ECS, in entries. It must be greater than zero if [CacheConfig.CacheType]
|
||||
// is [CacheTypeECS].
|
||||
ECSCount int
|
||||
|
||||
// NoECSCount is the size of the DNS cache for domain names that don't
|
||||
// support ECS, in entries. It must be greater than zero if
|
||||
// [CacheConfig.CacheType] is [CacheTypeSimple] or [CacheTypeECS].
|
||||
NoECSCount int
|
||||
|
||||
// Type is the cache type. It must be valid.
|
||||
Type CacheType
|
||||
|
||||
// OverrideCacheTTL shows if the TTL overriding logic should be used.
|
||||
OverrideCacheTTL bool
|
||||
}
|
||||
|
||||
// CacheType is the type of the cache to use.
|
||||
type CacheType uint8
|
||||
|
||||
// CacheType constants.
|
||||
const (
|
||||
CacheTypeNone CacheType = iota + 1
|
||||
CacheTypeSimple
|
||||
CacheTypeECS
|
||||
)
|
35
internal/dnssvc/context.go
Normal file
35
internal/dnssvc/context.go
Normal file
@ -0,0 +1,35 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
)
|
||||
|
||||
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
|
||||
// returns a context with the given timeout as well as a new [agd.RequestID].
|
||||
type contextConstructor struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newContextConstructor returns a new properly initialized *contextConstructor.
|
||||
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
|
||||
return &contextConstructor{
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
|
||||
|
||||
// New implements the [dnsserver.ContextConstructor] interface for
|
||||
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
|
||||
// as its timeout and the corresponding cancelation function.
|
||||
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
|
||||
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
|
||||
|
||||
return ctx, cancel
|
||||
}
|
@ -7,289 +7,27 @@ package dnssvc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/cmd/plugin"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Config is the configuration of the AdGuard DNS service.
|
||||
type Config struct {
|
||||
// BaseLogger is used to create loggers with custom prefixes for middlewares
|
||||
// and the service itself.
|
||||
BaseLogger *slog.Logger
|
||||
|
||||
// Messages is the message constructor used to create blocked and other
|
||||
// messages for this DNS service.
|
||||
Messages *dnsmsg.Constructor
|
||||
|
||||
// Cloner is used to clone messages more efficiently by disposing of parts
|
||||
// of DNS responses for later reuse.
|
||||
Cloner *dnsmsg.Cloner
|
||||
|
||||
// ControlConf is the configuration of socket options.
|
||||
ControlConf *netext.ControlConfig
|
||||
|
||||
// ConnLimiter, if not nil, is used to limit the number of simultaneously
|
||||
// active stream-connections.
|
||||
ConnLimiter *connlimiter.Limiter
|
||||
|
||||
// HumanIDParser is used to normalize and parse human-readable device
|
||||
// identifiers.
|
||||
HumanIDParser *agd.HumanIDParser
|
||||
|
||||
// PluginRegistry is used to override configuration parameters.
|
||||
PluginRegistry *plugin.Registry
|
||||
|
||||
// AccessManager is used to block requests.
|
||||
AccessManager access.Interface
|
||||
|
||||
// SafeBrowsing is the safe browsing TXT hash matcher.
|
||||
SafeBrowsing filter.HashMatcher
|
||||
|
||||
// BillStat is used to collect billing statistics.
|
||||
BillStat billstat.Recorder
|
||||
|
||||
// CacheManager is the global cache manager. CacheManager must not be nil.
|
||||
CacheManager agdcache.Manager
|
||||
|
||||
// ProfileDB is the AdGuard DNS profile database used to fetch data about
|
||||
// profiles, devices, and so on.
|
||||
ProfileDB profiledb.Interface
|
||||
|
||||
// PrometheusRegisterer is used to register Prometheus metrics.
|
||||
PrometheusRegisterer prometheus.Registerer
|
||||
|
||||
// DNSCheck is used by clients to check if they use AdGuard DNS.
|
||||
DNSCheck dnscheck.Interface
|
||||
|
||||
// NonDNS is the handler for non-DNS HTTP requests.
|
||||
NonDNS http.Handler
|
||||
|
||||
// DNSDB is used to update anonymous statistics about DNS queries.
|
||||
DNSDB dnsdb.Interface
|
||||
|
||||
// ErrColl is the error collector that is used to collect critical and
|
||||
// non-critical errors.
|
||||
ErrColl errcoll.Interface
|
||||
|
||||
// FilterStorage is the storage of all filters.
|
||||
FilterStorage filter.Storage
|
||||
|
||||
// GeoIP is the GeoIP database used to detect geographic data about IP
|
||||
// addresses in requests and responses.
|
||||
GeoIP geoip.Interface
|
||||
|
||||
// QueryLog is used to write the logs into.
|
||||
QueryLog querylog.Interface
|
||||
|
||||
// RuleStat is used to collect statistics about matched filtering rules and
|
||||
// rule lists.
|
||||
RuleStat rulestat.Interface
|
||||
|
||||
// NewListener, when set, is used instead of the package-level function
|
||||
// NewListener when creating a DNS listener.
|
||||
//
|
||||
// TODO(a.garipov): The handler and service logic should really not be
|
||||
// intertwined in this way. See AGDNS-1327.
|
||||
NewListener NewListenerFunc
|
||||
|
||||
// Handler is used as the main DNS handler instead of a simple forwarder.
|
||||
// It must not be nil.
|
||||
//
|
||||
// TODO(a.garipov): Think of a better way to make the DNS server logic more
|
||||
// testable.
|
||||
Handler dnsserver.Handler
|
||||
|
||||
// RateLimit is used for allow or decline requests.
|
||||
RateLimit ratelimit.Interface
|
||||
|
||||
// MetricsNamespace is a namespace for Prometheus metrics. It must be a
|
||||
// valid Prometheus metric label.
|
||||
MetricsNamespace string
|
||||
|
||||
// FilteringGroups are the DNS filtering groups. Each element must be
|
||||
// non-nil.
|
||||
FilteringGroups map[agd.FilteringGroupID]*agd.FilteringGroup
|
||||
|
||||
// ServerGroups are the DNS server groups. Each element must be non-nil.
|
||||
ServerGroups []*agd.ServerGroup
|
||||
|
||||
// HandleTimeout defines the timeout for the entire handling of a single
|
||||
// query.
|
||||
HandleTimeout time.Duration
|
||||
|
||||
// CacheSize is the size of the DNS cache for domain names that don't
|
||||
// support ECS.
|
||||
//
|
||||
// TODO(a.garipov): Extract this and following fields to cache configuration
|
||||
// struct.
|
||||
CacheSize int
|
||||
|
||||
// ECSCacheSize is the size of the DNS cache for domain names that support
|
||||
// ECS.
|
||||
ECSCacheSize int
|
||||
|
||||
// CacheMinTTL is the minimum supported TTL for cache items. This setting
|
||||
// is used when UseCacheTTLOverride set to true.
|
||||
CacheMinTTL time.Duration
|
||||
|
||||
// UseCacheTTLOverride shows if the TTL overrides logic should be used.
|
||||
UseCacheTTLOverride bool
|
||||
|
||||
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
|
||||
// used.
|
||||
UseECSCache bool
|
||||
// Service is the main DNS service of AdGuard DNS.
|
||||
type Service struct {
|
||||
groups []*serverGroup
|
||||
}
|
||||
|
||||
type (
|
||||
// MainMiddlewareMetrics is a re-export of the internal filtering-middleware
|
||||
// metrics interface.
|
||||
MainMiddlewareMetrics = mainmw.Metrics
|
||||
|
||||
// RatelimitMiddlewareMetrics is a re-export of the metrics interface of the
|
||||
// internal access and ratelimiting middleware.
|
||||
RatelimitMiddlewareMetrics = ratelimitmw.Metrics
|
||||
)
|
||||
|
||||
// New returns a new DNS service.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
// Use either the configured listener initializer or the default one.
|
||||
newListener := c.NewListener
|
||||
if newListener == nil {
|
||||
newListener = NewListener
|
||||
}
|
||||
|
||||
// Configure the end of the request handling pipeline.
|
||||
handler := c.Handler
|
||||
if handler == nil {
|
||||
return nil, errors.Error("handler in config must not be nil")
|
||||
}
|
||||
|
||||
// Configure the pre-upstream middleware common for all servers of all
|
||||
// groups.
|
||||
preUps := preupstream.New(&preupstream.Config{
|
||||
Cloner: c.Cloner,
|
||||
CacheManager: c.CacheManager,
|
||||
DB: c.DNSDB,
|
||||
GeoIP: c.GeoIP,
|
||||
CacheSize: c.CacheSize,
|
||||
ECSCacheSize: c.ECSCacheSize,
|
||||
UseECSCache: c.UseECSCache,
|
||||
CacheMinTTL: c.CacheMinTTL,
|
||||
UseCacheTTLOverride: c.UseCacheTTLOverride,
|
||||
})
|
||||
handler = preUps.Wrap(handler)
|
||||
|
||||
errCollListener := &errCollMetricsListener{
|
||||
errColl: c.ErrColl,
|
||||
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
|
||||
}
|
||||
|
||||
// Configure the service itself.
|
||||
groups := make([]*serverGroup, len(c.ServerGroups))
|
||||
svc = &Service{
|
||||
groups: groups,
|
||||
}
|
||||
|
||||
mainMwMtrc, err := newMainMiddlewareMetrics(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, srvGrp := range c.ServerGroups {
|
||||
// The Filtering Middlewares
|
||||
//
|
||||
// These are middlewares common to all filtering and server groups.
|
||||
// They change the flow of request handling, so they are separated.
|
||||
|
||||
dnsHdlr := dnsserver.WithMiddlewares(
|
||||
handler,
|
||||
preservice.New(&preservice.Config{
|
||||
Messages: c.Messages,
|
||||
HashMatcher: c.SafeBrowsing,
|
||||
Checker: c.DNSCheck,
|
||||
}),
|
||||
mainmw.New(&mainmw.Config{
|
||||
Metrics: mainMwMtrc,
|
||||
Messages: c.Messages,
|
||||
Cloner: c.Cloner,
|
||||
BillStat: c.BillStat,
|
||||
ErrColl: c.ErrColl,
|
||||
FilterStorage: c.FilterStorage,
|
||||
GeoIP: c.GeoIP,
|
||||
QueryLog: c.QueryLog,
|
||||
RuleStat: c.RuleStat,
|
||||
}),
|
||||
)
|
||||
|
||||
var servers []*server
|
||||
servers, err = newServers(c, srvGrp, dnsHdlr, rlMwMtrc, errCollListener, newListener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
|
||||
}
|
||||
|
||||
groups[i] = &serverGroup{
|
||||
name: srvGrp.Name,
|
||||
servers: servers,
|
||||
}
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// newMainMiddlewareMetrics returns a filtering-middleware metrics
|
||||
// implementation from the config.
|
||||
func newMainMiddlewareMetrics(c *Config) (mainMwMtrc MainMiddlewareMetrics, err error) {
|
||||
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
|
||||
if mainMwMtrc != nil {
|
||||
return mainMwMtrc, nil
|
||||
}
|
||||
|
||||
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mainmw metrics: %w", err)
|
||||
}
|
||||
|
||||
return mainMwMtrc, nil
|
||||
// serverGroup is a group of servers.
|
||||
type serverGroup struct {
|
||||
name agd.ServerGroupName
|
||||
servers []*server
|
||||
}
|
||||
|
||||
// server is a group of listeners.
|
||||
@ -303,27 +41,171 @@ type server struct {
|
||||
listeners []*listener
|
||||
}
|
||||
|
||||
// serverGroup is a group of servers.
|
||||
type serverGroup struct {
|
||||
name agd.ServerGroupName
|
||||
servers []*server
|
||||
// listener is a Listener along with some of its associated data.
|
||||
type listener struct {
|
||||
Listener
|
||||
|
||||
name string
|
||||
}
|
||||
|
||||
// Service is the main DNS service of AdGuard DNS.
|
||||
type Service struct {
|
||||
groups []*serverGroup
|
||||
}
|
||||
|
||||
// mustStartListener starts l and panics on any error.
|
||||
func mustStartListener(
|
||||
grp agd.ServerGroupName,
|
||||
srv agd.ServerName,
|
||||
l *listener,
|
||||
) {
|
||||
err := l.Start(context.Background())
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("group %q: server %q: starting %q: %w", grp, srv, l.name, err))
|
||||
// New returns a new DNS service.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
// Use either the configured listener initializer or the default one.
|
||||
newListener := c.NewListener
|
||||
if newListener == nil {
|
||||
newListener = NewListener
|
||||
}
|
||||
|
||||
errCollListener := &errCollMetricsListener{
|
||||
errColl: c.ErrColl,
|
||||
baseListener: dnssrvprom.NewServerMetricsListener(c.MetricsNamespace),
|
||||
}
|
||||
|
||||
// Configure the service itself.
|
||||
groups := make([]*serverGroup, 0, len(c.ServerGroups))
|
||||
|
||||
for _, srvGrp := range c.ServerGroups {
|
||||
g := &serverGroup{
|
||||
name: srvGrp.Name,
|
||||
}
|
||||
|
||||
g.servers, err = newServers(c, srvGrp, errCollListener, newListener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group %q: %w", srvGrp.Name, err)
|
||||
}
|
||||
|
||||
groups = append(groups, g)
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
groups: groups,
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// newServers creates a slice of servers.
|
||||
func newServers(
|
||||
c *Config,
|
||||
srvGrp *agd.ServerGroup,
|
||||
errCollListener *errCollMetricsListener,
|
||||
newListener NewListenerFunc,
|
||||
) (servers []*server, err error) {
|
||||
servers = make([]*server, 0, len(srvGrp.Servers))
|
||||
|
||||
for _, srv := range srvGrp.Servers {
|
||||
k := HandlerKey{
|
||||
Server: srv,
|
||||
ServerGroup: srvGrp,
|
||||
}
|
||||
handler, ok := c.Handlers[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no handler for server %q of group %q", srv.Name, srvGrp.Name)
|
||||
}
|
||||
|
||||
s := &server{
|
||||
name: srv.Name,
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
s.listeners, err = newListeners(c, srv, handler, errCollListener, newListener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server %q: %w", s.name, err)
|
||||
}
|
||||
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// newListeners creates a slice of listeners for a server.
|
||||
func newListeners(
|
||||
c *Config,
|
||||
srv *agd.Server,
|
||||
handler dnsserver.Handler,
|
||||
errCollListener *errCollMetricsListener,
|
||||
newListener NewListenerFunc,
|
||||
) (listeners []*listener, err error) {
|
||||
bindData := srv.BindData()
|
||||
listeners = make([]*listener, 0, len(bindData))
|
||||
for i, bindData := range bindData {
|
||||
var addr string
|
||||
if bindData.PrefixAddr == nil {
|
||||
addr = bindData.AddrPort.String()
|
||||
} else {
|
||||
addr = bindData.PrefixAddr.String()
|
||||
}
|
||||
|
||||
proto := srv.Protocol
|
||||
|
||||
name := listenerName(srv.Name, addr, proto)
|
||||
baseConf := dnsserver.ConfigBase{
|
||||
Network: dnsserver.NetworkAny,
|
||||
Handler: handler,
|
||||
Metrics: errCollListener,
|
||||
Disposer: c.Cloner,
|
||||
RequestContext: newContextConstructor(c.HandleTimeout),
|
||||
ListenConfig: newListenConfig(
|
||||
bindData.ListenConfig,
|
||||
c.ControlConf,
|
||||
c.ConnLimiter,
|
||||
proto,
|
||||
),
|
||||
Name: name,
|
||||
Addr: addr,
|
||||
}
|
||||
|
||||
l := &listener{
|
||||
name: name,
|
||||
}
|
||||
|
||||
l.Listener, err = newListener(srv, baseConf, c.NonDNS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, l)
|
||||
}
|
||||
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
// listenerName returns a standard name for a listener.
|
||||
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
|
||||
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
|
||||
}
|
||||
|
||||
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
|
||||
// servers. The resulting ListenConfig sets additional socket flags and
|
||||
// processes the control messages of connections created with ListenPacket.
|
||||
// Additionally, if l is not nil, it is used to limit the number of
|
||||
// simultaneously active stream-connections.
|
||||
func newListenConfig(
|
||||
original netext.ListenConfig,
|
||||
ctrlConf *netext.ControlConfig,
|
||||
l *connlimiter.Limiter,
|
||||
p agd.Protocol,
|
||||
) (lc netext.ListenConfig) {
|
||||
if original != nil {
|
||||
if l == nil {
|
||||
return original
|
||||
}
|
||||
|
||||
return connlimiter.NewListenConfig(original, l)
|
||||
}
|
||||
|
||||
if p == agd.ProtoDNS {
|
||||
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
|
||||
} else {
|
||||
lc = netext.DefaultListenConfig(ctrlConf)
|
||||
}
|
||||
|
||||
if l != nil {
|
||||
lc = connlimiter.NewListenConfig(lc, l)
|
||||
}
|
||||
|
||||
return lc
|
||||
}
|
||||
|
||||
// type check
|
||||
@ -331,13 +213,13 @@ var _ service.Interface = (*Service)(nil)
|
||||
|
||||
// Start implements the [service.Interface] interface for *Service. It panics
|
||||
// if one of the listeners could not start.
|
||||
func (svc *Service) Start(_ context.Context) (err error) {
|
||||
func (svc *Service) Start(ctx context.Context) (err error) {
|
||||
for _, g := range svc.groups {
|
||||
for _, s := range g.servers {
|
||||
for _, l := range s.listeners {
|
||||
// Consider inability to start any one DNS listener a fatal
|
||||
// error.
|
||||
mustStartListener(g.name, s.name, l)
|
||||
mustStartListener(ctx, g.name, s.name, l)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -345,17 +227,17 @@ func (svc *Service) Start(_ context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdownListeners is a helper function that shuts down all listeners of a
|
||||
// server.
|
||||
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
|
||||
for _, l := range listeners {
|
||||
err = l.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
|
||||
}
|
||||
// mustStartListener starts l and panics on any error.
|
||||
func mustStartListener(
|
||||
ctx context.Context,
|
||||
srvGrp agd.ServerGroupName,
|
||||
srv agd.ServerName,
|
||||
l *listener,
|
||||
) {
|
||||
err := l.Start(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("group %q: server %q: starting %q: %w", srvGrp, srv, l.name, err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown implements the [service.Interface] interface for *Service.
|
||||
@ -378,9 +260,22 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdownListeners is a helper function that shuts down all listeners of a
|
||||
// server.
|
||||
func shutdownListeners(ctx context.Context, listeners []*listener) (err error) {
|
||||
for _, l := range listeners {
|
||||
err = l.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shutting down listener %q: %w", l.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle is a simple helper to test the handling of DNS requests.
|
||||
//
|
||||
// TODO(a.garipov): Remove once the mainmw refactoring is complete.
|
||||
// TODO(a.garipov): Remove once the refactoring is complete.
|
||||
func (svc *Service) Handle(
|
||||
ctx context.Context,
|
||||
grpName agd.ServerGroupName,
|
||||
@ -417,33 +312,10 @@ func (svc *Service) Handle(
|
||||
return srv.handler.ServeDNS(ctx, rw, r)
|
||||
}
|
||||
|
||||
// Listener is a type alias for dnsserver.Server to make internal naming more
|
||||
// consistent.
|
||||
type Listener = dnsserver.Server
|
||||
|
||||
// NewListenerFunc is the type for DNS listener constructors.
|
||||
type NewListenerFunc func(
|
||||
s *agd.Server,
|
||||
baseConf dnsserver.ConfigBase,
|
||||
nonDNS http.Handler,
|
||||
) (l Listener, err error)
|
||||
|
||||
// listener is a Listener along with some of its associated data.
|
||||
type listener struct {
|
||||
Listener
|
||||
|
||||
name string
|
||||
}
|
||||
|
||||
// listenerName returns a standard name for a listener.
|
||||
func listenerName(srvName agd.ServerName, addr string, proto agd.Protocol) (name string) {
|
||||
return fmt.Sprintf("%s/%s/%s", srvName, proto, addr)
|
||||
}
|
||||
|
||||
// NewListener returns a new Listener. It is the default DNS listener
|
||||
// constructor.
|
||||
//
|
||||
// TODO(a.garipov): Replace this in tests with [netext.ListenConfig].
|
||||
// TODO(a.garipov): Replace this in tests with [netext.ListenConfig].
|
||||
func NewListener(
|
||||
s *agd.Server,
|
||||
baseConf dnsserver.ConfigBase,
|
||||
@ -500,213 +372,8 @@ func NewListener(
|
||||
TLSConfig: s.TLS,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("bad protocol %v", p)
|
||||
return nil, fmt.Errorf("protocol: %w: %d", errors.ErrBadEnumValue, p)
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// contextConstructor is a [dnsserver.ContextConstructor] implementation that
|
||||
// that returns a context with the given timeout as well as a new
|
||||
// [agd.RequestID].
|
||||
type contextConstructor struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newContextConstructor returns a new properly initialized *contextConstructor.
|
||||
func newContextConstructor(timeout time.Duration) (c *contextConstructor) {
|
||||
return &contextConstructor{
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ dnsserver.ContextConstructor = (*contextConstructor)(nil)
|
||||
|
||||
// New implements the [dnsserver.ContextConstructor] interface for
|
||||
// *contextConstructor. It returns a context with a new [agd.RequestID] as well
|
||||
// as its timeout and the corresponding cancelation function.
|
||||
func (c *contextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
|
||||
ctx = agd.WithRequestID(ctx, agd.NewRequestID())
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// newServers creates a slice of servers.
|
||||
//
|
||||
// TODO(a.garipov): Refactor this into a builder pattern.
|
||||
func newServers(
|
||||
c *Config,
|
||||
srvGrp *agd.ServerGroup,
|
||||
handler dnsserver.Handler,
|
||||
rlMwMtrc ratelimitmw.Metrics,
|
||||
errCollListener *errCollMetricsListener,
|
||||
newListener NewListenerFunc,
|
||||
) (servers []*server, err error) {
|
||||
servers = make([]*server, len(srvGrp.Servers))
|
||||
|
||||
for i, s := range srvGrp.Servers {
|
||||
// The Initial Middlewares
|
||||
//
|
||||
// These middlewares are either specific to the server or must be the
|
||||
// furthest away from the handler and thus are the first to process
|
||||
// a request.
|
||||
|
||||
// Assume that all the validations have been made during the
|
||||
// configuration validation step back in package cmd. If we ever get
|
||||
// new ways of receiving configuration, remove this assumption and
|
||||
// validate fg.
|
||||
fg := c.FilteringGroups[srvGrp.FilteringGroup]
|
||||
|
||||
df := newDeviceFinder(c, srvGrp, s)
|
||||
rlm := ratelimitmw.New(&ratelimitmw.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw"),
|
||||
Messages: c.Messages,
|
||||
FilteringGroup: fg,
|
||||
ServerGroup: srvGrp,
|
||||
Server: s,
|
||||
AccessManager: c.AccessManager,
|
||||
DeviceFinder: df,
|
||||
ErrColl: c.ErrColl,
|
||||
GeoIP: c.GeoIP,
|
||||
Metrics: rlMwMtrc,
|
||||
Limiter: c.RateLimit,
|
||||
// Only apply rate-limiting logic to plain DNS.
|
||||
Protocols: []agd.Protocol{agd.ProtoDNS},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ratelimit: %w", err)
|
||||
}
|
||||
|
||||
imw := initial.New(&initial.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
|
||||
})
|
||||
|
||||
h := dnsserver.WithMiddlewares(
|
||||
handler,
|
||||
|
||||
// Keep the rate limiting and access middlewares as the outer ones
|
||||
// to make sure that the application logic isn't touched if the
|
||||
// request is ratelimited or blocked by access settings.
|
||||
rlm,
|
||||
imw,
|
||||
)
|
||||
|
||||
srvName := s.Name
|
||||
|
||||
var listeners []*listener
|
||||
listeners, err = newListeners(c, s, h, errCollListener, newListener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server %q: %w", srvName, err)
|
||||
}
|
||||
|
||||
servers[i] = &server{
|
||||
name: srvName,
|
||||
handler: h,
|
||||
listeners: listeners,
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// newDeviceFinder returns a new [agd.DeviceFinder] for a server based on the
|
||||
// configuration.
|
||||
func newDeviceFinder(c *Config, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
|
||||
if !g.ProfilesEnabled {
|
||||
return agd.EmptyDeviceFinder{}
|
||||
}
|
||||
|
||||
return devicefinder.NewDefault(&devicefinder.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
|
||||
ProfileDB: c.ProfileDB,
|
||||
HumanIDParser: c.HumanIDParser,
|
||||
Server: s,
|
||||
DeviceDomains: g.TLS.DeviceDomains,
|
||||
})
|
||||
}
|
||||
|
||||
// newServers creates a slice of listeners for a server.
|
||||
func newListeners(
|
||||
c *Config,
|
||||
srv *agd.Server,
|
||||
handler dnsserver.Handler,
|
||||
errCollListener *errCollMetricsListener,
|
||||
newListener NewListenerFunc,
|
||||
) (listeners []*listener, err error) {
|
||||
bindData := srv.BindData()
|
||||
listeners = make([]*listener, 0, len(bindData))
|
||||
for i, bindData := range bindData {
|
||||
var addr string
|
||||
if bindData.PrefixAddr == nil {
|
||||
addr = bindData.AddrPort.String()
|
||||
} else {
|
||||
addr = bindData.PrefixAddr.String()
|
||||
}
|
||||
|
||||
proto := srv.Protocol
|
||||
|
||||
name := listenerName(srv.Name, addr, proto)
|
||||
baseConf := dnsserver.ConfigBase{
|
||||
Network: dnsserver.NetworkAny,
|
||||
Handler: handler,
|
||||
Metrics: errCollListener,
|
||||
Disposer: c.Cloner,
|
||||
RequestContext: newContextConstructor(c.HandleTimeout),
|
||||
ListenConfig: newListenConfig(
|
||||
bindData.ListenConfig,
|
||||
c.ControlConf,
|
||||
c.ConnLimiter,
|
||||
proto,
|
||||
),
|
||||
Name: name,
|
||||
Addr: addr,
|
||||
}
|
||||
|
||||
var l Listener
|
||||
l, err = newListener(srv, baseConf, c.NonDNS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bind data at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
listeners = append(listeners, &listener{
|
||||
name: name,
|
||||
Listener: l,
|
||||
})
|
||||
}
|
||||
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
// newListenConfig returns the netext.ListenConfig used by the plain-DNS
|
||||
// servers. The resulting ListenConfig sets additional socket flags and
|
||||
// processes the control messages of connections created with ListenPacket.
|
||||
// Additionally, if l is not nil, it is used to limit the number of
|
||||
// simultaneously active stream-connections.
|
||||
func newListenConfig(
|
||||
original netext.ListenConfig,
|
||||
ctrlConf *netext.ControlConfig,
|
||||
l *connlimiter.Limiter,
|
||||
p agd.Protocol,
|
||||
) (lc netext.ListenConfig) {
|
||||
if original != nil {
|
||||
if l == nil {
|
||||
return original
|
||||
}
|
||||
|
||||
return connlimiter.NewListenConfig(original, l)
|
||||
}
|
||||
|
||||
if p == agd.ProtoDNS {
|
||||
lc = netext.DefaultListenConfigWithOOB(ctrlConf)
|
||||
} else {
|
||||
lc = netext.DefaultListenConfig(ctrlConf)
|
||||
}
|
||||
|
||||
if l != nil {
|
||||
lc = connlimiter.NewListenConfig(lc, l)
|
||||
}
|
||||
|
||||
return lc
|
||||
}
|
||||
|
@ -10,26 +10,17 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testSrvGrpName is the [agd.ServerGroupName] for tests.
|
||||
const testSrvGrpName agd.ServerGroupName = "test_group"
|
||||
|
||||
// type check
|
||||
var _ agdservice.Refresher = (*forward.Handler)(nil)
|
||||
|
||||
@ -159,16 +150,23 @@ func TestService_Start(t *testing.T) {
|
||||
AddrPort: netip.MustParseAddrPort("127.0.0.1:53"),
|
||||
})
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
Name: dnssvctest.ServerGroupName,
|
||||
Servers: []*agd.Server{srv},
|
||||
}
|
||||
|
||||
k := dnssvc.HandlerKey{
|
||||
Server: srv,
|
||||
ServerGroup: srvGrp,
|
||||
}
|
||||
|
||||
c := &dnssvc.Config{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
NewListener: newTestListenerFunc(tl),
|
||||
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
|
||||
Handler: dnsservertest.DefaultHandler(),
|
||||
MetricsNamespace: "test_start",
|
||||
ServerGroups: []*agd.ServerGroup{{
|
||||
Name: "test_group",
|
||||
Servers: []*agd.Server{srv},
|
||||
}},
|
||||
NewListener: newTestListenerFunc(tl),
|
||||
Handlers: dnssvc.Handlers{
|
||||
k: dnsservertest.NewDefaultHandler(),
|
||||
},
|
||||
MetricsNamespace: "test_start",
|
||||
ServerGroups: []*agd.ServerGroup{srvGrp},
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
@ -206,15 +204,25 @@ func TestNew(t *testing.T) {
|
||||
}),
|
||||
}
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
Name: dnssvctest.ServerGroupName,
|
||||
Servers: srvs,
|
||||
}
|
||||
|
||||
handlers := dnssvc.Handlers{}
|
||||
for _, srv := range srvs {
|
||||
k := dnssvc.HandlerKey{
|
||||
Server: srv,
|
||||
ServerGroup: srvGrp,
|
||||
}
|
||||
|
||||
handlers[k] = dnsservertest.NewDefaultHandler()
|
||||
}
|
||||
|
||||
c := &dnssvc.Config{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
Handler: dnsservertest.DefaultHandler(),
|
||||
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
|
||||
MetricsNamespace: "test_new",
|
||||
ServerGroups: []*agd.ServerGroup{{
|
||||
Name: "test_group",
|
||||
Servers: srvs,
|
||||
}},
|
||||
Handlers: handlers,
|
||||
MetricsNamespace: "test_new",
|
||||
ServerGroups: []*agd.ServerGroup{srvGrp},
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
|
211
internal/dnssvc/handler.go
Normal file
211
internal/dnssvc/handler.go
Normal file
@ -0,0 +1,211 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/cache"
|
||||
dnssrvprom "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preservice"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/preupstream"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/ecscache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// NewHandlers returns the main DNS handlers wrapped in all necessary
|
||||
// middlewares. c must not be nil.
|
||||
func NewHandlers(ctx context.Context, c *HandlersConfig) (handlers Handlers, err error) {
|
||||
handler := wrapPreUpstreamMw(ctx, c)
|
||||
|
||||
mainMwMtrc, err := newMainMiddlewareMetrics(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainMw := mainmw.New(&mainmw.Config{
|
||||
Cloner: c.Cloner,
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "mainmw"),
|
||||
Messages: c.Messages,
|
||||
BillStat: c.BillStat,
|
||||
ErrColl: c.ErrColl,
|
||||
FilterStorage: c.FilterStorage,
|
||||
GeoIP: c.GeoIP,
|
||||
QueryLog: c.QueryLog,
|
||||
Metrics: mainMwMtrc,
|
||||
RuleStat: c.RuleStat,
|
||||
})
|
||||
|
||||
handler = mainMw.Wrap(handler)
|
||||
|
||||
preSvcMw := preservice.New(&preservice.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "presvcmw"),
|
||||
Messages: c.Messages,
|
||||
HashMatcher: c.HashMatcher,
|
||||
Checker: c.DNSCheck,
|
||||
})
|
||||
|
||||
handler = preSvcMw.Wrap(handler)
|
||||
|
||||
postInitMw := c.PluginRegistry.PostInitialMiddleware()
|
||||
if postInitMw != nil {
|
||||
handler = postInitMw.Wrap(handler)
|
||||
}
|
||||
|
||||
initMw := initial.New(&initial.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "initmw"),
|
||||
})
|
||||
|
||||
handler = initMw.Wrap(handler)
|
||||
|
||||
return newHandlersForServers(c, handler)
|
||||
}
|
||||
|
||||
// wrapPreUpstreamMw returns the handler wrapped into the pre-upstream
|
||||
// middlewares.
|
||||
//
|
||||
// TODO(a.garipov): Adapt the cache tests that previously were in package
|
||||
// preupstream.
|
||||
func wrapPreUpstreamMw(ctx context.Context, c *HandlersConfig) (wrapped dnsserver.Handler) {
|
||||
// TODO(a.garipov): Use in other places if necessary.
|
||||
l := c.BaseLogger.With(slogutil.KeyPrefix, "dnssvc")
|
||||
|
||||
wrapped = c.Handler
|
||||
switch conf := c.Cache; conf.Type {
|
||||
case CacheTypeNone:
|
||||
l.WarnContext(ctx, "cache disabled")
|
||||
case CacheTypeSimple:
|
||||
l.InfoContext(ctx, "plain cache enabled", "count", conf.NoECSCount)
|
||||
|
||||
cacheMw := cache.NewMiddleware(&cache.MiddlewareConfig{
|
||||
// TODO(a.garipov): Do not use promauto and refactor.
|
||||
MetricsListener: dnssrvprom.NewCacheMetricsListener(metrics.Namespace()),
|
||||
Size: conf.NoECSCount,
|
||||
MinTTL: conf.MinTTL,
|
||||
OverrideTTL: conf.OverrideCacheTTL,
|
||||
})
|
||||
|
||||
wrapped = cacheMw.Wrap(wrapped)
|
||||
case CacheTypeECS:
|
||||
l.InfoContext(
|
||||
ctx,
|
||||
"ecs cache enabled",
|
||||
"ecs_count", conf.ECSCount,
|
||||
"no_ecs_count", conf.NoECSCount,
|
||||
)
|
||||
|
||||
cacheMw := ecscache.NewMiddleware(&ecscache.MiddlewareConfig{
|
||||
Cloner: c.Cloner,
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "ecscache"),
|
||||
CacheManager: c.CacheManager,
|
||||
GeoIP: c.GeoIP,
|
||||
NoECSCount: conf.NoECSCount,
|
||||
ECSCount: conf.ECSCount,
|
||||
MinTTL: conf.MinTTL,
|
||||
OverrideTTL: conf.OverrideCacheTTL,
|
||||
})
|
||||
|
||||
wrapped = cacheMw.Wrap(wrapped)
|
||||
default:
|
||||
panic(fmt.Errorf("cache type: %w: %d", errors.ErrBadEnumValue, conf.Type))
|
||||
}
|
||||
|
||||
preUps := preupstream.New(ctx, &preupstream.Config{
|
||||
DB: c.DNSDB,
|
||||
})
|
||||
|
||||
wrapped = preUps.Wrap(wrapped)
|
||||
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// newMainMiddlewareMetrics returns a filtering-middleware metrics
|
||||
// implementation from the config.
|
||||
func newMainMiddlewareMetrics(c *HandlersConfig) (mainMwMtrc MainMiddlewareMetrics, err error) {
|
||||
mainMwMtrc = c.PluginRegistry.MainMiddlewareMetrics()
|
||||
if mainMwMtrc != nil {
|
||||
return mainMwMtrc, nil
|
||||
}
|
||||
|
||||
mainMwMtrc, err = metrics.NewDefaultMainMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mainmw metrics: %w", err)
|
||||
}
|
||||
|
||||
return mainMwMtrc, nil
|
||||
}
|
||||
|
||||
// newHandlersForServers returns a handler map for each server group and each
|
||||
// server.
|
||||
func newHandlersForServers(c *HandlersConfig, h dnsserver.Handler) (handlers Handlers, err error) {
|
||||
rlMwMtrc, err := metrics.NewDefaultRatelimitMiddleware(c.MetricsNamespace, c.PrometheusRegisterer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ratelimit middleware metrics: %w", err)
|
||||
}
|
||||
|
||||
handlers = Handlers{}
|
||||
|
||||
rlMwLogger := c.BaseLogger.With(slogutil.KeyPrefix, "ratelimitmw")
|
||||
for _, srvGrp := range c.ServerGroups {
|
||||
fltGrp, ok := c.FilteringGroups[srvGrp.FilteringGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(
|
||||
"no filtering group %q for server group %q",
|
||||
srvGrp.FilteringGroup,
|
||||
srvGrp.Name,
|
||||
)
|
||||
}
|
||||
|
||||
for _, srv := range srvGrp.Servers {
|
||||
rlMw := ratelimitmw.New(&ratelimitmw.Config{
|
||||
Logger: rlMwLogger,
|
||||
Messages: c.Messages,
|
||||
FilteringGroup: fltGrp,
|
||||
ServerGroup: srvGrp,
|
||||
Server: srv,
|
||||
StructuredErrors: c.StructuredErrors,
|
||||
AccessManager: c.AccessManager,
|
||||
DeviceFinder: newDeviceFinder(c, srvGrp, srv),
|
||||
ErrColl: c.ErrColl,
|
||||
GeoIP: c.GeoIP,
|
||||
Metrics: rlMwMtrc,
|
||||
Limiter: c.RateLimit,
|
||||
Protocols: []agd.Protocol{agd.ProtoDNS},
|
||||
EDEEnabled: c.EDEEnabled,
|
||||
})
|
||||
|
||||
k := HandlerKey{
|
||||
Server: srv,
|
||||
ServerGroup: srvGrp,
|
||||
}
|
||||
|
||||
handlers[k] = rlMw.Wrap(h)
|
||||
}
|
||||
}
|
||||
|
||||
return handlers, nil
|
||||
}
|
||||
|
||||
// newDeviceFinder returns a new agd.DeviceFinder for a server based on the
|
||||
// configuration. All arguments must not be nil.
|
||||
func newDeviceFinder(c *HandlersConfig, g *agd.ServerGroup, s *agd.Server) (df agd.DeviceFinder) {
|
||||
if !g.ProfilesEnabled {
|
||||
return agd.EmptyDeviceFinder{}
|
||||
}
|
||||
|
||||
return devicefinder.NewDefault(&devicefinder.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "devicefinder"),
|
||||
ProfileDB: c.ProfileDB,
|
||||
HumanIDParser: c.HumanIDParser,
|
||||
Server: s,
|
||||
DeviceDomains: g.TLS.DeviceDomains,
|
||||
})
|
||||
}
|
188
internal/dnssvc/handler_test.go
Normal file
188
internal/dnssvc/handler_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
package dnssvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
accessMgr := &agdtest.AccessManager{
|
||||
OnIsBlockedHost: func(host string, qt uint16) (blocked bool) { panic("not implemented") },
|
||||
OnIsBlockedIP: func(ip netip.Addr) (blocked bool) { panic("not implemented") },
|
||||
}
|
||||
|
||||
billStat := &agdtest.BillStatRecorder{
|
||||
OnRecord: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
_ geoip.Country,
|
||||
_ geoip.ASN,
|
||||
_ time.Time,
|
||||
_ agd.Protocol,
|
||||
) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
dnsCk := &agdtest.DNSCheck{
|
||||
OnCheck: func(
|
||||
_ context.Context,
|
||||
_ *dns.Msg,
|
||||
_ *agd.RequestInfo,
|
||||
) (resp *dns.Msg, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
dnsDB := &agdtest.DNSDB{
|
||||
OnRecord: func(_ context.Context, _ *dns.Msg, _ *agd.RequestInfo) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
fltGrps := map[agd.FilteringGroupID]*agd.FilteringGroup{
|
||||
dnssvctest.FilteringGroupID: {
|
||||
ID: dnssvctest.FilteringGroupID,
|
||||
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
|
||||
RuleListsEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
fltStrg := &agdtest.FilterStorage{
|
||||
OnFilterFromContext: func(_ context.Context, _ *agd.RequestInfo) (f filter.Interface) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hashMatcher := &agdtest.HashMatcher{
|
||||
OnMatchByPrefix: func(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
) (hashes []string, matched bool, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
queryLog := &agdtest.QueryLog{
|
||||
OnWrite: func(_ context.Context, _ *querylog.Entry) (err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
ruleStat := &agdtest.RuleStat{
|
||||
OnCollect: func(_ context.Context, _ agd.FilterListID, _ agd.FilterRuleText) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
srv := dnssvctest.NewServer(dnssvctest.ServerName, agd.ProtoDoT, &agd.ServerBindData{
|
||||
AddrPort: dnssvctest.ServerAddrPort,
|
||||
})
|
||||
|
||||
srvGrp := &agd.ServerGroup{
|
||||
DDR: &agd.DDR{
|
||||
Enabled: true,
|
||||
},
|
||||
TLS: &agd.TLS{},
|
||||
Name: dnssvctest.ServerGroupName,
|
||||
FilteringGroup: dnssvctest.FilteringGroupID,
|
||||
Servers: []*agd.Server{srv},
|
||||
ProfilesEnabled: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
cacheConf *dnssvc.CacheConfig
|
||||
name string
|
||||
}{{
|
||||
cacheConf: &dnssvc.CacheConfig{
|
||||
Type: dnssvc.CacheTypeNone,
|
||||
},
|
||||
name: "no_cache",
|
||||
}, {
|
||||
cacheConf: &dnssvc.CacheConfig{
|
||||
MinTTL: 10 * time.Second,
|
||||
NoECSCount: 100,
|
||||
Type: dnssvc.CacheTypeSimple,
|
||||
OverrideCacheTTL: true,
|
||||
},
|
||||
name: "cache_simple",
|
||||
}, {
|
||||
cacheConf: &dnssvc.CacheConfig{
|
||||
MinTTL: 10 * time.Second,
|
||||
ECSCount: 100,
|
||||
NoECSCount: 100,
|
||||
Type: dnssvc.CacheTypeECS,
|
||||
OverrideCacheTTL: true,
|
||||
},
|
||||
name: "cache_ecs",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, dnssvctest.Timeout)
|
||||
handlers, err := dnssvc.NewHandlers(ctx, &dnssvc.HandlersConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
Cloner: agdtest.NewCloner(),
|
||||
Cache: tc.cacheConf,
|
||||
HumanIDParser: agd.NewHumanIDParser(),
|
||||
Messages: agdtest.NewConstructor(t),
|
||||
PluginRegistry: nil,
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
AccessManager: accessMgr,
|
||||
BillStat: billStat,
|
||||
// TODO(a.garipov): Create a test implementation?
|
||||
CacheManager: agdcache.EmptyManager{},
|
||||
DNSCheck: dnsCk,
|
||||
DNSDB: dnsDB,
|
||||
ErrColl: agdtest.NewErrorCollector(),
|
||||
FilterStorage: fltStrg,
|
||||
GeoIP: agdtest.NewGeoIP(),
|
||||
Handler: dnsservertest.NewPanicHandler(),
|
||||
HashMatcher: hashMatcher,
|
||||
ProfileDB: agdtest.NewProfileDB(),
|
||||
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
|
||||
QueryLog: queryLog,
|
||||
RateLimit: agdtest.NewRateLimit(),
|
||||
RuleStat: ruleStat,
|
||||
MetricsNamespace: path.Base(t.Name()),
|
||||
FilteringGroups: fltGrps,
|
||||
ServerGroups: []*agd.ServerGroup{srvGrp},
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, handlers, 1)
|
||||
|
||||
for k, v := range handlers {
|
||||
assert.Same(t, srv, k.Server)
|
||||
assert.Same(t, srvGrp, k.ServerGroup)
|
||||
assert.NotNil(t, v)
|
||||
|
||||
break
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/access"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdcache"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/dnssvctest"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
@ -46,7 +48,7 @@ import (
|
||||
// from the service. The channels must not be nil. Each sending to a channel
|
||||
// wrapped with [testutil.RequireSend] using [dnssvctest.Timeout].
|
||||
//
|
||||
// It also uses the [dnsservertest.DefaultHandler] to create the DNS handler.
|
||||
// It also uses the [dnsservertest.NewDefaultHandler] to create the DNS handler.
|
||||
func newTestService(
|
||||
t testing.TB,
|
||||
flt filter.Interface,
|
||||
@ -81,46 +83,14 @@ func newTestService(
|
||||
QueryLogEnabled: true,
|
||||
}
|
||||
|
||||
db := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = func(
|
||||
_ context.Context,
|
||||
id agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
testutil.RequireSend(pt, profileDBCh, id, dnssvctest.Timeout)
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
id agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
testutil.RequireSend(pt, profileDBCh, id, dnssvctest.Timeout)
|
||||
|
||||
return prof, dev, nil
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
ctx context.Context,
|
||||
ip netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
return prof, dev, nil
|
||||
}
|
||||
|
||||
accessManager := &agdtest.AccessManager{
|
||||
@ -145,18 +115,12 @@ func newTestService(
|
||||
Continent: geoip.ContinentEU,
|
||||
ASN: 42,
|
||||
}
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ *geoip.Location,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(host string, _ netip.Addr) (l *geoip.Location, err error) {
|
||||
testutil.RequireSend(pt, geoIPCh, host, dnssvctest.Timeout)
|
||||
|
||||
return loc, nil
|
||||
},
|
||||
geoIP := agdtest.NewGeoIP()
|
||||
geoIP.OnData = func(host string, _ netip.Addr) (l *geoip.Location, err error) {
|
||||
testutil.RequireSend(pt, geoIPCh, host, dnssvctest.Timeout)
|
||||
|
||||
return loc, nil
|
||||
}
|
||||
|
||||
fltStrg := &agdtest.FilterStorage{
|
||||
@ -205,25 +169,38 @@ func newTestService(
|
||||
},
|
||||
}
|
||||
|
||||
rl := &agdtest.RateLimit{
|
||||
OnIsRateLimited: func(
|
||||
_ context.Context,
|
||||
_ *dns.Msg,
|
||||
_ netip.Addr,
|
||||
) (drop, allowlisted bool, err error) {
|
||||
return true, false, nil
|
||||
},
|
||||
OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {
|
||||
panic("not implemented")
|
||||
},
|
||||
rl := agdtest.NewRateLimit()
|
||||
rl.OnIsRateLimited = func(
|
||||
_ context.Context,
|
||||
_ *dns.Msg,
|
||||
_ netip.Addr,
|
||||
) (shouldDrop, isAllowlisted bool, err error) {
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
testFltGrpID := agd.FilteringGroupID("1234")
|
||||
srvGrps := []*agd.ServerGroup{{
|
||||
DDR: &agd.DDR{
|
||||
Enabled: true,
|
||||
},
|
||||
TLS: &agd.TLS{
|
||||
DeviceDomains: []string{dnssvctest.DomainForDevices},
|
||||
},
|
||||
Name: dnssvctest.ServerGroupName,
|
||||
FilteringGroup: dnssvctest.FilteringGroupID,
|
||||
Servers: []*agd.Server{srv},
|
||||
ProfilesEnabled: true,
|
||||
}}
|
||||
|
||||
c := &dnssvc.Config{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
AccessManager: accessManager,
|
||||
Messages: agdtest.NewConstructor(t),
|
||||
hdlrConf := &dnssvc.HandlersConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
Cache: &dnssvc.CacheConfig{
|
||||
Type: dnssvc.CacheTypeNone,
|
||||
},
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
Cloner: agdtest.NewCloner(),
|
||||
HumanIDParser: agd.NewHumanIDParser(),
|
||||
Messages: agdtest.NewConstructor(t),
|
||||
AccessManager: accessManager,
|
||||
BillStat: &agdtest.BillStatRecorder{
|
||||
OnRecord: func(
|
||||
_ context.Context,
|
||||
@ -235,42 +212,47 @@ func newTestService(
|
||||
) {
|
||||
},
|
||||
},
|
||||
ProfileDB: db,
|
||||
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
|
||||
CacheManager: agdcache.EmptyManager{},
|
||||
DNSCheck: dnsCk,
|
||||
NonDNS: http.NotFoundHandler(),
|
||||
DNSDB: dnsDB,
|
||||
ErrColl: errColl,
|
||||
FilterStorage: fltStrg,
|
||||
GeoIP: geoIP,
|
||||
Handler: dnsservertest.NewDefaultHandler(),
|
||||
HashMatcher: hashprefix.NewMatcher(nil),
|
||||
ProfileDB: profDB,
|
||||
PrometheusRegisterer: agdtest.NewTestPrometheusRegisterer(),
|
||||
QueryLog: ql,
|
||||
RuleStat: ruleStat,
|
||||
NewListener: newTestListenerFunc(tl),
|
||||
Handler: dnsservertest.DefaultHandler(),
|
||||
RateLimit: rl,
|
||||
RuleStat: ruleStat,
|
||||
MetricsNamespace: path.Base(t.Name()),
|
||||
FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{
|
||||
testFltGrpID: {
|
||||
ID: testFltGrpID,
|
||||
dnssvctest.FilteringGroupID: {
|
||||
ID: dnssvctest.FilteringGroupID,
|
||||
RuleListIDs: []agd.FilterListID{dnssvctest.FilterListID1},
|
||||
RuleListsEnabled: true,
|
||||
},
|
||||
},
|
||||
ServerGroups: []*agd.ServerGroup{{
|
||||
DDR: &agd.DDR{
|
||||
Enabled: true,
|
||||
},
|
||||
TLS: &agd.TLS{
|
||||
DeviceDomains: []string{dnssvctest.DomainForDevices},
|
||||
},
|
||||
Name: testSrvGrpName,
|
||||
FilteringGroup: testFltGrpID,
|
||||
Servers: []*agd.Server{srv},
|
||||
ProfilesEnabled: true,
|
||||
}},
|
||||
ServerGroups: srvGrps,
|
||||
EDEEnabled: true,
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
ctx := context.Background()
|
||||
handlers, err := dnssvc.NewHandlers(ctx, hdlrConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
c := &dnssvc.Config{
|
||||
Handlers: handlers,
|
||||
NewListener: newTestListenerFunc(tl),
|
||||
Cloner: agdtest.NewCloner(),
|
||||
ErrColl: errColl,
|
||||
NonDNS: http.NotFoundHandler(),
|
||||
MetricsNamespace: path.Base(t.Name()),
|
||||
ServerGroups: srvGrps,
|
||||
HandleTimeout: dnssvctest.Timeout,
|
||||
}
|
||||
|
||||
svc, err = dnssvc.New(c)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, svc)
|
||||
|
||||
@ -283,6 +265,7 @@ func newTestService(
|
||||
return svc, srvAddr
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Refactor to test handlers separately from the service.
|
||||
func TestService_Wrap(t *testing.T) {
|
||||
profileDBCh := make(chan agd.DeviceID, 1)
|
||||
querylogCh := make(chan *querylog.Entry, 1)
|
||||
@ -346,7 +329,7 @@ func TestService_Wrap(t *testing.T) {
|
||||
TLSServerName: dnssvctest.DeviceIDSrvName,
|
||||
})
|
||||
|
||||
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req)
|
||||
err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
@ -420,7 +403,7 @@ func TestService_Wrap(t *testing.T) {
|
||||
TLSServerName: dnssvctest.DeviceIDSrvName,
|
||||
})
|
||||
|
||||
err := svc.Handle(ctx, testSrvGrpName, dnssvctest.ServerName, rw, req)
|
||||
err := svc.Handle(ctx, dnssvctest.ServerGroupName, dnssvctest.ServerName, rw, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := rw.Msg()
|
||||
|
@ -80,35 +80,9 @@ func TestDefault_Find_plainAddrs(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr),
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: newOnProfileByLinkedIP(dnssvctest.LinkedAddr),
|
||||
}
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDedicatedIP = newOnProfileByDedicatedIP(dnssvctest.DedicatedAddr)
|
||||
profDB.OnProfileByLinkedIP = newOnProfileByLinkedIP(dnssvctest.LinkedAddr)
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
@ -177,40 +151,8 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
@ -231,44 +173,12 @@ func TestDefault_Find_plainEDNS(t *testing.T) {
|
||||
func TestDefault_Find_deleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
_ agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profDeleted, devNormal, nil
|
||||
},
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByLinkedIP = func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profDeleted, devNormal, nil
|
||||
}
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
@ -291,44 +201,21 @@ func TestDefault_Find_byHumanID(t *testing.T) {
|
||||
// device-type and profile data regardless of the case.
|
||||
extIDStr := "OTR-" + strings.ToUpper(dnssvctest.ProfileIDStr) + "-" + dnssvctest.HumanIDStr + "-!!!"
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profNormal, devAuto, nil
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnCreateAutoDevice = func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profNormal, devAuto, nil
|
||||
}
|
||||
profDB.OnProfileByHumanID = func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
}
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
|
@ -2,7 +2,6 @@ package devicefinder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"path"
|
||||
"testing"
|
||||
@ -88,48 +87,16 @@ func TestDefault_Find_DoHAuth(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
if tc.profDBDev != nil {
|
||||
return profNormal, tc.profDBDev, nil
|
||||
}
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
if tc.profDBDev != nil {
|
||||
return profNormal, tc.profDBDev, nil
|
||||
}
|
||||
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
return nil, nil, profiledb.ErrDeviceNotFound
|
||||
}
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
@ -220,44 +187,12 @@ func TestDefault_Find_DoHAuthOnly(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profNormal, tc.profDBDev, nil
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
return profNormal, tc.profDBDev, nil
|
||||
}
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
@ -351,34 +286,9 @@ func TestDefault_Find_DoH(t *testing.T) {
|
||||
name: "human_id_path_match",
|
||||
}}
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
|
||||
|
||||
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
|
||||
profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -450,34 +360,9 @@ func TestDefault_Find_stdEncrypted(t *testing.T) {
|
||||
deviceDomains: []string{dnssvctest.DomainForDevices},
|
||||
}}
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: newOnProfileByDeviceID(dnssvctest.DeviceID),
|
||||
|
||||
OnProfileByHumanID: newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower),
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
profDB := agdtest.NewProfileDB()
|
||||
profDB.OnProfileByDeviceID = newOnProfileByDeviceID(dnssvctest.DeviceID)
|
||||
profDB.OnProfileByHumanID = newOnProfileByHumanID(dnssvctest.ProfileID, dnssvctest.HumanIDLower)
|
||||
|
||||
srvData := []struct {
|
||||
srv *agd.Server
|
||||
|
@ -1,8 +1,6 @@
|
||||
package devicefinder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
@ -49,53 +47,13 @@ func TestDefault_Find_humanID(t *testing.T) {
|
||||
in: "otr-abcd1234-!!!",
|
||||
}}
|
||||
|
||||
profDB := &agdtest.ProfileDB{
|
||||
OnCreateAutoDevice: func(
|
||||
ctx context.Context,
|
||||
id agd.ProfileID,
|
||||
humanID agd.HumanID,
|
||||
devType agd.DeviceType,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDedicatedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByDeviceID: func(
|
||||
_ context.Context,
|
||||
devID agd.DeviceID,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByHumanID: func(
|
||||
_ context.Context,
|
||||
_ agd.ProfileID,
|
||||
_ agd.HumanIDLower,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
|
||||
OnProfileByLinkedIP: func(
|
||||
_ context.Context,
|
||||
_ netip.Addr,
|
||||
) (p *agd.Profile, d *agd.Device, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
df := devicefinder.NewDefault(&devicefinder.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ProfileDB: profDB,
|
||||
ProfileDB: agdtest.NewProfileDB(),
|
||||
HumanIDParser: agd.NewHumanIDParser(),
|
||||
Server: srvDoT,
|
||||
DeviceDomains: []string{dnssvctest.DomainForDevices},
|
||||
|
@ -55,8 +55,16 @@ const (
|
||||
DomainRewrittenCNAMEFQDN = DomainRewrittenCNAME + "."
|
||||
)
|
||||
|
||||
// ServerName is the common server name for tests.
|
||||
const ServerName agd.ServerName = "test_server_dns_tls"
|
||||
const (
|
||||
// FilteringGroupID is the common filtering-group ID for tests.
|
||||
FilteringGroupID agd.FilteringGroupID = "test_filtering_group"
|
||||
|
||||
// ServerName is the common server name for tests.
|
||||
ServerName agd.ServerName = "test_server_dns_tls"
|
||||
|
||||
// ServerGroupName is the common server-group name for tests.
|
||||
ServerGroupName agd.ServerGroupName = "test_server_group"
|
||||
)
|
||||
|
||||
const (
|
||||
// DomainForDevices is the upper-level domain name for requests with device
|
||||
|
@ -29,8 +29,8 @@ const (
|
||||
// Resolvers for querying the resolver with unknown or absent name.
|
||||
DDRDomain = DDRLabel + "." + ResolverARPADomain
|
||||
|
||||
// FirefoxCanaryHost is the hostname that Firefox uses to check if it
|
||||
// should use its own DNS-over-HTTPS settings.
|
||||
// FirefoxCanaryHost is the hostname that Firefox uses to check if it should
|
||||
// use its own DNS-over-HTTPS settings.
|
||||
//
|
||||
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
|
||||
FirefoxCanaryHost = "use-application-dns.net"
|
||||
@ -56,6 +56,11 @@ func (mw *Middleware) reqInfoSpecialHandler(
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// As per RFC-9462 section 6.4, resolvers SHOULD respond to queries of any
|
||||
// type other than SVCB for _dns.resolver.arpa. with NODATA and queries of
|
||||
// any type for any domain name under resolver.arpa with NODATA.
|
||||
//
|
||||
// TODO(e.burkov): Consider adding SOA records for these NODATA responses.
|
||||
if mw.isDDRRequest(ri) {
|
||||
if _, ok := ri.DeviceResult.(*agd.DeviceResultAuthenticationFailure); ok {
|
||||
return mw.handleDDRNoData, "ddr_doh"
|
||||
@ -84,8 +89,6 @@ type reqInfoHandlerFunc func(
|
||||
ri *agd.RequestInfo,
|
||||
) (err error)
|
||||
|
||||
// DDR And Resolver ARPA Domain
|
||||
|
||||
// isDDRRequest determines if the message is the request for Discovery of
|
||||
// Designated Resolvers as defined by the RFC draft. The request is considered
|
||||
// ARPA if the requested host is a subdomain of resolver.arpa SUDN.
|
||||
@ -141,8 +144,8 @@ func isDDRDomain(ri *agd.RequestInfo, host string) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
// handleDDR checks if the request is for the Discovery of Designated Resolvers
|
||||
// and writes a response if needed.
|
||||
// handleDDR responds to Discovery of Designated Resolvers (DDR) queries with a
|
||||
// response containing the designated resolvers.
|
||||
func (mw *Middleware) handleDDR(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
@ -157,11 +160,11 @@ func (mw *Middleware) handleDDR(
|
||||
return rw.WriteMsg(ctx, req, mw.newRespDDR(req, ri))
|
||||
}
|
||||
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
|
||||
}
|
||||
|
||||
// handleDDRNoData processes DDR (Discovery of Designated Resolvers) requests
|
||||
// for devices which need NODATA response and writes the response if needed.
|
||||
// handleDDRNoData responds to Discovery of Designated Resolvers (DDR) queries
|
||||
// with a NODATA response.
|
||||
func (mw *Middleware) handleDDRNoData(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
@ -173,17 +176,17 @@ func (mw *Middleware) handleDDRNoData(
|
||||
metrics.DNSSvcDDRRequestsTotal.Inc()
|
||||
|
||||
if ri.ServerGroup.DDR.Enabled {
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNODATA(req))
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeSuccess))
|
||||
}
|
||||
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
|
||||
return rw.WriteMsg(ctx, req, ri.Messages.NewRespRCode(req, dns.RcodeNameError))
|
||||
}
|
||||
|
||||
// newRespDDR returns a new Discovery of Designated Resolvers response copying
|
||||
// it from the prebuilt templates in srvGrp and modifying it in accordance with
|
||||
// the request data. req must not be nil.
|
||||
func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.Msg) {
|
||||
resp = ri.Messages.NewRespMsg(req)
|
||||
resp = ri.Messages.NewResp(req)
|
||||
name := req.Question[0].Name
|
||||
ddr := ri.ServerGroup.DDR
|
||||
|
||||
@ -210,7 +213,8 @@ func (mw *Middleware) newRespDDR(req *dns.Msg, ri *agd.RequestInfo) (resp *dns.M
|
||||
return resp
|
||||
}
|
||||
|
||||
// handleBadResolverARPA writes a NODATA response.
|
||||
// handleBadResolverARPA responds to badly formed resolver.arpa queries with a
|
||||
// NODATA response.
|
||||
func (mw *Middleware) handleBadResolverARPA(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
@ -219,7 +223,8 @@ func (mw *Middleware) handleBadResolverARPA(
|
||||
) (err error) {
|
||||
metrics.DNSSvcBadResolverARPA.Inc()
|
||||
|
||||
err = rw.WriteMsg(ctx, req, ri.Messages.NewRespMsg(req))
|
||||
resp := ri.Messages.NewRespRCode(req, dns.RcodeSuccess)
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
|
||||
}
|
||||
@ -257,8 +262,6 @@ func (mw *Middleware) specialDomainHandler(
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// Apple Private Relay
|
||||
|
||||
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
|
||||
// Relay check domain and the request information or profile indicates that
|
||||
// Apple Private Relay should be blocked.
|
||||
@ -280,13 +283,12 @@ func (mw *Middleware) handlePrivateRelay(
|
||||
) (err error) {
|
||||
metrics.DNSSvcApplePrivateRelayRequestsTotal.Inc()
|
||||
|
||||
err = rw.WriteMsg(ctx, req, ri.Messages.NewMsgNXDOMAIN(req))
|
||||
resp := ri.Messages.NewRespRCode(req, dns.RcodeNameError)
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing private relay resp: %w")
|
||||
}
|
||||
|
||||
// Firefox canary domain
|
||||
|
||||
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
|
||||
// domain and the request information or profile indicates that Firefox canary
|
||||
// domain should be blocked.
|
||||
@ -298,8 +300,8 @@ func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool)
|
||||
return ri.FilteringGroup.BlockFirefoxCanary
|
||||
}
|
||||
|
||||
// handleFirefoxCanary checks if the request is for the fully-qualified domain
|
||||
// name that Firefox uses to check DoH settings and writes a response if needed.
|
||||
// handleFirefoxCanary responds to Firefox canary domain queries with a REFUSED
|
||||
// response.
|
||||
func (mw *Middleware) handleFirefoxCanary(
|
||||
ctx context.Context,
|
||||
rw dnsserver.ResponseWriter,
|
||||
@ -308,7 +310,7 @@ func (mw *Middleware) handleFirefoxCanary(
|
||||
) (err error) {
|
||||
metrics.DNSSvcFirefoxRequestsTotal.Inc()
|
||||
|
||||
resp := ri.Messages.NewMsgREFUSED(req)
|
||||
resp := ri.Messages.NewRespRCode(req, dns.RcodeRefused)
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing firefox canary resp: %w")
|
||||
|
@ -44,7 +44,9 @@ func TestMiddleware_writeDebugResponse(t *testing.T) {
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -162,7 +162,7 @@ func (mw *Middleware) setFilteredResponse(
|
||||
mw.setFilteredResponseNoReq(ctx, fctx, ri)
|
||||
case *filter.ResultBlocked:
|
||||
var err error
|
||||
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest)
|
||||
fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
|
||||
if err != nil {
|
||||
mw.reportf(ctx, "creating blocked resp for filtered req: %w", err)
|
||||
fctx.filteredResponse = fctx.originalResponse
|
||||
@ -199,7 +199,7 @@ func (mw *Middleware) setFilteredResponseNoReq(
|
||||
fctx.filteredResponse = fctx.originalResponse
|
||||
case *filter.ResultBlocked:
|
||||
var err error
|
||||
fctx.filteredResponse, err = ri.Messages.NewBlockedRespMsg(fctx.originalRequest)
|
||||
fctx.filteredResponse, err = ri.Messages.NewBlockedResp(fctx.originalRequest)
|
||||
if err != nil {
|
||||
mw.reportf(ctx, "creating blocked resp for filtered resp: %w", err)
|
||||
fctx.filteredResponse = fctx.originalResponse
|
||||
|
@ -4,6 +4,7 @@ package mainmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||
@ -14,7 +15,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@ -24,14 +25,15 @@ import (
|
||||
|
||||
// Middleware is the main middleware of AdGuard DNS.
|
||||
type Middleware struct {
|
||||
messages *dnsmsg.Constructor
|
||||
cloner *dnsmsg.Cloner
|
||||
fltCtxPool *syncutil.Pool[filteringContext]
|
||||
metrics Metrics
|
||||
logger *slog.Logger
|
||||
messages *dnsmsg.Constructor
|
||||
billStat billstat.Recorder
|
||||
errColl errcoll.Interface
|
||||
fltStrg filter.Storage
|
||||
geoIP geoip.Interface
|
||||
metrics Metrics
|
||||
queryLog querylog.Interface
|
||||
ruleStat rulestat.Interface
|
||||
}
|
||||
@ -39,19 +41,17 @@ type Middleware struct {
|
||||
// Config is the configuration structure for the main middleware. All fields
|
||||
// must be non-nil.
|
||||
type Config struct {
|
||||
// Metrics is used to collect the statistics.
|
||||
Metrics Metrics
|
||||
// Cloner is used to clone messages more efficiently by disposing of parts
|
||||
// of DNS responses for later reuse.
|
||||
Cloner *dnsmsg.Cloner
|
||||
|
||||
// Logger is used to log the operation of the middleware.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Messages is the message constructor used to create blocked and other
|
||||
// messages for this middleware.
|
||||
Messages *dnsmsg.Constructor
|
||||
|
||||
// Cloner is used to clone messages more efficiently by disposing of parts
|
||||
// of DNS responses for later reuse.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
Cloner *dnsmsg.Cloner
|
||||
|
||||
// BillStat is used to collect billing statistics.
|
||||
BillStat billstat.Recorder
|
||||
|
||||
@ -66,6 +66,9 @@ type Config struct {
|
||||
// addresses in requests and responses.
|
||||
GeoIP geoip.Interface
|
||||
|
||||
// Metrics is used to collect the statistics.
|
||||
Metrics Metrics
|
||||
|
||||
// QueryLog is used to write the logs into.
|
||||
QueryLog querylog.Interface
|
||||
|
||||
@ -77,16 +80,17 @@ type Config struct {
|
||||
// New returns a new main middleware. c must not be nil.
|
||||
func New(c *Config) (mw *Middleware) {
|
||||
return &Middleware{
|
||||
metrics: c.Metrics,
|
||||
messages: c.Messages,
|
||||
cloner: c.Cloner,
|
||||
cloner: c.Cloner,
|
||||
fltCtxPool: syncutil.NewPool(func() (v *filteringContext) {
|
||||
return &filteringContext{}
|
||||
}),
|
||||
logger: c.Logger,
|
||||
messages: c.Messages,
|
||||
billStat: c.BillStat,
|
||||
errColl: c.ErrColl,
|
||||
fltStrg: c.FilterStorage,
|
||||
geoIP: c.GeoIP,
|
||||
metrics: c.Metrics,
|
||||
queryLog: c.QueryLog,
|
||||
ruleStat: c.RuleStat,
|
||||
}
|
||||
@ -106,8 +110,20 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
defer mw.fltCtxPool.Put(fctx)
|
||||
|
||||
ri := agd.MustRequestInfoFromContext(ctx)
|
||||
optlog.Debug2("processing request %q from %s", ri.ID, ri.RemoteIP)
|
||||
defer optlog.Debug2("finished processing request %q from %s", ri.ID, ri.RemoteIP)
|
||||
optslog.Debug2(
|
||||
ctx,
|
||||
mw.logger,
|
||||
"processing request",
|
||||
"req_id", ri.ID,
|
||||
"remote_ip", ri.RemoteIP,
|
||||
)
|
||||
defer optslog.Debug2(
|
||||
ctx,
|
||||
mw.logger,
|
||||
"finished processing request",
|
||||
"req_id", ri.ID,
|
||||
"remote_ip", ri.RemoteIP,
|
||||
)
|
||||
|
||||
flt := mw.fltStrg.FilterFromContext(ctx, ri)
|
||||
mw.filterRequest(ctx, fctx, flt, ri)
|
||||
@ -184,10 +200,12 @@ func (mw *Middleware) nextParams(
|
||||
|
||||
ctx = agd.ContextWithRequestInfo(ctx, modReqInfo)
|
||||
|
||||
optlog.Debug2(
|
||||
"mainmw: request for %q rewritten to %q by CNAME rewrite rule",
|
||||
ri.Host,
|
||||
modReqInfo.Host,
|
||||
optslog.Debug2(
|
||||
ctx,
|
||||
mw.logger,
|
||||
"request rewritten by cname rewrite rule",
|
||||
"orig_host", ri.Host,
|
||||
"mod_host", modReqInfo.Host,
|
||||
)
|
||||
|
||||
return ctx, origRW, modReq
|
||||
|
@ -17,17 +17,13 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// Common constants for tests.
|
||||
const (
|
||||
testASN geoip.ASN = 12345
|
||||
@ -121,24 +117,17 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
|
||||
}
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ *geoip.Location,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) {
|
||||
pt := testutil.PanicT{}
|
||||
require.Equal(pt, dnssvctest.Domain, host)
|
||||
if addr.Is4() {
|
||||
require.Equal(pt, addr, testRespAddr4)
|
||||
} else if addr.Is6() {
|
||||
require.Equal(pt, addr, testRespAddr6)
|
||||
}
|
||||
geoIP := agdtest.NewGeoIP()
|
||||
geoIP.OnData = func(host string, addr netip.Addr) (l *geoip.Location, err error) {
|
||||
pt := testutil.PanicT{}
|
||||
require.Equal(pt, dnssvctest.Domain, host)
|
||||
if addr.Is4() {
|
||||
require.Equal(pt, addr, testRespAddr4)
|
||||
} else if addr.Is6() {
|
||||
require.Equal(pt, addr, testRespAddr6)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ruleStat := &agdtest.RuleStat{
|
||||
@ -153,7 +142,9 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -227,13 +218,14 @@ func TestMiddleware_Wrap(t *testing.T) {
|
||||
}
|
||||
|
||||
c := &mainmw.Config{
|
||||
Metrics: mainmw.EmptyMetrics{},
|
||||
Messages: msgs,
|
||||
Cloner: cloner,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Messages: msgs,
|
||||
BillStat: tc.billStat,
|
||||
ErrColl: agdtest.NewErrorCollector(),
|
||||
FilterStorage: fltStrg,
|
||||
GeoIP: geoIP,
|
||||
Metrics: mainmw.EmptyMetrics{},
|
||||
QueryLog: queryLog,
|
||||
RuleStat: ruleStat,
|
||||
}
|
||||
@ -411,16 +403,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ *geoip.Location,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(host string, addr netip.Addr) (l *geoip.Location, err error) {
|
||||
return nil, nil
|
||||
},
|
||||
geoIP := agdtest.NewGeoIP()
|
||||
geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
@ -537,7 +522,9 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
|
||||
msgs, err := dnsmsg.NewConstructor(&dnsmsg.ConstructorConfig{
|
||||
Cloner: cloner,
|
||||
BlockingMode: &dnsmsg.BlockingModeNullIP{},
|
||||
StructuredErrors: agdtest.NewSDEConfig(true),
|
||||
FilteredResponseTTL: agdtest.FilteredResponseTTL,
|
||||
EDEEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -701,13 +688,14 @@ func TestMiddleware_Wrap_filtering(t *testing.T) {
|
||||
}
|
||||
|
||||
c := &mainmw.Config{
|
||||
Metrics: mainmw.EmptyMetrics{},
|
||||
Messages: msgs,
|
||||
Cloner: cloner,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Messages: msgs,
|
||||
BillStat: tc.billStat,
|
||||
ErrColl: agdtest.NewErrorCollector(),
|
||||
FilterStorage: fltStrg,
|
||||
GeoIP: geoIP,
|
||||
Metrics: mainmw.EmptyMetrics{},
|
||||
QueryLog: queryLog,
|
||||
RuleStat: ruleStat,
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
@ -112,8 +112,7 @@ func (mw *Middleware) responseCountry(
|
||||
}
|
||||
|
||||
ctry = mw.country(ctx, host, respIP)
|
||||
// TODO(a.garipov): Use optslog.Trace2.
|
||||
optlog.Debug2("mainmw: got ctry %q for resp ip %v", ctry, respIP)
|
||||
optslog.Trace2(ctx, mw.logger, "geoip for resp", "ctry", ctry, "resp_ip", respIP)
|
||||
|
||||
return ctry
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -134,20 +134,13 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
|
||||
Country: testCtry,
|
||||
}
|
||||
|
||||
geoIP := &agdtest.GeoIP{
|
||||
OnSubnetByLocation: func(
|
||||
_ *geoip.Location,
|
||||
_ netutil.AddrFamily,
|
||||
) (n netip.Prefix, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
OnData: func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
|
||||
if !tc.wantGeoIP {
|
||||
t.Error("unexpected call to geoip")
|
||||
}
|
||||
geoIP := agdtest.NewGeoIP()
|
||||
geoIP.OnData = func(_ string, _ netip.Addr) (l *geoip.Location, err error) {
|
||||
if !tc.wantGeoIP {
|
||||
t.Error("unexpected call to geoip")
|
||||
}
|
||||
|
||||
return loc, nil
|
||||
},
|
||||
return loc, nil
|
||||
}
|
||||
|
||||
queryLogCalled := false
|
||||
@ -164,6 +157,7 @@ func TestMiddleware_recordQueryInfo_respCtry(t *testing.T) {
|
||||
}
|
||||
|
||||
mw := &Middleware{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
billStat: billstat.EmptyRecorder{},
|
||||
geoIP: geoIP,
|
||||
queryLog: queryLog,
|
||||
|
@ -5,15 +5,16 @@ package preservice
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||
"github.com/AdguardTeam/AdGuardDNS/internal/optslog"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@ -22,19 +23,18 @@ import (
|
||||
// names that may be filtered by safe browsing or parental control filters as
|
||||
// well as handling of the DNS-server check queries.
|
||||
type Middleware struct {
|
||||
// messages is used to construct TXT responses.
|
||||
messages *dnsmsg.Constructor
|
||||
|
||||
// hashMatcher is the safe browsing DNS hashMatcher.
|
||||
logger *slog.Logger
|
||||
messages *dnsmsg.Constructor
|
||||
hashMatcher filter.HashMatcher
|
||||
|
||||
// checker is used to detect and process DNS-check requests.
|
||||
checker dnscheck.Interface
|
||||
checker dnscheck.Interface
|
||||
}
|
||||
|
||||
// Config is the configurational structure for the preservice middleware. All
|
||||
// fields must be non-nil.
|
||||
type Config struct {
|
||||
// Logger is used to log the operation of the middleware.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Messages is used to construct TXT responses.
|
||||
Messages *dnsmsg.Constructor
|
||||
|
||||
@ -48,6 +48,7 @@ type Config struct {
|
||||
// New returns a new preservice middleware. c must not be nil.
|
||||
func New(c *Config) (mw *Middleware) {
|
||||
return &Middleware{
|
||||
logger: c.Logger,
|
||||
messages: c.Messages,
|
||||
hashMatcher: c.HashMatcher,
|
||||
checker: c.Checker,
|
||||
@ -60,7 +61,7 @@ var _ dnsserver.Middleware = (*Middleware)(nil)
|
||||
// Wrap implements the [dnsserver.Middleware] interface for *Middleware.
|
||||
func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "preservice mw: %w") }()
|
||||
defer func() { err = errors.Annotate(err, "presvcmw: %w") }()
|
||||
|
||||
ri := agd.MustRequestInfoFromContext(ctx)
|
||||
if ri.QType == dns.TypeTXT {
|
||||
@ -71,13 +72,20 @@ func (mw *Middleware) Wrap(next dnsserver.Handler) (wrapped dnsserver.Handler) {
|
||||
resp, err := mw.checker.Check(ctx, req, ri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("calling dnscheck: %w", err)
|
||||
} else if resp != nil {
|
||||
return errors.Annotate(rw.WriteMsg(ctx, req, resp), "writing dnscheck response: %w")
|
||||
}
|
||||
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already [errors.Annotate] here.
|
||||
return next.ServeDNS(ctx, rw, req)
|
||||
if resp == nil {
|
||||
// Don't wrap the error, because this is the main flow, and there is
|
||||
// already [errors.Annotate] here.
|
||||
return next.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing dnscheck response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return dnsserver.HandlerFunc(f)
|
||||
@ -92,15 +100,19 @@ func (mw *Middleware) respondWithHashes(
|
||||
req *dns.Msg,
|
||||
ri *agd.RequestInfo,
|
||||
) (err error) {
|
||||
optlog.Debug1("preservice mw: safe browsing: got txt req for %q", ri.Host)
|
||||
optslog.Debug1(ctx, mw.logger, "got txt req", "host", ri.Host)
|
||||
|
||||
hashes, matched, err := mw.hashMatcher.MatchByPrefix(ctx, ri.Host)
|
||||
if err != nil {
|
||||
// Don't return or collect this error to prevent DDoS of the error
|
||||
// collector by sending bad requests.
|
||||
log.Error("preservice mw: safe browsing: matching hashes: %s", err)
|
||||
mw.logger.ErrorContext(
|
||||
ctx,
|
||||
"matching hashes",
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
resp := mw.messages.NewMsgREFUSED(req)
|
||||
resp := mw.messages.NewRespRCode(req, dns.RcodeRefused)
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
|
||||
return errors.Annotate(err, "writing refused response: %w")
|
||||
@ -110,7 +122,7 @@ func (mw *Middleware) respondWithHashes(
|
||||
return next.ServeDNS(ctx, rw, req)
|
||||
}
|
||||
|
||||
resp, err := mw.messages.NewTXTRespMsg(req, hashes...)
|
||||
resp, err := mw.messages.NewRespTXT(req, hashes...)
|
||||
if err != nil {
|
||||
// Technically should never happen since the only error that could arise
|
||||
// in [dnsmsg.Constructor.NewTXTRespMsg] is the one about request type
|
||||
@ -118,7 +130,7 @@ func (mw *Middleware) respondWithHashes(
|
||||
return fmt.Errorf("creating safe browsing result: %w", err)
|
||||
}
|
||||
|
||||
optlog.Debug1("preservice mw: safe browsing: writing hashes %q", hashes)
|
||||
optslog.Debug1(ctx, mw.logger, "writing hashes", "hashes", hashes)
|
||||
|
||||
err = rw.WriteMsg(ctx, req, resp)
|
||||
if err != nil {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user