mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
Sync v2.1.5
This commit is contained in:
parent
7dec041e0f
commit
f20c533cc3
93
CHANGELOG.md
93
CHANGELOG.md
@ -11,6 +11,99 @@ The format is **not** based on [Keep a Changelog][kec], since the project
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## AGDNS-916 / Build 456
|
||||||
|
|
||||||
|
* `ratelimit` now defines rate of requests per second for IPv4 and IPv6
|
||||||
|
addresses separately. So replace this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
ratelimit:
|
||||||
|
rps: 30
|
||||||
|
ipv4_subnet_key_len: 24
|
||||||
|
ipv6_subnet_key_len: 48
|
||||||
|
```
|
||||||
|
|
||||||
|
with this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
ratelimit:
|
||||||
|
ipv4:
|
||||||
|
rps: 30
|
||||||
|
subnet_key_len: 24
|
||||||
|
ipv6:
|
||||||
|
rps: 300
|
||||||
|
subnet_key_len: 48
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## AGDNS-907 / Build 449
|
||||||
|
|
||||||
|
* The objects within the `filtering_groups` have a new property,
|
||||||
|
`block_firefox_canary`. So replace this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
filtering_groups:
|
||||||
|
-
|
||||||
|
id: default
|
||||||
|
# …
|
||||||
|
```
|
||||||
|
|
||||||
|
with this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
filtering_groups:
|
||||||
|
-
|
||||||
|
id: default
|
||||||
|
# …
|
||||||
|
block_firefox_canary: true
|
||||||
|
```
|
||||||
|
|
||||||
|
The recommended default value is `true`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## AGDNS-1308 / Build 447
|
||||||
|
|
||||||
|
* There is now a new env variable `RESEARCH_METRICS` that controls whether
|
||||||
|
collecting research metrics is enabled or not. Also, the first research
|
||||||
|
metric is added: `dns_research_blocked_per_country_total`, it counts the
|
||||||
|
number of blocked requests per country. Its default value is `0`, i.e.
|
||||||
|
research metrics collection is disabled by default.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## AGDNS-1051 / Build 443
|
||||||
|
|
||||||
|
* There are two changes in the keys of the `static_content` map. Firstly,
|
||||||
|
properties `allow_origin` and `content_type` are removed. Secondly, a new
|
||||||
|
property, called `headers`, is added. So replace this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
static_content:
|
||||||
|
'/favicon.ico':
|
||||||
|
# …
|
||||||
|
allow_origin: '*'
|
||||||
|
content_type: 'image/x-icon'
|
||||||
|
```
|
||||||
|
|
||||||
|
with this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
static_content:
|
||||||
|
'/favicon.ico':
|
||||||
|
# …
|
||||||
|
headers:
|
||||||
|
'Access-Control-Allow-Origin':
|
||||||
|
- '*'
|
||||||
|
'Content-Type':
|
||||||
|
- 'image/x-icon'
|
||||||
|
```
|
||||||
|
|
||||||
|
Adjust or add the values, if necessary.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## AGDNS-1278 / Build 423
|
## AGDNS-1278 / Build 423
|
||||||
|
|
||||||
* The object `filters` has two new properties, `rule_list_cache_size` and
|
* The object `filters` has two new properties, `rule_list_cache_size` and
|
||||||
|
@ -8,8 +8,21 @@ ratelimit:
|
|||||||
refuseany: true
|
refuseany: true
|
||||||
# If response is larger than this, it is counted as several responses.
|
# If response is larger than this, it is counted as several responses.
|
||||||
response_size_estimate: 1KB
|
response_size_estimate: 1KB
|
||||||
# Rate of requests per second for one subnet.
|
|
||||||
rps: 30
|
rps: 30
|
||||||
|
# Rate limit options for IPv4 addresses.
|
||||||
|
ipv4:
|
||||||
|
# Rate of requests per second for one subnet for IPv4 addresses.
|
||||||
|
rps: 30
|
||||||
|
# The lengths of the subnet prefixes used to calculate rate limiter
|
||||||
|
# bucket keys for IPv4 addresses.
|
||||||
|
subnet_key_len: 24
|
||||||
|
# Rate limit options for IPv6 addresses.
|
||||||
|
ipv6:
|
||||||
|
# Rate of requests per second for one subnet for IPv6 addresses.
|
||||||
|
rps: 300
|
||||||
|
# The lengths of the subnet prefixes used to calculate rate limiter
|
||||||
|
# bucket keys for IPv6 addresses.
|
||||||
|
subnet_key_len: 48
|
||||||
# The time during which to count the number of times a client has hit the
|
# The time during which to count the number of times a client has hit the
|
||||||
# rate limit for a back off.
|
# rate limit for a back off.
|
||||||
#
|
#
|
||||||
@ -21,10 +34,6 @@ ratelimit:
|
|||||||
# How much a client that has hit the rate limit too often stays in the back
|
# How much a client that has hit the rate limit too often stays in the back
|
||||||
# off.
|
# off.
|
||||||
back_off_duration: 30m
|
back_off_duration: 30m
|
||||||
# The lengths of the subnet prefixes used to calculate rate limiter bucket
|
|
||||||
# keys for IPv4 and IPv6 addresses correspondingly.
|
|
||||||
ipv4_subnet_key_len: 24
|
|
||||||
ipv6_subnet_key_len: 48
|
|
||||||
|
|
||||||
# Configuration for the allowlist.
|
# Configuration for the allowlist.
|
||||||
allowlist:
|
allowlist:
|
||||||
@ -156,9 +165,12 @@ web:
|
|||||||
# servers. Paths must not cross the ones used by the DNS-over-HTTPS server.
|
# servers. Paths must not cross the ones used by the DNS-over-HTTPS server.
|
||||||
static_content:
|
static_content:
|
||||||
'/favicon.ico':
|
'/favicon.ico':
|
||||||
allow_origin: '*'
|
|
||||||
content_type: 'image/x-icon'
|
|
||||||
content: ''
|
content: ''
|
||||||
|
headers:
|
||||||
|
Access-Control-Allow-Origin:
|
||||||
|
- '*'
|
||||||
|
Content-Type:
|
||||||
|
- 'image/x-icon'
|
||||||
# If not defined, AdGuard DNS will respond with a 404 page to all such
|
# If not defined, AdGuard DNS will respond with a 404 page to all such
|
||||||
# requests.
|
# requests.
|
||||||
root_redirect_url: 'https://adguard-dns.com'
|
root_redirect_url: 'https://adguard-dns.com'
|
||||||
@ -221,6 +233,7 @@ filtering_groups:
|
|||||||
safe_browsing:
|
safe_browsing:
|
||||||
enabled: true
|
enabled: true
|
||||||
block_private_relay: false
|
block_private_relay: false
|
||||||
|
block_firefox_canary: true
|
||||||
- id: 'family'
|
- id: 'family'
|
||||||
parental:
|
parental:
|
||||||
enabled: true
|
enabled: true
|
||||||
@ -234,6 +247,7 @@ filtering_groups:
|
|||||||
safe_browsing:
|
safe_browsing:
|
||||||
enabled: true
|
enabled: true
|
||||||
block_private_relay: false
|
block_private_relay: false
|
||||||
|
block_firefox_canary: true
|
||||||
- id: 'non_filtering'
|
- id: 'non_filtering'
|
||||||
rule_lists:
|
rule_lists:
|
||||||
enabled: false
|
enabled: false
|
||||||
@ -242,6 +256,7 @@ filtering_groups:
|
|||||||
safe_browsing:
|
safe_browsing:
|
||||||
enabled: false
|
enabled: false
|
||||||
block_private_relay: false
|
block_private_relay: false
|
||||||
|
block_firefox_canary: true
|
||||||
|
|
||||||
# Server groups and servers.
|
# Server groups and servers.
|
||||||
server_groups:
|
server_groups:
|
||||||
|
@ -85,11 +85,23 @@ The `ratelimit` object has the following properties:
|
|||||||
|
|
||||||
**Example:** `30m`.
|
**Example:** `30m`.
|
||||||
|
|
||||||
* <a href="#ratelimit-rps" id="ratelimit-rps" name="ratelimit-rps">`rps`</a>:
|
* <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>:
|
||||||
The rate of requests per second for one subnet. Requests above this are
|
The ipv4 configuration object. It has the following fields:
|
||||||
counted in the backoff count.
|
|
||||||
|
|
||||||
**Example:** `30`.
|
* <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>:
|
||||||
|
The rate of requests per second for one subnet. Requests above this are
|
||||||
|
counted in the backoff count.
|
||||||
|
|
||||||
|
**Example:** `30`.
|
||||||
|
|
||||||
|
* <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>:
|
||||||
|
The length of the subnet prefix used to calculate rate limiter bucket keys.
|
||||||
|
|
||||||
|
**Example:** `24`.
|
||||||
|
|
||||||
|
* <a href="#ratelimit-ipv6" id="ratelimit-ipv6" name="ratelimit-ipv6">`ipv6`</a>:
|
||||||
|
The `ipv6` configuration object has the same properties as the `ipv4` one
|
||||||
|
above.
|
||||||
|
|
||||||
* <a href="#ratelimit-back_off_count" id="ratelimit-back_off_count" name="ratelimit-back_off_count">`back_off_count`</a>:
|
* <a href="#ratelimit-back_off_count" id="ratelimit-back_off_count" name="ratelimit-back_off_count">`back_off_count`</a>:
|
||||||
Maximum number of requests a client can make above the RPS within
|
Maximum number of requests a client can make above the RPS within
|
||||||
@ -120,22 +132,11 @@ The `ratelimit` object has the following properties:
|
|||||||
|
|
||||||
**Example:** `30s`.
|
**Example:** `30s`.
|
||||||
|
|
||||||
* <a href="#ratelimit-ipv4_subnet_key_len" id="ratelimit-ipv4_subnet_key_len" name="ratelimit-ipv4_subnet_key_len">`ipv4_subnet_key_len`</a>:
|
For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and
|
||||||
The length of the subnet prefix used to calculate rate limiter bucket keys
|
`ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined
|
||||||
for IPv4 addresses.
|
by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests
|
||||||
|
(one above `rps`) every second for 10 seconds within one minute, the client is
|
||||||
**Example:** `24`.
|
blocked for `back_off_duration`.
|
||||||
|
|
||||||
* <a href="#ratelimit-ipv6_subnet_key_len" id="ratelimit-ipv6_subnet_key_len" name="ratelimit-ipv6_subnet_key_len">`ipv6_subnet_key_len`</a>:
|
|
||||||
Same as `ipv4_subnet_key_len` above but for IPv6 addresses.
|
|
||||||
|
|
||||||
**Example:** `48`.
|
|
||||||
|
|
||||||
For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and `rps`
|
|
||||||
is `5`, a client (meaning all IP addresses within the subnet defined by
|
|
||||||
`ipv4_subnet_key_len` and `ipv6_subnet_key_len`) that made 15 requests in one
|
|
||||||
second or 6 requests (one above `rps`) every second for 10 seconds within one
|
|
||||||
minute, the client is blocked for `back_off_duration`.
|
|
||||||
|
|
||||||
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
|
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
|
||||||
|
|
||||||
@ -454,13 +455,17 @@ The optional `web` object has the following properties:
|
|||||||
`safe_browsing` and `adult_blocking` servers. Paths must not duplicate the
|
`safe_browsing` and `adult_blocking` servers. Paths must not duplicate the
|
||||||
ones used by the DNS-over-HTTPS server.
|
ones used by the DNS-over-HTTPS server.
|
||||||
|
|
||||||
|
Inside of the `headers` map, the header `Content-Type` is required.
|
||||||
|
|
||||||
**Property example:**
|
**Property example:**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
'static_content':
|
static_content:
|
||||||
'/favicon.ico':
|
'/favicon.ico':
|
||||||
'content_type': 'image/x-icon'
|
content: 'base64content'
|
||||||
'content': 'base64content'
|
headers:
|
||||||
|
'Content-Type':
|
||||||
|
- 'image/x-icon'
|
||||||
```
|
```
|
||||||
|
|
||||||
* <a href="#web-root_redirect_url" id="web-root_redirect_url" name="web-root_redirect_url">`root_redirect_url`</a>:
|
* <a href="#web-root_redirect_url" id="web-root_redirect_url" name="web-root_redirect_url">`root_redirect_url`</a>:
|
||||||
@ -647,6 +652,12 @@ The items of the `filtering_groups` array have the following properties:
|
|||||||
|
|
||||||
**Example:** `false`.
|
**Example:** `false`.
|
||||||
|
|
||||||
|
* <a href="#fg-*-block_firefox_canary" id="fg-*-block_firefox_canary" name="fg-*-block_firefox_canary">`block_firefox_canary`</a>:
|
||||||
|
If true, Firefox canary domain queries are blocked for requests using this
|
||||||
|
filtering group.
|
||||||
|
|
||||||
|
**Example:** `true`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## <a href="#server_groups" id="server_groups" name="server_groups">Server groups</a>
|
## <a href="#server_groups" id="server_groups" name="server_groups">Server groups</a>
|
||||||
|
@ -22,6 +22,7 @@ sensitive configuration. All other configuration is stored in the
|
|||||||
* [`LISTEN_PORT`](#LISTEN_PORT)
|
* [`LISTEN_PORT`](#LISTEN_PORT)
|
||||||
* [`LOG_TIMESTAMP`](#LOG_TIMESTAMP)
|
* [`LOG_TIMESTAMP`](#LOG_TIMESTAMP)
|
||||||
* [`QUERYLOG_PATH`](#QUERYLOG_PATH)
|
* [`QUERYLOG_PATH`](#QUERYLOG_PATH)
|
||||||
|
* [`RESEARCH_METRICS`](#RESEARCH_METRICS)
|
||||||
* [`RULESTAT_URL`](#RULESTAT_URL)
|
* [`RULESTAT_URL`](#RULESTAT_URL)
|
||||||
* [`SENTRY_DSN`](#SENTRY_DSN)
|
* [`SENTRY_DSN`](#SENTRY_DSN)
|
||||||
* [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE)
|
* [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE)
|
||||||
@ -198,6 +199,15 @@ The path to the file into which the query log is going to be written.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## <a href="#RESEARCH_METRICS" id="RESEARCH_METRICS" name="RESEARCH_METRICS">`RESEARCH_METRICS`</a>
|
||||||
|
|
||||||
|
If `1`, enable collection of a set of special prometheus metrics (prefix is
|
||||||
|
`dns_research`). If `0`, disable collection of those metrics.
|
||||||
|
|
||||||
|
**Default:** `0`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## <a href="#RULESTAT_URL" id="RULESTAT_URL" name="RULESTAT_URL">`RULESTAT_URL`</a>
|
## <a href="#RULESTAT_URL" id="RULESTAT_URL" name="RULESTAT_URL">`RULESTAT_URL`</a>
|
||||||
|
|
||||||
The URL to send filtering rule list statistics to. If empty or unset, the
|
The URL to send filtering rule list statistics to. If empty or unset, the
|
||||||
|
2
go.mod
2
go.mod
@ -4,7 +4,7 @@ go 1.19
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
|
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
|
||||||
github.com/AdguardTeam/golibs v0.11.3
|
github.com/AdguardTeam/golibs v0.11.4
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1
|
github.com/AdguardTeam/urlfilter v0.16.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
||||||
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a
|
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a
|
||||||
|
4
go.sum
4
go.sum
@ -33,8 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
|
|||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||||
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
|
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
|
||||||
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
||||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
|
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
||||||
|
22
go.work.sum
22
go.work.sum
@ -9,7 +9,11 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVl
|
|||||||
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
|
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
|
||||||
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
|
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
|
||||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
|
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
|
||||||
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
|
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||||
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||||
|
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
|
||||||
|
github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
||||||
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
|
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
|
||||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||||
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
|
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
|
||||||
@ -74,7 +78,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
|
|||||||
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
|
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
|
||||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||||
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
||||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
|
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
|
||||||
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
|
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
|
||||||
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
|
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
|
||||||
@ -198,41 +202,36 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
|
|||||||
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
|
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
|
||||||
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
|
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
|
||||||
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
|
||||||
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
|
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
|
||||||
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
|
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
|
||||||
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
|
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
|
||||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
|
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
|
||||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
|
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||||
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
|
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
|
||||||
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
|
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
|
||||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
|
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
|
||||||
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||||
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
|
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
|
||||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
|
||||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||||
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
|
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
|
||||||
|
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
|
||||||
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
|
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
|
||||||
@ -246,6 +245,7 @@ gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
|
|||||||
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
||||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||||
|
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
|
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
|
||||||
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
|
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
|
||||||
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
|
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
package agd
|
package agd
|
||||||
|
|
||||||
import (
|
import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Common DNS Message Constants, Types, And Utilities
|
// Common DNS Message Constants, Types, And Utilities
|
||||||
|
|
||||||
@ -27,10 +22,3 @@ const (
|
|||||||
ProtoDoT = dnsserver.ProtoDoT
|
ProtoDoT = dnsserver.ProtoDoT
|
||||||
ProtoDNSCrypt = dnsserver.ProtoDNSCrypt
|
ProtoDNSCrypt = dnsserver.ProtoDNSCrypt
|
||||||
)
|
)
|
||||||
|
|
||||||
// Resolver is the DNS resolver interface.
|
|
||||||
//
|
|
||||||
// See go doc net.Resolver.
|
|
||||||
type Resolver interface {
|
|
||||||
LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error)
|
|
||||||
}
|
|
||||||
|
@ -146,6 +146,10 @@ type FilteringGroup struct {
|
|||||||
// BlockPrivateRelay shows if Apple Private Relay is blocked for requests
|
// BlockPrivateRelay shows if Apple Private Relay is blocked for requests
|
||||||
// using this filtering group.
|
// using this filtering group.
|
||||||
BlockPrivateRelay bool
|
BlockPrivateRelay bool
|
||||||
|
|
||||||
|
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
|
||||||
|
// requests using this filtering group.
|
||||||
|
BlockFirefoxCanary bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteringGroupID is the ID of a filter group. It is an opaque string.
|
// FilteringGroupID is the ID of a filter group. It is an opaque string.
|
||||||
|
@ -73,6 +73,10 @@ type Profile struct {
|
|||||||
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
|
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
|
||||||
// requests from all devices in this profile.
|
// requests from all devices in this profile.
|
||||||
BlockPrivateRelay bool
|
BlockPrivateRelay bool
|
||||||
|
|
||||||
|
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
|
||||||
|
// requests from all devices in this profile.
|
||||||
|
BlockFirefoxCanary bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProfileID is the ID of a profile. It is an opaque string.
|
// ProfileID is the ID of a profile. It is an opaque string.
|
||||||
|
@ -17,6 +17,8 @@ import (
|
|||||||
// Data Storage
|
// Data Storage
|
||||||
|
|
||||||
// ProfileDB is the local database of profiles and other data.
|
// ProfileDB is the local database of profiles and other data.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): move this logic to the backend package.
|
||||||
type ProfileDB interface {
|
type ProfileDB interface {
|
||||||
ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
|
ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
|
||||||
ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)
|
ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)
|
||||||
|
@ -7,6 +7,8 @@ package agdio
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/mathutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LimitError is returned when the Limit is reached.
|
// LimitError is returned when the Limit is reached.
|
||||||
@ -35,9 +37,8 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if int64(len(p)) > lr.n {
|
l := mathutil.Min(int64(len(p)), lr.n)
|
||||||
p = p[0:lr.n]
|
p = p[:l]
|
||||||
}
|
|
||||||
|
|
||||||
n, err = lr.r.Read(p)
|
n, err = lr.r.Read(p)
|
||||||
lr.n -= int64(n)
|
lr.n -= int64(n)
|
||||||
|
@ -165,6 +165,31 @@ type v1SettingsRespSettings struct {
|
|||||||
FilteringEnabled bool `json:"filtering_enabled"`
|
FilteringEnabled bool `json:"filtering_enabled"`
|
||||||
Deleted bool `json:"deleted"`
|
Deleted bool `json:"deleted"`
|
||||||
BlockPrivateRelay bool `json:"block_private_relay"`
|
BlockPrivateRelay bool `json:"block_private_relay"`
|
||||||
|
BlockFirefoxCanary bool `json:"block_firefox_canary"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ json.Unmarshaler = (*v1SettingsRespSettings)(nil)
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [json.Unmarshaler] interface for
|
||||||
|
// *v1SettingsRespSettings. It puts default value into BlockFirefoxCanary
|
||||||
|
// field while it is not implemented on the backend side.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Remove once the backend starts to always send it.
|
||||||
|
func (rs *v1SettingsRespSettings) UnmarshalJSON(b []byte) (err error) {
|
||||||
|
type defaultDec v1SettingsRespSettings
|
||||||
|
|
||||||
|
s := defaultDec{
|
||||||
|
BlockFirefoxCanary: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(b, &s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*rs = v1SettingsRespSettings(s)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// v1SettingsRespRuleLists is the structure for decoding filtering rule lists
|
// v1SettingsRespRuleLists is the structure for decoding filtering rule lists
|
||||||
@ -414,10 +439,11 @@ const maxFltRespTTL = 1 * time.Hour
|
|||||||
func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) {
|
func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) {
|
||||||
ttl = time.Duration(respTTL) * time.Second
|
ttl = time.Duration(respTTL) * time.Second
|
||||||
if ttl > maxFltRespTTL {
|
if ttl > maxFltRespTTL {
|
||||||
return ttl, fmt.Errorf("too high: got %d, max %d", respTTL, maxFltRespTTL)
|
ttl = maxFltRespTTL
|
||||||
|
err = fmt.Errorf("too high: got %s, max %s", ttl, maxFltRespTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ttl, nil
|
return ttl, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// toInternal converts r to an [agd.DSProfilesResponse] instance.
|
// toInternal converts r to an [agd.DSProfilesResponse] instance.
|
||||||
@ -461,6 +487,9 @@ func (r *v1SettingsResp) toInternal(
|
|||||||
reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err)
|
reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err)
|
||||||
|
|
||||||
// Go on and use the fixed value.
|
// Go on and use the fixed value.
|
||||||
|
//
|
||||||
|
// TODO(ameshkov, a.garipov): Consider continuing, like with all
|
||||||
|
// other validation errors.
|
||||||
}
|
}
|
||||||
|
|
||||||
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
|
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
|
||||||
@ -479,6 +508,7 @@ func (r *v1SettingsResp) toInternal(
|
|||||||
QueryLogEnabled: s.QueryLogEnabled,
|
QueryLogEnabled: s.QueryLogEnabled,
|
||||||
Deleted: s.Deleted,
|
Deleted: s.Deleted,
|
||||||
BlockPrivateRelay: s.BlockPrivateRelay,
|
BlockPrivateRelay: s.BlockPrivateRelay,
|
||||||
|
BlockFirefoxCanary: s.BlockFirefoxCanary,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,6 +131,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
|
|||||||
QueryLogEnabled: true,
|
QueryLogEnabled: true,
|
||||||
Deleted: false,
|
Deleted: false,
|
||||||
BlockPrivateRelay: true,
|
BlockPrivateRelay: true,
|
||||||
|
BlockFirefoxCanary: true,
|
||||||
}, {
|
}, {
|
||||||
Parental: wantParental,
|
Parental: wantParental,
|
||||||
ID: "83f3ea8f",
|
ID: "83f3ea8f",
|
||||||
@ -162,6 +163,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
|
|||||||
QueryLogEnabled: true,
|
QueryLogEnabled: true,
|
||||||
Deleted: true,
|
Deleted: true,
|
||||||
BlockPrivateRelay: false,
|
BlockPrivateRelay: false,
|
||||||
|
BlockFirefoxCanary: false,
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
internal/backend/testdata/profiles.json
vendored
2
internal/backend/testdata/profiles.json
vendored
@ -11,6 +11,7 @@
|
|||||||
},
|
},
|
||||||
"deleted": false,
|
"deleted": false,
|
||||||
"block_private_relay": true,
|
"block_private_relay": true,
|
||||||
|
"block_firefox_canary": true,
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
"id": "118ffe93",
|
"id": "118ffe93",
|
||||||
@ -43,6 +44,7 @@
|
|||||||
},
|
},
|
||||||
"deleted": true,
|
"deleted": true,
|
||||||
"block_private_relay": false,
|
"block_private_relay": false,
|
||||||
|
"block_firefox_canary": false,
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
"id": "0d7724fa",
|
"id": "0d7724fa",
|
||||||
|
@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
|
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Main is the entry point of application.
|
// Main is the entry point of application.
|
||||||
@ -89,52 +90,75 @@ func Main() {
|
|||||||
go func() {
|
go func() {
|
||||||
defer geoIPMu.Unlock()
|
defer geoIPMu.Unlock()
|
||||||
|
|
||||||
geoIP, geoIPErr = envs.geoIP(c.GeoIP, errColl)
|
geoIP, geoIPErr = envs.geoIP(c.GeoIP)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Safe Browsing Hosts
|
// Safe-browsing and adult-blocking filters
|
||||||
|
|
||||||
|
// TODO(ameshkov): Consider making configurable.
|
||||||
|
filteringResolver := agdnet.NewCachingResolver(
|
||||||
|
agdnet.DefaultResolver{},
|
||||||
|
1*timeutil.Day,
|
||||||
|
)
|
||||||
|
|
||||||
err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm)
|
err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
safeBrowsingConf := c.SafeBrowsing.toInternal(
|
safeBrowsingConf, err := c.SafeBrowsing.toInternal(
|
||||||
|
errColl,
|
||||||
|
filteringResolver,
|
||||||
agd.FilterListIDSafeBrowsing,
|
agd.FilterListIDSafeBrowsing,
|
||||||
envs.FilterCachePath,
|
envs.FilterCachePath,
|
||||||
errColl,
|
|
||||||
)
|
)
|
||||||
safeBrowsingHashes, err := filter.NewHashStorage(safeBrowsingConf)
|
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
err = safeBrowsingHashes.Start()
|
safeBrowsingFilter, err := filter.NewHashPrefix(safeBrowsingConf)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
adultBlockingConf := c.AdultBlocking.toInternal(
|
safeBrowsingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
|
||||||
|
Context: ctxWithDefaultTimeout,
|
||||||
|
Refresher: safeBrowsingFilter,
|
||||||
|
ErrColl: errColl,
|
||||||
|
Name: string(agd.FilterListIDSafeBrowsing),
|
||||||
|
Interval: safeBrowsingConf.Staleness,
|
||||||
|
RefreshOnShutdown: false,
|
||||||
|
RoutineLogsAreDebug: false,
|
||||||
|
})
|
||||||
|
err = safeBrowsingUpd.Start()
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
adultBlockingConf, err := c.AdultBlocking.toInternal(
|
||||||
|
errColl,
|
||||||
|
filteringResolver,
|
||||||
agd.FilterListIDAdultBlocking,
|
agd.FilterListIDAdultBlocking,
|
||||||
envs.FilterCachePath,
|
envs.FilterCachePath,
|
||||||
errColl,
|
|
||||||
)
|
)
|
||||||
adultBlockingHashes, err := filter.NewHashStorage(adultBlockingConf)
|
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
err = adultBlockingHashes.Start()
|
adultBlockingFilter, err := filter.NewHashPrefix(adultBlockingConf)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
// Filters And Filtering Groups
|
adultBlockingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
|
||||||
|
Context: ctxWithDefaultTimeout,
|
||||||
|
Refresher: adultBlockingFilter,
|
||||||
|
ErrColl: errColl,
|
||||||
|
Name: string(agd.FilterListIDAdultBlocking),
|
||||||
|
Interval: adultBlockingConf.Staleness,
|
||||||
|
RefreshOnShutdown: false,
|
||||||
|
RoutineLogsAreDebug: false,
|
||||||
|
})
|
||||||
|
err = adultBlockingUpd.Start()
|
||||||
|
check(err)
|
||||||
|
|
||||||
fltStrgConf := c.Filters.toInternal(errColl, envs)
|
// Filter storage and filtering groups
|
||||||
fltStrgConf.SafeBrowsing = &filter.HashPrefixConfig{
|
|
||||||
Hashes: safeBrowsingHashes,
|
|
||||||
ReplacementHost: c.SafeBrowsing.BlockHost,
|
|
||||||
CacheTTL: c.SafeBrowsing.CacheTTL.Duration,
|
|
||||||
CacheSize: c.SafeBrowsing.CacheSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fltStrgConf.AdultBlocking = &filter.HashPrefixConfig{
|
fltStrgConf := c.Filters.toInternal(
|
||||||
Hashes: adultBlockingHashes,
|
errColl,
|
||||||
ReplacementHost: c.AdultBlocking.BlockHost,
|
filteringResolver,
|
||||||
CacheTTL: c.AdultBlocking.CacheTTL.Duration,
|
envs,
|
||||||
CacheSize: c.AdultBlocking.CacheSize,
|
safeBrowsingFilter,
|
||||||
}
|
adultBlockingFilter,
|
||||||
|
)
|
||||||
|
|
||||||
fltStrg, err := filter.NewDefaultStorage(fltStrgConf)
|
fltStrg, err := filter.NewDefaultStorage(fltStrgConf)
|
||||||
check(err)
|
check(err)
|
||||||
@ -153,8 +177,6 @@ func Main() {
|
|||||||
err = fltStrgUpd.Start()
|
err = fltStrgUpd.Start()
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
safeBrowsing := filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes)
|
|
||||||
|
|
||||||
// Server Groups
|
// Server Groups
|
||||||
|
|
||||||
fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
|
fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
|
||||||
@ -329,8 +351,11 @@ func Main() {
|
|||||||
}, c.Upstream.Healthcheck.Enabled)
|
}, c.Upstream.Healthcheck.Enabled)
|
||||||
|
|
||||||
dnsConf := &dnssvc.Config{
|
dnsConf := &dnssvc.Config{
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
SafeBrowsing: safeBrowsing,
|
SafeBrowsing: filter.NewSafeBrowsingServer(
|
||||||
|
safeBrowsingConf.Hashes,
|
||||||
|
adultBlockingConf.Hashes,
|
||||||
|
),
|
||||||
BillStat: billStatRec,
|
BillStat: billStatRec,
|
||||||
ProfileDB: profDB,
|
ProfileDB: profDB,
|
||||||
DNSCheck: dnsCk,
|
DNSCheck: dnsCk,
|
||||||
@ -349,6 +374,7 @@ func Main() {
|
|||||||
CacheSize: c.Cache.Size,
|
CacheSize: c.Cache.Size,
|
||||||
ECSCacheSize: c.Cache.ECSSize,
|
ECSCacheSize: c.Cache.ECSSize,
|
||||||
UseECSCache: c.Cache.Type == cacheTypeECS,
|
UseECSCache: c.Cache.Type == cacheTypeECS,
|
||||||
|
ResearchMetrics: bool(envs.ResearchMetrics),
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsSvc, err := dnssvc.New(dnsConf)
|
dnsSvc, err := dnssvc.New(dnsConf)
|
||||||
@ -383,11 +409,11 @@ func Main() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
h := newSignalHandler(
|
h := newSignalHandler(
|
||||||
adultBlockingHashes,
|
|
||||||
safeBrowsingHashes,
|
|
||||||
debugSvc,
|
debugSvc,
|
||||||
webSvc,
|
webSvc,
|
||||||
dnsSvc,
|
dnsSvc,
|
||||||
|
safeBrowsingUpd,
|
||||||
|
adultBlockingUpd,
|
||||||
profDBUpd,
|
profDBUpd,
|
||||||
dnsDBUpd,
|
dnsDBUpd,
|
||||||
geoIPUpd,
|
geoIPUpd,
|
||||||
|
@ -3,14 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
@ -213,74 +208,6 @@ func (c *geoIPConfig) validate() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// allowListConfig is the consul allow list configuration.
|
|
||||||
type allowListConfig struct {
|
|
||||||
// List contains IPs and CIDRs.
|
|
||||||
List []string `yaml:"list"`
|
|
||||||
|
|
||||||
// RefreshIvl time between two updates of allow list from the Consul URL.
|
|
||||||
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
|
|
||||||
type safeBrowsingConfig struct {
|
|
||||||
// URL is the URL used to update the filter.
|
|
||||||
URL *agdhttp.URL `yaml:"url"`
|
|
||||||
|
|
||||||
// BlockHost is the hostname with which to respond to any requests that
|
|
||||||
// match the filter.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
|
|
||||||
// addresses.
|
|
||||||
BlockHost string `yaml:"block_host"`
|
|
||||||
|
|
||||||
// CacheSize is the size of the response cache, in entries.
|
|
||||||
CacheSize int `yaml:"cache_size"`
|
|
||||||
|
|
||||||
// CacheTTL is the TTL of the response cache.
|
|
||||||
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
|
|
||||||
|
|
||||||
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
|
|
||||||
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// toInternal converts c to the safe browsing filter configuration for the
|
|
||||||
// filter storage of the DNS server. c is assumed to be valid.
|
|
||||||
func (c *safeBrowsingConfig) toInternal(
|
|
||||||
id agd.FilterListID,
|
|
||||||
cacheDir string,
|
|
||||||
errColl agd.ErrorCollector,
|
|
||||||
) (conf *filter.HashStorageConfig) {
|
|
||||||
return &filter.HashStorageConfig{
|
|
||||||
URL: netutil.CloneURL(&c.URL.URL),
|
|
||||||
ErrColl: errColl,
|
|
||||||
ID: id,
|
|
||||||
CachePath: filepath.Join(cacheDir, string(id)),
|
|
||||||
RefreshIvl: c.RefreshIvl.Duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate returns an error if the safe browsing filter configuration is
|
|
||||||
// invalid.
|
|
||||||
func (c *safeBrowsingConfig) validate() (err error) {
|
|
||||||
switch {
|
|
||||||
case c == nil:
|
|
||||||
return errNilConfig
|
|
||||||
case c.URL == nil:
|
|
||||||
return errors.Error("no url")
|
|
||||||
case c.BlockHost == "":
|
|
||||||
return errors.Error("no block_host")
|
|
||||||
case c.CacheSize <= 0:
|
|
||||||
return newMustBePositiveError("cache_size", c.CacheSize)
|
|
||||||
case c.CacheTTL.Duration <= 0:
|
|
||||||
return newMustBePositiveError("cache_ttl", c.CacheTTL)
|
|
||||||
case c.RefreshIvl.Duration <= 0:
|
|
||||||
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// readConfig reads the configuration.
|
// readConfig reads the configuration.
|
||||||
func readConfig(confPath string) (c *configuration, err error) {
|
func readConfig(confPath string) (c *configuration, err error) {
|
||||||
// #nosec G304 -- Trust the path to the configuration file that is given
|
// #nosec G304 -- Trust the path to the configuration file that is given
|
||||||
|
@ -48,8 +48,9 @@ type environments struct {
|
|||||||
|
|
||||||
ListenPort int `env:"LISTEN_PORT" envDefault:"8181"`
|
ListenPort int `env:"LISTEN_PORT" envDefault:"8181"`
|
||||||
|
|
||||||
LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
|
LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
|
||||||
LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
|
LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
|
||||||
|
ResearchMetrics strictBool `env:"RESEARCH_METRICS" envDefault:"0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// readEnvs reads the configuration.
|
// readEnvs reads the configuration.
|
||||||
@ -127,12 +128,10 @@ func (envs *environments) buildDNSDB(
|
|||||||
// geoIP returns an GeoIP database implementation from environment.
|
// geoIP returns an GeoIP database implementation from environment.
|
||||||
func (envs *environments) geoIP(
|
func (envs *environments) geoIP(
|
||||||
c *geoIPConfig,
|
c *geoIPConfig,
|
||||||
errColl agd.ErrorCollector,
|
|
||||||
) (g *geoip.File, err error) {
|
) (g *geoip.File, err error) {
|
||||||
log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath)
|
log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath)
|
||||||
|
|
||||||
g, err = geoip.NewFile(&geoip.FileConfig{
|
g, err = geoip.NewFile(&geoip.FileConfig{
|
||||||
ErrColl: errColl,
|
|
||||||
ASNPath: envs.GeoIPASNPath,
|
ASNPath: envs.GeoIPASNPath,
|
||||||
CountryPath: envs.GeoIPCountryPath,
|
CountryPath: envs.GeoIPCountryPath,
|
||||||
HostCacheSize: c.HostCacheSize,
|
HostCacheSize: c.HostCacheSize,
|
||||||
|
@ -47,16 +47,21 @@ type filtersConfig struct {
|
|||||||
// cacheDir must exist. c is assumed to be valid.
|
// cacheDir must exist. c is assumed to be valid.
|
||||||
func (c *filtersConfig) toInternal(
|
func (c *filtersConfig) toInternal(
|
||||||
errColl agd.ErrorCollector,
|
errColl agd.ErrorCollector,
|
||||||
|
resolver agdnet.Resolver,
|
||||||
envs *environments,
|
envs *environments,
|
||||||
|
safeBrowsing *filter.HashPrefix,
|
||||||
|
adultBlocking *filter.HashPrefix,
|
||||||
) (conf *filter.DefaultStorageConfig) {
|
) (conf *filter.DefaultStorageConfig) {
|
||||||
return &filter.DefaultStorageConfig{
|
return &filter.DefaultStorageConfig{
|
||||||
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
|
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
|
||||||
BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL),
|
BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL),
|
||||||
GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL),
|
GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL),
|
||||||
YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL),
|
YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL),
|
||||||
|
SafeBrowsing: safeBrowsing,
|
||||||
|
AdultBlocking: adultBlocking,
|
||||||
Now: time.Now,
|
Now: time.Now,
|
||||||
ErrColl: errColl,
|
ErrColl: errColl,
|
||||||
Resolver: agdnet.DefaultResolver{},
|
Resolver: resolver,
|
||||||
CacheDir: envs.FilterCachePath,
|
CacheDir: envs.FilterCachePath,
|
||||||
CustomFilterCacheSize: c.CustomFilterCacheSize,
|
CustomFilterCacheSize: c.CustomFilterCacheSize,
|
||||||
SafeSearchCacheSize: c.SafeSearchCacheSize,
|
SafeSearchCacheSize: c.SafeSearchCacheSize,
|
||||||
|
@ -29,6 +29,10 @@ type filteringGroup struct {
|
|||||||
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
|
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
|
||||||
// requests using this filtering group.
|
// requests using this filtering group.
|
||||||
BlockPrivateRelay bool `yaml:"block_private_relay"`
|
BlockPrivateRelay bool `yaml:"block_private_relay"`
|
||||||
|
|
||||||
|
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
|
||||||
|
// requests using this filtering group.
|
||||||
|
BlockFirefoxCanary bool `yaml:"block_firefox_canary"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// fltGrpRuleLists contains filter rule lists configuration for a filtering
|
// fltGrpRuleLists contains filter rule lists configuration for a filtering
|
||||||
@ -133,6 +137,7 @@ func (groups filteringGroups) toInternal(
|
|||||||
GeneralSafeSearch: g.Parental.GeneralSafeSearch,
|
GeneralSafeSearch: g.Parental.GeneralSafeSearch,
|
||||||
YoutubeSafeSearch: g.Parental.YoutubeSafeSearch,
|
YoutubeSafeSearch: g.Parental.YoutubeSafeSearch,
|
||||||
BlockPrivateRelay: g.BlockPrivateRelay,
|
BlockPrivateRelay: g.BlockPrivateRelay,
|
||||||
|
BlockFirefoxCanary: g.BlockFirefoxCanary,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,14 +15,17 @@ type rateLimitConfig struct {
|
|||||||
// AllowList is the allowlist of clients.
|
// AllowList is the allowlist of clients.
|
||||||
Allowlist *allowListConfig `yaml:"allowlist"`
|
Allowlist *allowListConfig `yaml:"allowlist"`
|
||||||
|
|
||||||
|
// Rate limit options for IPv4 addresses.
|
||||||
|
IPv4 *rateLimitOptions `yaml:"ipv4"`
|
||||||
|
|
||||||
|
// Rate limit options for IPv6 addresses.
|
||||||
|
IPv6 *rateLimitOptions `yaml:"ipv6"`
|
||||||
|
|
||||||
// ResponseSizeEstimate is the size of the estimate of the size of one DNS
|
// ResponseSizeEstimate is the size of the estimate of the size of one DNS
|
||||||
// response for the purposes of rate limiting. Responses over this estimate
|
// response for the purposes of rate limiting. Responses over this estimate
|
||||||
// are counted as several responses.
|
// are counted as several responses.
|
||||||
ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"`
|
ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"`
|
||||||
|
|
||||||
// RPS is the maximum number of requests per second.
|
|
||||||
RPS int `yaml:"rps"`
|
|
||||||
|
|
||||||
// BackOffCount helps with repeated offenders. It defines, how many times
|
// BackOffCount helps with repeated offenders. It defines, how many times
|
||||||
// a client hits the rate limit before being held in the back off.
|
// a client hits the rate limit before being held in the back off.
|
||||||
BackOffCount int `yaml:"back_off_count"`
|
BackOffCount int `yaml:"back_off_count"`
|
||||||
@ -35,18 +38,42 @@ type rateLimitConfig struct {
|
|||||||
// a client has hit the rate limit for a back off.
|
// a client has hit the rate limit for a back off.
|
||||||
BackOffPeriod timeutil.Duration `yaml:"back_off_period"`
|
BackOffPeriod timeutil.Duration `yaml:"back_off_period"`
|
||||||
|
|
||||||
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
|
|
||||||
// rate limiter bucket keys for IPv4 addresses.
|
|
||||||
IPv4SubnetKeyLen int `yaml:"ipv4_subnet_key_len"`
|
|
||||||
|
|
||||||
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
|
|
||||||
// rate limiter bucket keys for IPv6 addresses.
|
|
||||||
IPv6SubnetKeyLen int `yaml:"ipv6_subnet_key_len"`
|
|
||||||
|
|
||||||
// RefuseANY, if true, makes the server refuse DNS * queries.
|
// RefuseANY, if true, makes the server refuse DNS * queries.
|
||||||
RefuseANY bool `yaml:"refuse_any"`
|
RefuseANY bool `yaml:"refuse_any"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allowListConfig is the consul allow list configuration.
|
||||||
|
type allowListConfig struct {
|
||||||
|
// List contains IPs and CIDRs.
|
||||||
|
List []string `yaml:"list"`
|
||||||
|
|
||||||
|
// RefreshIvl time between two updates of allow list from the Consul URL.
|
||||||
|
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
|
||||||
|
// addresses.
|
||||||
|
type rateLimitOptions struct {
|
||||||
|
// RPS is the maximum number of requests per second.
|
||||||
|
RPS int `yaml:"rps"`
|
||||||
|
|
||||||
|
// SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||||
|
// rate limiter bucket keys.
|
||||||
|
SubnetKeyLen int `yaml:"subnet_key_len"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate returns an error if rate limit options are invalid.
|
||||||
|
func (o *rateLimitOptions) validate() (err error) {
|
||||||
|
if o == nil {
|
||||||
|
return errNilConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return coalesceError(
|
||||||
|
validatePositive("rps", o.RPS),
|
||||||
|
validatePositive("subnet_key_len", o.SubnetKeyLen),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// toInternal converts c to the rate limiting configuration for the DNS server.
|
// toInternal converts c to the rate limiting configuration for the DNS server.
|
||||||
// c is assumed to be valid.
|
// c is assumed to be valid.
|
||||||
func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) {
|
func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) {
|
||||||
@ -55,10 +82,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
|
|||||||
ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()),
|
ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()),
|
||||||
Duration: c.BackOffDuration.Duration,
|
Duration: c.BackOffDuration.Duration,
|
||||||
Period: c.BackOffPeriod.Duration,
|
Period: c.BackOffPeriod.Duration,
|
||||||
RPS: c.RPS,
|
IPv4RPS: c.IPv4.RPS,
|
||||||
|
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
|
||||||
|
IPv6RPS: c.IPv6.RPS,
|
||||||
|
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
|
||||||
Count: c.BackOffCount,
|
Count: c.BackOffCount,
|
||||||
IPv4SubnetKeyLen: c.IPv4SubnetKeyLen,
|
|
||||||
IPv6SubnetKeyLen: c.IPv6SubnetKeyLen,
|
|
||||||
RefuseANY: c.RefuseANY,
|
RefuseANY: c.RefuseANY,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -71,14 +99,21 @@ func (c *rateLimitConfig) validate() (err error) {
|
|||||||
return fmt.Errorf("allowlist: %w", errNilConfig)
|
return fmt.Errorf("allowlist: %w", errNilConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = c.IPv4.validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ipv4: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.IPv6.validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ipv6: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return coalesceError(
|
return coalesceError(
|
||||||
validatePositive("rps", c.RPS),
|
|
||||||
validatePositive("back_off_count", c.BackOffCount),
|
validatePositive("back_off_count", c.BackOffCount),
|
||||||
validatePositive("back_off_duration", c.BackOffDuration),
|
validatePositive("back_off_duration", c.BackOffDuration),
|
||||||
validatePositive("back_off_period", c.BackOffPeriod),
|
validatePositive("back_off_period", c.BackOffPeriod),
|
||||||
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
|
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
|
||||||
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
|
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
|
||||||
validatePositive("ipv4_subnet_key_len", c.IPv4SubnetKeyLen),
|
|
||||||
validatePositive("ipv6_subnet_key_len", c.IPv6SubnetKeyLen),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
86
internal/cmd/safebrowsing.go
Normal file
86
internal/cmd/safebrowsing.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Safe-browsing and adult-blocking configuration
|
||||||
|
|
||||||
|
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
|
||||||
|
type safeBrowsingConfig struct {
|
||||||
|
// URL is the URL used to update the filter.
|
||||||
|
URL *agdhttp.URL `yaml:"url"`
|
||||||
|
|
||||||
|
// BlockHost is the hostname with which to respond to any requests that
|
||||||
|
// match the filter.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
|
||||||
|
// addresses.
|
||||||
|
BlockHost string `yaml:"block_host"`
|
||||||
|
|
||||||
|
// CacheSize is the size of the response cache, in entries.
|
||||||
|
CacheSize int `yaml:"cache_size"`
|
||||||
|
|
||||||
|
// CacheTTL is the TTL of the response cache.
|
||||||
|
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
|
||||||
|
|
||||||
|
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
|
||||||
|
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// toInternal converts c to the safe browsing filter configuration for the
|
||||||
|
// filter storage of the DNS server. c is assumed to be valid.
|
||||||
|
func (c *safeBrowsingConfig) toInternal(
|
||||||
|
errColl agd.ErrorCollector,
|
||||||
|
resolver agdnet.Resolver,
|
||||||
|
id agd.FilterListID,
|
||||||
|
cacheDir string,
|
||||||
|
) (fltConf *filter.HashPrefixConfig, err error) {
|
||||||
|
hashes, err := hashstorage.New("")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &filter.HashPrefixConfig{
|
||||||
|
Hashes: hashes,
|
||||||
|
URL: netutil.CloneURL(&c.URL.URL),
|
||||||
|
ErrColl: errColl,
|
||||||
|
Resolver: resolver,
|
||||||
|
ID: id,
|
||||||
|
CachePath: filepath.Join(cacheDir, string(id)),
|
||||||
|
ReplacementHost: c.BlockHost,
|
||||||
|
Staleness: c.RefreshIvl.Duration,
|
||||||
|
CacheTTL: c.CacheTTL.Duration,
|
||||||
|
CacheSize: c.CacheSize,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate returns an error if the safe browsing filter configuration is
|
||||||
|
// invalid.
|
||||||
|
func (c *safeBrowsingConfig) validate() (err error) {
|
||||||
|
switch {
|
||||||
|
case c == nil:
|
||||||
|
return errNilConfig
|
||||||
|
case c.URL == nil:
|
||||||
|
return errors.Error("no url")
|
||||||
|
case c.BlockHost == "":
|
||||||
|
return errors.Error("no block_host")
|
||||||
|
case c.CacheSize <= 0:
|
||||||
|
return newMustBePositiveError("cache_size", c.CacheSize)
|
||||||
|
case c.CacheTTL.Duration <= 0:
|
||||||
|
return newMustBePositiveError("cache_ttl", c.CacheTTL)
|
||||||
|
case c.RefreshIvl.Duration <= 0:
|
||||||
|
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/textproto"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
@ -386,11 +387,8 @@ func (sc staticContent) validate() (err error) {
|
|||||||
|
|
||||||
// staticFile is a single file in a static content mapping.
|
// staticFile is a single file in a static content mapping.
|
||||||
type staticFile struct {
|
type staticFile struct {
|
||||||
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
|
// Headers contains headers of the HTTP response.
|
||||||
AllowOrigin string `yaml:"allow_origin"`
|
Headers http.Header `yaml:"headers"`
|
||||||
|
|
||||||
// ContentType is the value for the HTTP Content-Type header.
|
|
||||||
ContentType string `yaml:"content_type"`
|
|
||||||
|
|
||||||
// Content is the file content.
|
// Content is the file content.
|
||||||
Content string `yaml:"content"`
|
Content string `yaml:"content"`
|
||||||
@ -400,8 +398,12 @@ type staticFile struct {
|
|||||||
// assumed to be valid.
|
// assumed to be valid.
|
||||||
func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
|
func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
|
||||||
file = &websvc.StaticFile{
|
file = &websvc.StaticFile{
|
||||||
AllowOrigin: f.AllowOrigin,
|
Headers: http.Header{},
|
||||||
ContentType: f.ContentType,
|
}
|
||||||
|
|
||||||
|
for k, vs := range f.Headers {
|
||||||
|
ck := textproto.CanonicalMIMEHeaderKey(k)
|
||||||
|
file.Headers[ck] = vs
|
||||||
}
|
}
|
||||||
|
|
||||||
file.Content, err = base64.StdEncoding.DecodeString(f.Content)
|
file.Content, err = base64.StdEncoding.DecodeString(f.Content)
|
||||||
@ -409,17 +411,20 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
|
|||||||
return nil, fmt.Errorf("content: %w", err)
|
return nil, fmt.Errorf("content: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check Content-Type here as opposed to in validate, because we need
|
||||||
|
// all keys to be canonicalized first.
|
||||||
|
if file.Headers.Get(agdhttp.HdrNameContentType) == "" {
|
||||||
|
return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required")
|
||||||
|
}
|
||||||
|
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if the static content file is invalid.
|
// validate returns an error if the static content file is invalid.
|
||||||
func (f *staticFile) validate() (err error) {
|
func (f *staticFile) validate() (err error) {
|
||||||
switch {
|
if f == nil {
|
||||||
case f == nil:
|
|
||||||
return errors.Error("no file")
|
return errors.Error("no file")
|
||||||
case f.ContentType == "":
|
|
||||||
return errors.Error("no content_type")
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -222,9 +222,9 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindENDS0Option searches for the specified EDNS0 option in the OPT resource
|
// FindEDNS0Option searches for the specified EDNS0 option in the OPT resource
|
||||||
// record of the msg and returns it or nil if it's not present.
|
// record of the msg and returns it or nil if it's not present.
|
||||||
func FindENDS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
|
func FindEDNS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
|
||||||
rr := msg.IsEdns0()
|
rr := msg.IsEdns0()
|
||||||
if rr == nil {
|
if rr == nil {
|
||||||
return o
|
return o
|
||||||
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
|
|||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/golibs v0.11.3
|
github.com/AdguardTeam/golibs v0.11.4
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
||||||
github.com/ameshkov/dnsstamps v1.0.3
|
github.com/ameshkov/dnsstamps v1.0.3
|
||||||
github.com/bluele/gcache v0.0.2
|
github.com/bluele/gcache v0.0.2
|
||||||
|
@ -31,8 +31,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
|
|||||||
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
|
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
|
||||||
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
|
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
|
||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
|
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
|
||||||
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
|
|
||||||
|
|
||||||
package dnsserver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// listenUDP listens to the specified address on UDP.
|
|
||||||
func listenUDP(_ context.Context, addr string, _ bool) (conn *net.UDPConn, err error) {
|
|
||||||
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
|
|
||||||
|
|
||||||
c, err := net.ListenPacket("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, ok := c.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
// TODO(ameshkov): should not happen, consider panic here.
|
|
||||||
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// listenTCP listens to the specified address on TCP.
|
|
||||||
func listenTCP(_ context.Context, addr string) (conn net.Listener, err error) {
|
|
||||||
return net.Listen("tcp", addr)
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
|
|
||||||
|
|
||||||
package dnsserver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func reuseportControl(_, _ string, c syscall.RawConn) (err error) {
|
|
||||||
var opErr error
|
|
||||||
err = c.Control(func(fd uintptr) {
|
|
||||||
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return opErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// listenUDP listens to the specified address on UDP. If oob flag is set to
|
|
||||||
// true this method also enables OOB for the listen socket that enables using of
|
|
||||||
// ReadMsgUDP/WriteMsgUDP. Doing it this way is necessary to correctly discover
|
|
||||||
// the source address when it listens to 0.0.0.0.
|
|
||||||
func listenUDP(ctx context.Context, addr string, oob bool) (conn *net.UDPConn, err error) {
|
|
||||||
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
|
|
||||||
|
|
||||||
var lc net.ListenConfig
|
|
||||||
lc.Control = reuseportControl
|
|
||||||
c, err := lc.ListenPacket(ctx, "udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, ok := c.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
// TODO(ameshkov): should not happen, consider panic here.
|
|
||||||
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if oob {
|
|
||||||
if err = setUDPSocketOptions(conn); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set socket options: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// listenTCP listens to the specified address on TCP.
|
|
||||||
func listenTCP(ctx context.Context, addr string) (l net.Listener, err error) {
|
|
||||||
var lc net.ListenConfig
|
|
||||||
lc.Control = reuseportControl
|
|
||||||
return lc.Listen(ctx, "tcp", addr)
|
|
||||||
}
|
|
73
internal/dnsserver/netext/listenconfig.go
Normal file
73
internal/dnsserver/netext/listenconfig.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
// Package netext contains extensions of package net in the Go standard library.
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenConfig is the interface that allows controlling options of connections
|
||||||
|
// used by the DNS servers defined in this module. Default ListenConfigs are
|
||||||
|
// the ones returned by [DefaultListenConfigWithOOB] for plain DNS and
|
||||||
|
// [DefaultListenConfig] for others.
|
||||||
|
//
|
||||||
|
// This interface is modeled after [net.ListenConfig].
|
||||||
|
type ListenConfig interface {
|
||||||
|
Listen(ctx context.Context, network, address string) (l net.Listener, err error)
|
||||||
|
ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultListenConfig returns the default [ListenConfig] used by the servers in
|
||||||
|
// this module except for the plain-DNS ones, which use
|
||||||
|
// [DefaultListenConfigWithOOB].
|
||||||
|
func DefaultListenConfig() (lc ListenConfig) {
|
||||||
|
return &net.ListenConfig{
|
||||||
|
Control: defaultListenControl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the
|
||||||
|
// plain-DNS servers in this module. The resulting ListenConfig sets additional
|
||||||
|
// socket flags and processes the control-messages of connections created with
|
||||||
|
// ListenPacket.
|
||||||
|
func DefaultListenConfigWithOOB() (lc ListenConfig) {
|
||||||
|
return &listenConfigOOB{
|
||||||
|
ListenConfig: net.ListenConfig{
|
||||||
|
Control: defaultListenControl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ ListenConfig = (*listenConfigOOB)(nil)
|
||||||
|
|
||||||
|
// listenConfigOOB is a wrapper around [net.ListenConfig] with modifications
|
||||||
|
// that set the control-message options on packet conns.
|
||||||
|
type listenConfigOOB struct {
|
||||||
|
net.ListenConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenPacket implements the [ListenConfig] interface for *listenConfigOOB.
|
||||||
|
// It sets the control-message flags to receive additional out-of-band data to
|
||||||
|
// correctly discover the source address when it listens to 0.0.0.0 as well as
|
||||||
|
// in situations when SO_BINDTODEVICE is used.
|
||||||
|
//
|
||||||
|
// network must be "udp", "udp4", or "udp6".
|
||||||
|
func (lc *listenConfigOOB) ListenPacket(
|
||||||
|
ctx context.Context,
|
||||||
|
network string,
|
||||||
|
address string,
|
||||||
|
) (c net.PacketConn, err error) {
|
||||||
|
c, err = lc.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = setIPOpts(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting socket options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapPacketConn(c), nil
|
||||||
|
}
|
44
internal/dnsserver/netext/listenconfig_unix.go
Normal file
44
internal/dnsserver/netext/listenconfig_unix.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultListenControl is used as a [net.ListenConfig.Control] function to set
|
||||||
|
// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this
|
||||||
|
// package.
|
||||||
|
func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
|
||||||
|
var opErr error
|
||||||
|
err = c.Control(func(fd uintptr) {
|
||||||
|
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.WithDeferred(opErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
|
||||||
|
func setIPOpts(c net.PacketConn) (err error) {
|
||||||
|
// TODO(a.garipov): Returning an error only if both functions return one
|
||||||
|
// (which is what module dns does as well) seems rather fragile. Depending
|
||||||
|
// on the OS, the valid errors are ENOPROTOOPT, EINVAL, and maybe others.
|
||||||
|
// Investigate and make OS-specific versions to make sure we don't miss any
|
||||||
|
// real errors.
|
||||||
|
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
|
||||||
|
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
|
||||||
|
if err4 != nil && err6 != nil {
|
||||||
|
return errors.List("setting ipv4 and ipv6 options", err4, err6)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
67
internal/dnsserver/netext/listenconfig_unix_test.go
Normal file
67
internal/dnsserver/netext/listenconfig_unix_test.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package netext_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultListenConfigWithOOB(t *testing.T) {
|
||||||
|
lc := netext.DefaultListenConfigWithOOB()
|
||||||
|
require.NotNil(t, lc)
|
||||||
|
|
||||||
|
type syscallConner interface {
|
||||||
|
SyscallConn() (c syscall.RawConn, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ipv4", func(t *testing.T) {
|
||||||
|
c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
require.Implements(t, (*syscallConner)(nil), c)
|
||||||
|
|
||||||
|
sc, err := c.(syscallConner).SyscallConn()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sc.Control(func(fd uintptr) {
|
||||||
|
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
|
||||||
|
require.NoError(t, opErr)
|
||||||
|
|
||||||
|
// TODO(a.garipov): Rewrite this to use actual expected values for
|
||||||
|
// each OS.
|
||||||
|
assert.NotEqual(t, 0, val)
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ipv6", func(t *testing.T) {
|
||||||
|
c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0")
|
||||||
|
if errors.Is(err, syscall.EADDRNOTAVAIL) {
|
||||||
|
// Some CI machines have IPv6 disabled.
|
||||||
|
t.Skipf("ipv6 seems to not be supported: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
require.Implements(t, (*syscallConner)(nil), c)
|
||||||
|
|
||||||
|
sc, err := c.(syscallConner).SyscallConn()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sc.Control(func(fd uintptr) {
|
||||||
|
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
|
||||||
|
require.NoError(t, opErr)
|
||||||
|
|
||||||
|
assert.NotEqual(t, 0, val)
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
17
internal/dnsserver/netext/listenconfig_windows.go
Normal file
17
internal/dnsserver/netext/listenconfig_windows.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultListenControl is nil on Windows, because it doesn't support
|
||||||
|
// SO_REUSEPORT.
|
||||||
|
var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
|
||||||
|
|
||||||
|
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
|
||||||
|
func setIPOpts(c net.PacketConn) (err error) {
|
||||||
|
return nil
|
||||||
|
}
|
69
internal/dnsserver/netext/packetconn.go
Normal file
69
internal/dnsserver/netext/packetconn.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketSession contains additional information about a packet read from or
|
||||||
|
// written to a [SessionPacketConn].
|
||||||
|
type PacketSession interface {
|
||||||
|
LocalAddr() (addr net.Addr)
|
||||||
|
RemoteAddr() (addr net.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimplePacketSession returns a new packet session using the given
|
||||||
|
// parameters.
|
||||||
|
func NewSimplePacketSession(laddr, raddr net.Addr) (s PacketSession) {
|
||||||
|
return &simplePacketSession{
|
||||||
|
laddr: laddr,
|
||||||
|
raddr: raddr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// simplePacketSession is a simple implementation of the [PacketSession]
|
||||||
|
// interface.
|
||||||
|
type simplePacketSession struct {
|
||||||
|
laddr net.Addr
|
||||||
|
raddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements the [PacketSession] interface for *simplePacketSession.
|
||||||
|
func (s *simplePacketSession) LocalAddr() (addr net.Addr) { return s.laddr }
|
||||||
|
|
||||||
|
// RemoteAddr implements the [PacketSession] interface for *simplePacketSession.
|
||||||
|
func (s *simplePacketSession) RemoteAddr() (addr net.Addr) { return s.raddr }
|
||||||
|
|
||||||
|
// SessionPacketConn extends [net.PacketConn] with methods for working with
|
||||||
|
// packet sessions.
|
||||||
|
type SessionPacketConn interface {
|
||||||
|
net.PacketConn
|
||||||
|
|
||||||
|
ReadFromSession(b []byte) (n int, s PacketSession, err error)
|
||||||
|
WriteToSession(b []byte, s PacketSession) (n int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFromSession is a convenience wrapper for types that may or may not
|
||||||
|
// implement [SessionPacketConn]. If c implements it, ReadFromSession uses
|
||||||
|
// c.ReadFromSession. Otherwise, it uses c.ReadFrom and the session is created
|
||||||
|
// by using [NewSimplePacketSession] with c.LocalAddr.
|
||||||
|
func ReadFromSession(c net.PacketConn, b []byte) (n int, s PacketSession, err error) {
|
||||||
|
if spc, ok := c.(SessionPacketConn); ok {
|
||||||
|
return spc.ReadFromSession(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, raddr, err := c.ReadFrom(b)
|
||||||
|
s = NewSimplePacketSession(c.LocalAddr(), raddr)
|
||||||
|
|
||||||
|
return n, s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToSession is a convenience wrapper for types that may or may not
|
||||||
|
// implement [SessionPacketConn]. If c implements it, WriteToSession uses
|
||||||
|
// c.WriteToSession. Otherwise, it uses c.WriteTo using s.RemoteAddr.
|
||||||
|
func WriteToSession(c net.PacketConn, b []byte, s PacketSession) (n int, err error) {
|
||||||
|
if spc, ok := c.(SessionPacketConn); ok {
|
||||||
|
return spc.WriteToSession(b, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.WriteTo(b, s.RemoteAddr())
|
||||||
|
}
|
159
internal/dnsserver/netext/packetconn_linux.go
Normal file
159
internal/dnsserver/netext/packetconn_linux.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
// TODO(a.garipov): Technically, we can expand this to other platforms, but that
|
||||||
|
// would require separate udpOOBSize constants and tests.
|
||||||
|
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ PacketSession = (*packetSession)(nil)
|
||||||
|
|
||||||
|
// packetSession contains additional information about the packet read from a
|
||||||
|
// UDP connection. It is basically an extended version of [dns.SessionUDP] that
|
||||||
|
// contains the local address as well.
|
||||||
|
type packetSession struct {
|
||||||
|
laddr *net.UDPAddr
|
||||||
|
raddr *net.UDPAddr
|
||||||
|
respOOB []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements the [PacketSession] interface for *packetSession.
|
||||||
|
func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr }
|
||||||
|
|
||||||
|
// RemoteAddr implements the [PacketSession] interface for *packetSession.
|
||||||
|
func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr }
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ SessionPacketConn = (*sessionPacketConn)(nil)
|
||||||
|
|
||||||
|
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
|
||||||
|
// that.
|
||||||
|
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
|
||||||
|
return &sessionPacketConn{
|
||||||
|
UDPConn: *c.(*net.UDPConn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionPacketConn wraps a UDP connection and implements [SessionPacketConn].
|
||||||
|
type sessionPacketConn struct {
|
||||||
|
net.UDPConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// oobPool is the pool of byte slices for out-of-band data.
|
||||||
|
var oobPool = &sync.Pool{
|
||||||
|
New: func() (v any) {
|
||||||
|
b := make([]byte, IPDstOOBSize)
|
||||||
|
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPDstOOBSize is the required size of the control-message buffer for
|
||||||
|
// [net.UDPConn.ReadMsgUDP] to read the original destination on Linux.
|
||||||
|
//
|
||||||
|
// See packetconn_linux_internal_test.go.
|
||||||
|
const IPDstOOBSize = 40
|
||||||
|
|
||||||
|
// ReadFromSession implements the [SessionPacketConn] interface for *packetConn.
|
||||||
|
func (c *sessionPacketConn) ReadFromSession(b []byte) (n int, s PacketSession, err error) {
|
||||||
|
oobPtr := oobPool.Get().(*[]byte)
|
||||||
|
defer oobPool.Put(oobPtr)
|
||||||
|
|
||||||
|
var oobn int
|
||||||
|
oob := *oobPtr
|
||||||
|
ps := &packetSession{}
|
||||||
|
n, oobn, _, ps.raddr, err = c.ReadMsgUDP(b, oob)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var origDstIP net.IP
|
||||||
|
sockLAddr := c.LocalAddr().(*net.UDPAddr)
|
||||||
|
origDstIP, err = origLAddr(oob[:oobn])
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("getting original addr: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if origDstIP == nil {
|
||||||
|
ps.laddr = sockLAddr
|
||||||
|
} else {
|
||||||
|
ps.respOOB = newRespOOB(origDstIP)
|
||||||
|
ps.laddr = &net.UDPAddr{
|
||||||
|
IP: origDstIP,
|
||||||
|
Port: sockLAddr.Port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, ps, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// origLAddr returns the original local address from the encoded control-message
|
||||||
|
// data, if there is one. If not nil, origDst will have a protocol-appropriate
|
||||||
|
// length.
|
||||||
|
func origLAddr(oob []byte) (origDst net.IP, err error) {
|
||||||
|
ctrlMsg6 := &ipv6.ControlMessage{}
|
||||||
|
err = ctrlMsg6.Parse(oob)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing ipv6 control message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dst := ctrlMsg6.Dst; dst != nil {
|
||||||
|
// Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
|
||||||
|
// IPv4 dst from an IPv6 control-message.
|
||||||
|
origDst = dst.To4()
|
||||||
|
if origDst == nil {
|
||||||
|
origDst = dst
|
||||||
|
}
|
||||||
|
|
||||||
|
return origDst, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrlMsg4 := &ipv4.ControlMessage{}
|
||||||
|
err = ctrlMsg4.Parse(oob)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing ipv4 control message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctrlMsg4.Dst.To4(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRespOOB returns an encoded control-message for the response for this IP
|
||||||
|
// address. origDst is expected to have a protocol-appropriate length.
|
||||||
|
func newRespOOB(origDst net.IP) (b []byte) {
|
||||||
|
switch len(origDst) {
|
||||||
|
case net.IPv4len:
|
||||||
|
cm := &ipv4.ControlMessage{
|
||||||
|
Src: origDst,
|
||||||
|
}
|
||||||
|
|
||||||
|
return cm.Marshal()
|
||||||
|
case net.IPv6len:
|
||||||
|
cm := &ipv6.ControlMessage{
|
||||||
|
Src: origDst,
|
||||||
|
}
|
||||||
|
|
||||||
|
return cm.Marshal()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToSession implements the [SessionPacketConn] interface for *packetConn.
|
||||||
|
func (c *sessionPacketConn) WriteToSession(b []byte, s PacketSession) (n int, err error) {
|
||||||
|
if ps, ok := s.(*packetSession); ok {
|
||||||
|
n, _, err = c.WriteMsgUDP(b, ps.respOOB, ps.raddr)
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.WriteTo(b, s.RemoteAddr())
|
||||||
|
}
|
25
internal/dnsserver/netext/packetconn_linux_internal_test.go
Normal file
25
internal/dnsserver/netext/packetconn_linux_internal_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUDPOOBSize(t *testing.T) {
|
||||||
|
// See https://github.com/miekg/dns/blob/v1.1.50/udp.go.
|
||||||
|
|
||||||
|
len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface))
|
||||||
|
len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface))
|
||||||
|
|
||||||
|
max := len4
|
||||||
|
if len6 > max {
|
||||||
|
max = len6
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, max, IPDstOOBSize)
|
||||||
|
}
|
140
internal/dnsserver/netext/packetconn_linux_test.go
Normal file
140
internal/dnsserver/netext/packetconn_linux_test.go
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package netext_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(a.garipov): Add IPv6 test.
|
||||||
|
func TestSessionPacketConn(t *testing.T) {
|
||||||
|
const numTries = 5
|
||||||
|
|
||||||
|
// Try the test multiple times to reduce flakiness due to UDP failures.
|
||||||
|
var success4, success6 bool
|
||||||
|
for i := 0; i < numTries; i++ {
|
||||||
|
var isTimeout4, isTimeout6 bool
|
||||||
|
success4 = t.Run(fmt.Sprintf("ipv4_%d", i), func(t *testing.T) {
|
||||||
|
isTimeout4 = testSessionPacketConn(t, "udp4", "0.0.0.0:0", net.IP{127, 0, 0, 1})
|
||||||
|
})
|
||||||
|
|
||||||
|
success6 = t.Run(fmt.Sprintf("ipv6_%d", i), func(t *testing.T) {
|
||||||
|
isTimeout6 = testSessionPacketConn(t, "udp6", "[::]:0", net.IPv6loopback)
|
||||||
|
})
|
||||||
|
|
||||||
|
if success4 && success6 {
|
||||||
|
break
|
||||||
|
} else if isTimeout4 || isTimeout6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success4 {
|
||||||
|
t.Errorf("ipv4 test failed after %d attempts", numTries)
|
||||||
|
} else if !success6 {
|
||||||
|
t.Errorf("ipv6 test failed after %d attempts", numTries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) {
|
||||||
|
lc := netext.DefaultListenConfigWithOOB()
|
||||||
|
require.NotNil(t, lc)
|
||||||
|
|
||||||
|
c, err := lc.ListenPacket(context.Background(), proto, addr)
|
||||||
|
if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, c)
|
||||||
|
|
||||||
|
deadline := time.Now().Add(1 * time.Second)
|
||||||
|
err = c.SetDeadline(deadline)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
laddr := testutil.RequireTypeAssert[*net.UDPAddr](t, c.LocalAddr())
|
||||||
|
require.NotNil(t, laddr)
|
||||||
|
|
||||||
|
dstAddr := &net.UDPAddr{
|
||||||
|
IP: dstIP,
|
||||||
|
Port: laddr.Port,
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteConn, err := net.DialUDP(proto, nil, dstAddr)
|
||||||
|
if proto == "udp6" && errors.Is(err, syscall.EADDRNOTAVAIL) {
|
||||||
|
// Some CI machines have IPv6 disabled.
|
||||||
|
t.Skipf("ipv6 seems to not be supported: %s", err)
|
||||||
|
} else if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
err = remoteConn.SetDeadline(deadline)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := []byte("hello")
|
||||||
|
msgLen := len(msg)
|
||||||
|
_, err = remoteConn.Write(msg)
|
||||||
|
if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Implements(t, (*netext.SessionPacketConn)(nil), c)
|
||||||
|
|
||||||
|
buf := make([]byte, msgLen)
|
||||||
|
n, sess, err := netext.ReadFromSession(c, buf)
|
||||||
|
if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, msgLen, n)
|
||||||
|
assert.Equal(t, net.Addr(dstAddr), sess.LocalAddr())
|
||||||
|
assert.Equal(t, remoteConn.LocalAddr(), sess.RemoteAddr())
|
||||||
|
assert.Equal(t, msg, buf)
|
||||||
|
|
||||||
|
respMsg := []byte("world")
|
||||||
|
respMsgLen := len(respMsg)
|
||||||
|
n, err = netext.WriteToSession(c, respMsg, sess)
|
||||||
|
if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, respMsgLen, n)
|
||||||
|
|
||||||
|
buf = make([]byte, respMsgLen)
|
||||||
|
n, err = remoteConn.Read(buf)
|
||||||
|
if isTimeoutOrFail(t, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, respMsgLen, n)
|
||||||
|
assert.Equal(t, respMsg, buf)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTimeoutOrFail is a helper function that returns true if err is a timeout
|
||||||
|
// error and also calls require.NoError on err.
|
||||||
|
func isTimeoutOrFail(t *testing.T, err error) (ok bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
defer require.NoError(t, err)
|
||||||
|
|
||||||
|
return errors.Is(err, os.ErrDeadlineExceeded)
|
||||||
|
}
|
11
internal/dnsserver/netext/packetconn_others.go
Normal file
11
internal/dnsserver/netext/packetconn_others.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package netext
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
|
||||||
|
// that.
|
||||||
|
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
|
||||||
|
return c
|
||||||
|
}
|
@ -27,7 +27,8 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
|
|||||||
Duration: time.Minute,
|
Duration: time.Minute,
|
||||||
Count: rps,
|
Count: rps,
|
||||||
ResponseSizeEstimate: 1000,
|
ResponseSizeEstimate: 1000,
|
||||||
RPS: rps,
|
IPv4RPS: rps,
|
||||||
|
IPv6RPS: rps,
|
||||||
RefuseANY: true,
|
RefuseANY: true,
|
||||||
})
|
})
|
||||||
rlMw, err := ratelimit.NewMiddleware(rl, nil)
|
rlMw, err := ratelimit.NewMiddleware(rl, nil)
|
||||||
|
@ -37,15 +37,22 @@ type BackOffConfig struct {
|
|||||||
// as several responses.
|
// as several responses.
|
||||||
ResponseSizeEstimate int
|
ResponseSizeEstimate int
|
||||||
|
|
||||||
// RPS is the maximum number of requests per second allowed from a single
|
// IPv4RPS is the maximum number of requests per second allowed from a
|
||||||
// subnet. Any requests above this rate are counted as the client's
|
// single subnet for IPv4 addresses. Any requests above this rate are
|
||||||
// back-off count. RPS must be greater than zero.
|
// counted as the client's back-off count. RPS must be greater than
|
||||||
RPS int
|
// zero.
|
||||||
|
IPv4RPS int
|
||||||
|
|
||||||
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
|
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||||
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
|
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
|
||||||
IPv4SubnetKeyLen int
|
IPv4SubnetKeyLen int
|
||||||
|
|
||||||
|
// IPv6RPS is the maximum number of requests per second allowed from a
|
||||||
|
// single subnet for IPv6 addresses. Any requests above this rate are
|
||||||
|
// counted as the client's back-off count. RPS must be greater than
|
||||||
|
// zero.
|
||||||
|
IPv6RPS int
|
||||||
|
|
||||||
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
|
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
|
||||||
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
|
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
|
||||||
IPv6SubnetKeyLen int
|
IPv6SubnetKeyLen int
|
||||||
@ -62,14 +69,18 @@ type BackOffConfig struct {
|
|||||||
// current implementation might be too abstract. Middlewares by themselves
|
// current implementation might be too abstract. Middlewares by themselves
|
||||||
// already provide an interface that can be re-implemented by the users.
|
// already provide an interface that can be re-implemented by the users.
|
||||||
// Perhaps, another layer of abstraction is unnecessary.
|
// Perhaps, another layer of abstraction is unnecessary.
|
||||||
|
//
|
||||||
|
// TODO(ameshkov): Consider splitting rps and other properties by protocol
|
||||||
|
// family.
|
||||||
type BackOff struct {
|
type BackOff struct {
|
||||||
rpsCounters *cache.Cache
|
rpsCounters *cache.Cache
|
||||||
hitCounters *cache.Cache
|
hitCounters *cache.Cache
|
||||||
allowlist Allowlist
|
allowlist Allowlist
|
||||||
count int
|
count int
|
||||||
rps int
|
|
||||||
respSzEst int
|
respSzEst int
|
||||||
|
ipv4rps int
|
||||||
ipv4SubnetKeyLen int
|
ipv4SubnetKeyLen int
|
||||||
|
ipv6rps int
|
||||||
ipv6SubnetKeyLen int
|
ipv6SubnetKeyLen int
|
||||||
refuseANY bool
|
refuseANY bool
|
||||||
}
|
}
|
||||||
@ -84,9 +95,10 @@ func NewBackOff(c *BackOffConfig) (l *BackOff) {
|
|||||||
hitCounters: cache.New(c.Duration, c.Duration),
|
hitCounters: cache.New(c.Duration, c.Duration),
|
||||||
allowlist: c.Allowlist,
|
allowlist: c.Allowlist,
|
||||||
count: c.Count,
|
count: c.Count,
|
||||||
rps: c.RPS,
|
|
||||||
respSzEst: c.ResponseSizeEstimate,
|
respSzEst: c.ResponseSizeEstimate,
|
||||||
|
ipv4rps: c.IPv4RPS,
|
||||||
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
|
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
|
||||||
|
ipv6rps: c.IPv6RPS,
|
||||||
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
|
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
|
||||||
refuseANY: c.RefuseANY,
|
refuseANY: c.RefuseANY,
|
||||||
}
|
}
|
||||||
@ -124,7 +136,12 @@ func (l *BackOff) IsRateLimited(
|
|||||||
return true, false, nil
|
return true, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return l.hasHitRateLimit(key), false, nil
|
rps := l.ipv4rps
|
||||||
|
if ip.Is6() {
|
||||||
|
rps = l.ipv6rps
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.hasHitRateLimit(key, rps), false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
|
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
|
||||||
@ -184,14 +201,15 @@ func (l *BackOff) incBackOff(key string) {
|
|||||||
l.hitCounters.SetDefault(key, counter)
|
l.hitCounters.SetDefault(key, counter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasHitRateLimit checks value for a subnet.
|
// hasHitRateLimit checks value for a subnet with rps as a maximum number
|
||||||
func (l *BackOff) hasHitRateLimit(subnetIPStr string) (ok bool) {
|
// requests per second.
|
||||||
var r *rps
|
func (l *BackOff) hasHitRateLimit(subnetIPStr string, rps int) (ok bool) {
|
||||||
|
var r *rpsCounter
|
||||||
rVal, ok := l.rpsCounters.Get(subnetIPStr)
|
rVal, ok := l.rpsCounters.Get(subnetIPStr)
|
||||||
if ok {
|
if ok {
|
||||||
r = rVal.(*rps)
|
r = rVal.(*rpsCounter)
|
||||||
} else {
|
} else {
|
||||||
r = newRPS(l.rps)
|
r = newRPSCounter(rps)
|
||||||
l.rpsCounters.SetDefault(subnetIPStr, r)
|
l.rpsCounters.SetDefault(subnetIPStr, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,8 +98,9 @@ func TestRatelimitMiddleware(t *testing.T) {
|
|||||||
Duration: time.Minute,
|
Duration: time.Minute,
|
||||||
Count: rps,
|
Count: rps,
|
||||||
ResponseSizeEstimate: 128,
|
ResponseSizeEstimate: 128,
|
||||||
RPS: rps,
|
IPv4RPS: rps,
|
||||||
IPv4SubnetKeyLen: 24,
|
IPv4SubnetKeyLen: 24,
|
||||||
|
IPv6RPS: rps,
|
||||||
IPv6SubnetKeyLen: 48,
|
IPv6SubnetKeyLen: 48,
|
||||||
RefuseANY: true,
|
RefuseANY: true,
|
||||||
})
|
})
|
||||||
|
@ -7,17 +7,17 @@ import (
|
|||||||
|
|
||||||
// Requests Per Second Counter
|
// Requests Per Second Counter
|
||||||
|
|
||||||
// rps is a single request per seconds counter.
|
// rpsCounter is a single request per seconds counter.
|
||||||
type rps struct {
|
type rpsCounter struct {
|
||||||
// mu protects all fields.
|
// mu protects all fields.
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
ring []int64
|
ring []int64
|
||||||
idx int
|
idx int
|
||||||
}
|
}
|
||||||
|
|
||||||
// newRPS returns a new requests per second counter. n must be above zero.
|
// newRPSCounter returns a new requests per second counter. n must be above zero.
|
||||||
func newRPS(n int) (r *rps) {
|
func newRPSCounter(n int) (r *rpsCounter) {
|
||||||
return &rps{
|
return &rpsCounter{
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
// Add one, because we need to always keep track of the previous
|
// Add one, because we need to always keep track of the previous
|
||||||
// request. For example, consider n == 1.
|
// request. For example, consider n == 1.
|
||||||
@ -28,7 +28,7 @@ func newRPS(n int) (r *rps) {
|
|||||||
|
|
||||||
// add adds another request to the counter. above is true if the request goes
|
// add adds another request to the counter. above is true if the request goes
|
||||||
// above the counter value. It is safe for concurrent use.
|
// above the counter value. It is safe for concurrent use.
|
||||||
func (r *rps) add(t time.Time) (above bool) {
|
func (r *rpsCounter) add(t time.Time) (above bool) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@ -31,16 +32,19 @@ type ConfigBase struct {
|
|||||||
// Handler is a handler that processes incoming DNS messages.
|
// Handler is a handler that processes incoming DNS messages.
|
||||||
// If not set, we'll use the default handler that returns error response
|
// If not set, we'll use the default handler that returns error response
|
||||||
// to any query.
|
// to any query.
|
||||||
|
|
||||||
Handler Handler
|
Handler Handler
|
||||||
|
|
||||||
// Metrics is the object we use for collecting performance metrics.
|
// Metrics is the object we use for collecting performance metrics.
|
||||||
// This field is optional.
|
// This field is optional.
|
||||||
|
|
||||||
Metrics MetricsListener
|
Metrics MetricsListener
|
||||||
|
|
||||||
// BaseContext is a function that should return the base context. If not
|
// BaseContext is a function that should return the base context. If not
|
||||||
// set, we'll be using context.Background().
|
// set, we'll be using context.Background().
|
||||||
BaseContext func() (ctx context.Context)
|
BaseContext func() (ctx context.Context)
|
||||||
|
|
||||||
|
// ListenConfig, when set, is used to set options of connections used by the
|
||||||
|
// DNS server. If nil, an appropriate default ListenConfig is used.
|
||||||
|
ListenConfig netext.ListenConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerBase implements base methods that every Server implementation uses.
|
// ServerBase implements base methods that every Server implementation uses.
|
||||||
@ -67,14 +71,19 @@ type ServerBase struct {
|
|||||||
// metrics is the object we use for collecting performance metrics.
|
// metrics is the object we use for collecting performance metrics.
|
||||||
metrics MetricsListener
|
metrics MetricsListener
|
||||||
|
|
||||||
|
// listenConfig is used to set tcpListener and udpListener.
|
||||||
|
listenConfig netext.ListenConfig
|
||||||
|
|
||||||
// Server operation
|
// Server operation
|
||||||
// --
|
// --
|
||||||
|
|
||||||
// will be nil for servers that don't use TCP.
|
// tcpListener is used to accept new TCP connections. It is nil for servers
|
||||||
|
// that don't use TCP.
|
||||||
tcpListener net.Listener
|
tcpListener net.Listener
|
||||||
|
|
||||||
// will be nil for servers that don't use UDP.
|
// udpListener is used to accept new UDP messages. It is nil for servers
|
||||||
udpListener *net.UDPConn
|
// that don't use UDP.
|
||||||
|
udpListener net.PacketConn
|
||||||
|
|
||||||
// Shutdown handling
|
// Shutdown handling
|
||||||
// --
|
// --
|
||||||
@ -94,13 +103,14 @@ var _ Server = (*ServerBase)(nil)
|
|||||||
// some of its internal properties.
|
// some of its internal properties.
|
||||||
func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
|
func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
|
||||||
s = &ServerBase{
|
s = &ServerBase{
|
||||||
name: conf.Name,
|
name: conf.Name,
|
||||||
addr: conf.Addr,
|
addr: conf.Addr,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
network: conf.Network,
|
network: conf.Network,
|
||||||
handler: conf.Handler,
|
handler: conf.Handler,
|
||||||
metrics: conf.Metrics,
|
metrics: conf.Metrics,
|
||||||
baseContext: conf.BaseContext,
|
listenConfig: conf.ListenConfig,
|
||||||
|
baseContext: conf.BaseContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.baseContext == nil {
|
if s.baseContext == nil {
|
||||||
@ -347,18 +357,17 @@ func (s *ServerBase) handlePanicAndRecover(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenUDP creates a UDP listener for the ServerBase.addr. This function will
|
// listenUDP initializes and starts s.udpListener using s.addr. If the TCP
|
||||||
// initialize and start ServerBase.udpListener or return an error. If the TCP
|
// listener is already running, its address is used instead to properly handle
|
||||||
// listener is already running, its address is used instead. The point of this
|
// the case when port 0 is used as both listeners should use the same port, and
|
||||||
// is to properly handle the case when port 0 is used as both listeners should
|
// we only learn it after the first one was started.
|
||||||
// use the same port, and we only learn it after the first one was started.
|
|
||||||
func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
|
func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
|
||||||
addr := s.addr
|
addr := s.addr
|
||||||
if s.tcpListener != nil {
|
if s.tcpListener != nil {
|
||||||
addr = s.tcpListener.Addr().String()
|
addr = s.tcpListener.Addr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := listenUDP(ctx, addr, true)
|
conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -368,19 +377,17 @@ func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenTCP creates a TCP listener for the ServerBase.addr. This function will
|
// listenTCP initializes and starts s.tcpListener using s.addr. If the UDP
|
||||||
// initialize and start ServerBase.tcpListener or return an error. If the UDP
|
// listener is already running, its address is used instead to properly handle
|
||||||
// listener is already running, its address is used instead. The point of this
|
// the case when port 0 is used as both listeners should use the same port, and
|
||||||
// is to properly handle the case when port 0 is used as both listeners should
|
// we only learn it after the first one was started.
|
||||||
// use the same port, and we only learn it after the first one was started.
|
|
||||||
func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
|
func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
|
||||||
addr := s.addr
|
addr := s.addr
|
||||||
if s.udpListener != nil {
|
if s.udpListener != nil {
|
||||||
addr = s.udpListener.LocalAddr().String()
|
addr = s.udpListener.LocalAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
var l net.Listener
|
l, err := s.listenConfig.Listen(ctx, "tcp", addr)
|
||||||
l, err = listenTCP(ctx, addr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -110,6 +111,10 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) {
|
|||||||
conf.TCPSize = dns.MinMsgSize
|
conf.TCPSize = dns.MinMsgSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if conf.ListenConfig == nil {
|
||||||
|
conf.ListenConfig = netext.DefaultListenConfigWithOOB()
|
||||||
|
}
|
||||||
|
|
||||||
s = &ServerDNS{
|
s = &ServerDNS{
|
||||||
ServerBase: newServerBase(proto, conf.ConfigBase),
|
ServerBase: newServerBase(proto, conf.ConfigBase),
|
||||||
conf: conf,
|
conf: conf,
|
||||||
@ -262,10 +267,21 @@ func makePacketBuffer(size int) (f func() any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeDeadlineSetter is an interface for connections that can set write
|
||||||
|
// deadlines.
|
||||||
|
type writeDeadlineSetter interface {
|
||||||
|
SetWriteDeadline(t time.Time) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
// withWriteDeadline is a helper that takes the deadline of the context and the
|
// withWriteDeadline is a helper that takes the deadline of the context and the
|
||||||
// write timeout into account. It sets the write deadline on conn before
|
// write timeout into account. It sets the write deadline on conn before
|
||||||
// calling f and resets it once f is done.
|
// calling f and resets it once f is done.
|
||||||
func withWriteDeadline(ctx context.Context, writeTimeout time.Duration, conn net.Conn, f func()) {
|
func withWriteDeadline(
|
||||||
|
ctx context.Context,
|
||||||
|
writeTimeout time.Duration,
|
||||||
|
conn writeDeadlineSetter,
|
||||||
|
f func(),
|
||||||
|
) {
|
||||||
dl, hasDeadline := ctx.Deadline()
|
dl, hasDeadline := ctx.Deadline()
|
||||||
if !hasDeadline {
|
if !hasDeadline {
|
||||||
dl = time.Now().Add(writeTimeout)
|
dl = time.Now().Add(writeTimeout)
|
||||||
|
@ -284,8 +284,8 @@ func TestServerDNS_integration_query(t *testing.T) {
|
|||||||
tc.expectedTruncated,
|
tc.expectedTruncated,
|
||||||
)
|
)
|
||||||
|
|
||||||
reqKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
|
reqKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
|
||||||
respKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
|
respKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
|
||||||
if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil {
|
if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil {
|
||||||
require.NotNil(t, respKeepAliveOpt)
|
require.NotNil(t, respKeepAliveOpt)
|
||||||
expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100)
|
expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100)
|
||||||
|
@ -2,7 +2,9 @@ package dnsserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/ameshkov/dnscrypt/v2"
|
"github.com/ameshkov/dnscrypt/v2"
|
||||||
@ -38,6 +40,10 @@ var _ Server = (*ServerDNSCrypt)(nil)
|
|||||||
|
|
||||||
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
|
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
|
||||||
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
|
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
|
||||||
|
if conf.ListenConfig == nil {
|
||||||
|
conf.ListenConfig = netext.DefaultListenConfig()
|
||||||
|
}
|
||||||
|
|
||||||
return &ServerDNSCrypt{
|
return &ServerDNSCrypt{
|
||||||
ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase),
|
ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase),
|
||||||
conf: conf,
|
conf: conf,
|
||||||
@ -140,7 +146,10 @@ func (s *ServerDNSCrypt) startServeUDP(ctx context.Context) {
|
|||||||
// TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in
|
// TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in
|
||||||
// dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext
|
// dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext
|
||||||
// methods for now.
|
// methods for now.
|
||||||
err := s.dnsCryptServer.ServeUDP(s.udpListener)
|
//
|
||||||
|
// TODO(ameshkov): Redo the dnscrypt module to make it not depend on
|
||||||
|
// *net.UDPConn and use net.PacketConn instead.
|
||||||
|
err := s.dnsCryptServer.ServeUDP(s.udpListener.(*net.UDPConn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err)
|
log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err)
|
||||||
}
|
}
|
||||||
|
@ -4,23 +4,21 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// serveUDP runs the UDP serving loop.
|
// serveUDP runs the UDP serving loop.
|
||||||
func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) {
|
func (s *ServerDNS) serveUDP(ctx context.Context, conn net.PacketConn) (err error) {
|
||||||
defer log.OnCloserError(conn, log.DEBUG)
|
defer log.OnCloserError(conn, log.DEBUG)
|
||||||
|
|
||||||
for s.isStarted() {
|
for s.isStarted() {
|
||||||
var m []byte
|
var m []byte
|
||||||
var sess *dns.SessionUDP
|
var sess netext.PacketSession
|
||||||
m, sess, err = s.readUDPMsg(ctx, conn)
|
m, sess, err = s.readUDPMsg(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(ameshkov): Consider the situation where the server is shut
|
// TODO(ameshkov): Consider the situation where the server is shut
|
||||||
@ -59,15 +57,15 @@ func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error)
|
|||||||
func (s *ServerDNS) serveUDPPacket(
|
func (s *ServerDNS) serveUDPPacket(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
m []byte,
|
m []byte,
|
||||||
conn *net.UDPConn,
|
conn net.PacketConn,
|
||||||
udpSession *dns.SessionUDP,
|
sess netext.PacketSession,
|
||||||
) {
|
) {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
defer s.handlePanicAndRecover(ctx)
|
defer s.handlePanicAndRecover(ctx)
|
||||||
|
|
||||||
rw := &udpResponseWriter{
|
rw := &udpResponseWriter{
|
||||||
|
udpSession: sess,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
udpSession: udpSession,
|
|
||||||
writeTimeout: s.conf.WriteTimeout,
|
writeTimeout: s.conf.WriteTimeout,
|
||||||
}
|
}
|
||||||
s.serveDNS(ctx, m, rw)
|
s.serveDNS(ctx, m, rw)
|
||||||
@ -75,15 +73,18 @@ func (s *ServerDNS) serveUDPPacket(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// readUDPMsg reads the next incoming DNS message.
|
// readUDPMsg reads the next incoming DNS message.
|
||||||
func (s *ServerDNS) readUDPMsg(ctx context.Context, conn *net.UDPConn) (msg []byte, sess *dns.SessionUDP, err error) {
|
func (s *ServerDNS) readUDPMsg(
|
||||||
|
ctx context.Context,
|
||||||
|
conn net.PacketConn,
|
||||||
|
) (msg []byte, sess netext.PacketSession, err error) {
|
||||||
err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
|
err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m := s.getUDPBuffer()
|
m := s.getUDPBuffer()
|
||||||
var n int
|
|
||||||
n, sess, err = dns.ReadFromSessionUDP(conn, m)
|
n, sess, err := netext.ReadFromSession(conn, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.putUDPBuffer(m)
|
s.putUDPBuffer(m)
|
||||||
|
|
||||||
@ -120,30 +121,10 @@ func (s *ServerDNS) putUDPBuffer(m []byte) {
|
|||||||
s.udpPool.Put(&m)
|
s.udpPool.Put(&m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setUDPSocketOptions is a function that is necessary to be able to use
|
|
||||||
// dns.ReadFromSessionUDP and dns.WriteToSessionUDP.
|
|
||||||
// TODO(ameshkov): https://github.com/AdguardTeam/AdGuardHome/issues/2807
|
|
||||||
func setUDPSocketOptions(conn *net.UDPConn) (err error) {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
|
|
||||||
// Try enabling receiving of ECN and packet info for both IP versions.
|
|
||||||
// We expect at least one of those syscalls to succeed.
|
|
||||||
err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
|
|
||||||
err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
|
|
||||||
if err4 != nil && err6 != nil {
|
|
||||||
return errors.List("error while setting NetworkUDP socket options", err4, err6)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP.
|
// udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP.
|
||||||
type udpResponseWriter struct {
|
type udpResponseWriter struct {
|
||||||
udpSession *dns.SessionUDP
|
udpSession netext.PacketSession
|
||||||
conn *net.UDPConn
|
conn net.PacketConn
|
||||||
writeTimeout time.Duration
|
writeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,13 +133,15 @@ var _ ResponseWriter = (*udpResponseWriter)(nil)
|
|||||||
|
|
||||||
// LocalAddr implements the ResponseWriter interface for *udpResponseWriter.
|
// LocalAddr implements the ResponseWriter interface for *udpResponseWriter.
|
||||||
func (r *udpResponseWriter) LocalAddr() (addr net.Addr) {
|
func (r *udpResponseWriter) LocalAddr() (addr net.Addr) {
|
||||||
return r.conn.LocalAddr()
|
// Don't use r.conn.LocalAddr(), since udpSession may actually contain the
|
||||||
|
// decoded OOB data, including the real local (dst) address.
|
||||||
|
return r.udpSession.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr implements the ResponseWriter interface for *udpResponseWriter.
|
// RemoteAddr implements the ResponseWriter interface for *udpResponseWriter.
|
||||||
func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) {
|
func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) {
|
||||||
// Don't use r.conn.RemoteAddr(), since udpSession actually contains the
|
// Don't use r.conn.RemoteAddr(), since udpSession may actually contain the
|
||||||
// decoded OOB data, including the remote address.
|
// decoded OOB data, including the real remote (src) address.
|
||||||
return r.udpSession.RemoteAddr()
|
return r.udpSession.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,7 +156,7 @@ func (r *udpResponseWriter) WriteMsg(ctx context.Context, req, resp *dns.Msg) (e
|
|||||||
}
|
}
|
||||||
|
|
||||||
withWriteDeadline(ctx, r.writeTimeout, r.conn, func() {
|
withWriteDeadline(ctx, r.writeTimeout, r.conn, func() {
|
||||||
_, err = dns.WriteToSessionUDP(r.conn, data, r.udpSession)
|
_, err = netext.WriteToSession(r.conn, data, r.udpSession)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
@ -94,6 +95,12 @@ var _ Server = (*ServerHTTPS)(nil)
|
|||||||
|
|
||||||
// NewServerHTTPS creates a new ServerHTTPS instance.
|
// NewServerHTTPS creates a new ServerHTTPS instance.
|
||||||
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
|
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
|
||||||
|
if conf.ListenConfig == nil {
|
||||||
|
// Do not enable OOB here, because ListenPacket is only used by HTTP/3,
|
||||||
|
// and quic-go sets the necessary flags.
|
||||||
|
conf.ListenConfig = netext.DefaultListenConfig()
|
||||||
|
}
|
||||||
|
|
||||||
s = &ServerHTTPS{
|
s = &ServerHTTPS{
|
||||||
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
|
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
|
||||||
conf: conf,
|
conf: conf,
|
||||||
@ -500,8 +507,7 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
|
|||||||
tlsConf.NextProtos = nextProtoDoH3
|
tlsConf.NextProtos = nextProtoDoH3
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not enable OOB here as quic-go will do that on its own.
|
conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
|
||||||
conn, err := listenUDP(ctx, s.addr, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -518,24 +524,28 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpContextWithClientInfo adds client info to the context.
|
// httpContextWithClientInfo adds client info to the context. ctx is never nil,
|
||||||
|
// even when there is an error.
|
||||||
func httpContextWithClientInfo(
|
func httpContextWithClientInfo(
|
||||||
parent context.Context,
|
parent context.Context,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (ctx context.Context, err error) {
|
) (ctx context.Context, err error) {
|
||||||
|
ctx = parent
|
||||||
|
|
||||||
ci := ClientInfo{
|
ci := ClientInfo{
|
||||||
URL: netutil.CloneURL(r.URL),
|
URL: netutil.CloneURL(r.URL),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Due to the quic-go bug we should use Host instead of r.TLS:
|
// Due to the quic-go bug we should use Host instead of r.TLS. See
|
||||||
// https://github.com/lucas-clemente/quic-go/issues/3596
|
// https://github.com/quic-go/quic-go/issues/2879 and
|
||||||
|
// https://github.com/lucas-clemente/quic-go/issues/3596.
|
||||||
//
|
//
|
||||||
// TODO(ameshkov): remove this when the bug is fixed in quic-go.
|
// TODO(ameshkov): Remove when quic-go is fixed, likely in v0.32.0.
|
||||||
if r.ProtoAtLeast(3, 0) {
|
if r.ProtoAtLeast(3, 0) {
|
||||||
var host string
|
var host string
|
||||||
host, err = netutil.SplitHost(r.Host)
|
host, err = netutil.SplitHost(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse Host: %w", err)
|
return ctx, fmt.Errorf("failed to parse Host: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ci.TLSServerName = host
|
ci.TLSServerName = host
|
||||||
@ -543,7 +553,7 @@ func httpContextWithClientInfo(
|
|||||||
ci.TLSServerName = strings.ToLower(r.TLS.ServerName)
|
ci.TLSServerName = strings.ToLower(r.TLS.ServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ContextWithClientInfo(parent, ci), nil
|
return ContextWithClientInfo(ctx, ci), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpRequestToMsg reads the DNS message from http.Request.
|
// httpRequestToMsg reads the DNS message from http.Request.
|
||||||
|
@ -133,7 +133,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
|
|||||||
require.True(t, resp.Response)
|
require.True(t, resp.Response)
|
||||||
|
|
||||||
// EDNS0 padding is only present when request also has padding opt.
|
// EDNS0 padding is only present when request also has padding opt.
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.Nil(t, paddingOpt)
|
require.Nil(t, paddingOpt)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -338,7 +338,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
|
|||||||
resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req)
|
resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req)
|
||||||
require.True(t, resp.Response)
|
require.True(t, resp.Response)
|
||||||
|
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.NotNil(t, paddingOpt)
|
require.NotNil(t, paddingOpt)
|
||||||
require.NotEmpty(t, paddingOpt.Padding)
|
require.NotEmpty(t, paddingOpt.Padding)
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
@ -98,6 +99,11 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) {
|
|||||||
tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...)
|
tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if conf.ListenConfig == nil {
|
||||||
|
// Do not enable OOB here as quic-go will do that on its own.
|
||||||
|
conf.ListenConfig = netext.DefaultListenConfig()
|
||||||
|
}
|
||||||
|
|
||||||
s = &ServerQUIC{
|
s = &ServerQUIC{
|
||||||
ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase),
|
ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase),
|
||||||
conf: conf,
|
conf: conf,
|
||||||
@ -476,8 +482,7 @@ func (s *ServerQUIC) readQUICMsg(
|
|||||||
// listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts
|
// listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts
|
||||||
// the QUIC listener.
|
// the QUIC listener.
|
||||||
func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) {
|
func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) {
|
||||||
// Do not enable OOB here as quic-go will do that on its own.
|
conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
|
||||||
conn, err := listenUDP(ctx, s.addr, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
|
|||||||
assert.True(t, resp.Response)
|
assert.True(t, resp.Response)
|
||||||
|
|
||||||
// EDNS0 padding is only present when request also has padding opt.
|
// EDNS0 padding is only present when request also has padding opt.
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.Nil(t, paddingOpt)
|
require.Nil(t, paddingOpt)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@ -97,7 +97,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
|
|||||||
require.True(t, resp.Response)
|
require.True(t, resp.Response)
|
||||||
require.False(t, resp.Truncated)
|
require.False(t, resp.Truncated)
|
||||||
|
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.NotNil(t, paddingOpt)
|
require.NotNil(t, paddingOpt)
|
||||||
require.NotEmpty(t, paddingOpt.Padding)
|
require.NotEmpty(t, paddingOpt.Padding)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package dnsserver
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
@ -102,10 +101,9 @@ func (s *ServerTLS) startServeTCP(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenTLS creates the TLS listener for the ServerTLS.addr.
|
// listenTLS creates the TLS listener for s.addr.
|
||||||
func (s *ServerTLS) listenTLS(ctx context.Context) (err error) {
|
func (s *ServerTLS) listenTLS(ctx context.Context) (err error) {
|
||||||
var l net.Listener
|
l, err := s.listenConfig.Listen(ctx, "tcp", s.addr)
|
||||||
l, err = listenTCP(ctx, s.addr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ func TestServerTLS_integration_queryTLS(t *testing.T) {
|
|||||||
require.False(t, resp.Truncated)
|
require.False(t, resp.Truncated)
|
||||||
|
|
||||||
// EDNS0 padding is only present when request also has padding opt.
|
// EDNS0 padding is only present when request also has padding opt.
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.Nil(t, paddingOpt)
|
require.Nil(t, paddingOpt)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
|
|||||||
require.False(t, resp.Truncated)
|
require.False(t, resp.Truncated)
|
||||||
|
|
||||||
// EDNS0 padding is only present when request also has padding opt.
|
// EDNS0 padding is only present when request also has padding opt.
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.Nil(t, paddingOpt)
|
require.Nil(t, paddingOpt)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
|
|||||||
require.True(t, resp.Response)
|
require.True(t, resp.Response)
|
||||||
require.False(t, resp.Truncated)
|
require.False(t, resp.Truncated)
|
||||||
|
|
||||||
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
|
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
|
||||||
require.NotNil(t, paddingOpt)
|
require.NotNil(t, paddingOpt)
|
||||||
require.NotEmpty(t, paddingOpt.Padding)
|
require.NotEmpty(t, paddingOpt.Padding)
|
||||||
}
|
}
|
||||||
|
172
internal/dnssvc/debug_internal_test.go
Normal file
172
internal/dnssvc/debug_internal_test.go
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package dnssvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTXTExtra is a helper function that converts strs into DNS TXT resource
|
||||||
|
// records with Name and Txt fields set to first and second values of each
|
||||||
|
// tuple.
|
||||||
|
func newTXTExtra(strs [][2]string) (extra []dns.RR) {
|
||||||
|
for _, v := range strs {
|
||||||
|
extra = append(extra, &dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: v[0],
|
||||||
|
Rrtype: dns.TypeTXT,
|
||||||
|
Class: dns.ClassCHAOS,
|
||||||
|
Ttl: 1,
|
||||||
|
},
|
||||||
|
Txt: []string{v[1]},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return extra
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_writeDebugResponse(t *testing.T) {
|
||||||
|
svc := &Service{messages: &dnsmsg.Constructor{FilteredResponseTTL: time.Second}}
|
||||||
|
|
||||||
|
const (
|
||||||
|
fltListID1 agd.FilterListID = "fl1"
|
||||||
|
fltListID2 agd.FilterListID = "fl2"
|
||||||
|
|
||||||
|
blockRule = "||example.com^"
|
||||||
|
)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
ri *agd.RequestInfo
|
||||||
|
reqRes filter.Result
|
||||||
|
respRes filter.Result
|
||||||
|
wantExtra []dns.RR
|
||||||
|
}{{
|
||||||
|
name: "normal",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "request_result_blocked",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule},
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"req.res-type.adguard-dns.com.", "blocked"},
|
||||||
|
{"req.rule.adguard-dns.com.", "||example.com^"},
|
||||||
|
{"req.rule-list-id.adguard-dns.com.", "fl1"},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "response_result_blocked",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule},
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "blocked"},
|
||||||
|
{"resp.rule.adguard-dns.com.", "||example.com^"},
|
||||||
|
{"resp.rule-list-id.adguard-dns.com.", "fl2"},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "request_result_allowed",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: &filter.ResultAllowed{},
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"req.res-type.adguard-dns.com.", "allowed"},
|
||||||
|
{"req.rule.adguard-dns.com.", ""},
|
||||||
|
{"req.rule-list-id.adguard-dns.com.", ""},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "response_result_allowed",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: &filter.ResultAllowed{},
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "allowed"},
|
||||||
|
{"resp.rule.adguard-dns.com.", ""},
|
||||||
|
{"resp.rule-list-id.adguard-dns.com.", ""},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "request_result_modified",
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
reqRes: &filter.ResultModified{
|
||||||
|
Rule: "||example.com^$dnsrewrite=REFUSED",
|
||||||
|
},
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"req.res-type.adguard-dns.com.", "modified"},
|
||||||
|
{"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"},
|
||||||
|
{"req.rule-list-id.adguard-dns.com.", ""},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "device",
|
||||||
|
ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"device-id.adguard-dns.com.", "dev1234"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "profile",
|
||||||
|
ri: &agd.RequestInfo{
|
||||||
|
Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")},
|
||||||
|
},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"profile-id.adguard-dns.com.", "some-profile-id"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||||
|
}),
|
||||||
|
}, {
|
||||||
|
name: "location",
|
||||||
|
ri: &agd.RequestInfo{Location: &agd.Location{Country: agd.CountryAD}},
|
||||||
|
reqRes: nil,
|
||||||
|
respRes: nil,
|
||||||
|
wantExtra: newTXTExtra([][2]string{
|
||||||
|
{"client-ip.adguard-dns.com.", "1.2.3.4"},
|
||||||
|
{"country.adguard-dns.com.", "AD"},
|
||||||
|
{"asn.adguard-dns.com.", "0"},
|
||||||
|
{"resp.res-type.adguard-dns.com.", "normal"},
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||||
|
|
||||||
|
ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)
|
||||||
|
|
||||||
|
req := dnsservertest.NewReq("example.com", dns.TypeA, dns.ClassINET)
|
||||||
|
resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
|
||||||
|
|
||||||
|
err := svc.writeDebugResponse(ctx, rw, req, resp, tc.reqRes, tc.respRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := rw.Msg()
|
||||||
|
require.NotNil(t, msg)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantExtra, msg.Extra)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -195,7 +195,15 @@ func deviceIDFromEDNS(req *dns.Msg) (id agd.DeviceID, err error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return agd.NewDeviceID(string(o.Data))
|
id, err = agd.NewDeviceID(string(o.Data))
|
||||||
|
if err != nil {
|
||||||
|
return "", &deviceIDError{
|
||||||
|
err: err,
|
||||||
|
typ: "edns option",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
|
@ -233,7 +233,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
|
|||||||
Data: []byte{},
|
Data: []byte{},
|
||||||
},
|
},
|
||||||
wantDeviceID: "",
|
wantDeviceID: "",
|
||||||
wantErrMsg: `bad device id "": too short: got 0 bytes, min 1`,
|
wantErrMsg: `edns option device id check: bad device id "": ` +
|
||||||
|
`too short: got 0 bytes, min 1`,
|
||||||
}, {
|
}, {
|
||||||
name: "bad_device_id",
|
name: "bad_device_id",
|
||||||
opt: &dns.EDNS0_LOCAL{
|
opt: &dns.EDNS0_LOCAL{
|
||||||
@ -241,7 +242,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
|
|||||||
Data: []byte("toolongdeviceid"),
|
Data: []byte("toolongdeviceid"),
|
||||||
},
|
},
|
||||||
wantDeviceID: "",
|
wantDeviceID: "",
|
||||||
wantErrMsg: `bad device id "toolongdeviceid": too long: got 15 bytes, max 8`,
|
wantErrMsg: `edns option device id check: bad device id "toolongdeviceid": ` +
|
||||||
|
`too long: got 15 bytes, max 8`,
|
||||||
}, {
|
}, {
|
||||||
name: "device_id",
|
name: "device_id",
|
||||||
opt: &dns.EDNS0_LOCAL{
|
opt: &dns.EDNS0_LOCAL{
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
|
||||||
@ -79,10 +78,6 @@ type Config struct {
|
|||||||
// Upstream defines the upstream server and the group of fallback servers.
|
// Upstream defines the upstream server and the group of fallback servers.
|
||||||
Upstream *agd.Upstream
|
Upstream *agd.Upstream
|
||||||
|
|
||||||
// RootRedirectURL is the URL to which non-DNS and non-Debug HTTP requests
|
|
||||||
// are redirected.
|
|
||||||
RootRedirectURL *url.URL
|
|
||||||
|
|
||||||
// NewListener, when set, is used instead of the package-level function
|
// NewListener, when set, is used instead of the package-level function
|
||||||
// NewListener when creating a DNS listener.
|
// NewListener when creating a DNS listener.
|
||||||
//
|
//
|
||||||
@ -119,6 +114,11 @@ type Config struct {
|
|||||||
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
|
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
|
||||||
// used.
|
// used.
|
||||||
UseECSCache bool
|
UseECSCache bool
|
||||||
|
|
||||||
|
// ResearchMetrics controls whether research metrics are enabled or not.
|
||||||
|
// This is a set of metrics that we may need temporary, so its collection is
|
||||||
|
// controlled by a separate setting.
|
||||||
|
ResearchMetrics bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new DNS service.
|
// New returns a new DNS service.
|
||||||
@ -150,7 +150,6 @@ func New(c *Config) (svc *Service, err error) {
|
|||||||
groups := make([]*serverGroup, len(c.ServerGroups))
|
groups := make([]*serverGroup, len(c.ServerGroups))
|
||||||
svc = &Service{
|
svc = &Service{
|
||||||
messages: c.Messages,
|
messages: c.Messages,
|
||||||
rootRedirectURL: c.RootRedirectURL,
|
|
||||||
billStat: c.BillStat,
|
billStat: c.BillStat,
|
||||||
errColl: c.ErrColl,
|
errColl: c.ErrColl,
|
||||||
fltStrg: c.FilterStorage,
|
fltStrg: c.FilterStorage,
|
||||||
@ -158,6 +157,7 @@ func New(c *Config) (svc *Service, err error) {
|
|||||||
queryLog: c.QueryLog,
|
queryLog: c.QueryLog,
|
||||||
ruleStat: c.RuleStat,
|
ruleStat: c.RuleStat,
|
||||||
groups: groups,
|
groups: groups,
|
||||||
|
researchMetrics: c.ResearchMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, srvGrp := range c.ServerGroups {
|
for i, srvGrp := range c.ServerGroups {
|
||||||
@ -212,7 +212,6 @@ var _ agd.Service = (*Service)(nil)
|
|||||||
// Service is the main DNS service of AdGuard DNS.
|
// Service is the main DNS service of AdGuard DNS.
|
||||||
type Service struct {
|
type Service struct {
|
||||||
messages *dnsmsg.Constructor
|
messages *dnsmsg.Constructor
|
||||||
rootRedirectURL *url.URL
|
|
||||||
billStat billstat.Recorder
|
billStat billstat.Recorder
|
||||||
errColl agd.ErrorCollector
|
errColl agd.ErrorCollector
|
||||||
fltStrg filter.Storage
|
fltStrg filter.Storage
|
||||||
@ -220,6 +219,7 @@ type Service struct {
|
|||||||
queryLog querylog.Interface
|
queryLog querylog.Interface
|
||||||
ruleStat rulestat.Interface
|
ruleStat rulestat.Interface
|
||||||
groups []*serverGroup
|
groups []*serverGroup
|
||||||
|
researchMetrics bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustStartListener starts l and panics on any error.
|
// mustStartListener starts l and panics on any error.
|
||||||
|
@ -260,14 +260,6 @@ func (mh *initMwHandler) ServeDNS(
|
|||||||
// Copy middleware to the local variable to make the code simpler.
|
// Copy middleware to the local variable to make the code simpler.
|
||||||
mw := mh.mw
|
mw := mh.mw
|
||||||
|
|
||||||
if specHdlr, name := mw.noReqInfoSpecialHandler(fqdn, qt, cl); specHdlr != nil {
|
|
||||||
optlog.Debug1("init mw: got no-req-info special handler %s", name)
|
|
||||||
|
|
||||||
// Don't wrap the error, because it's informative enough as is, and
|
|
||||||
// because if handled is true, the main flow terminates here.
|
|
||||||
return specHdlr(ctx, rw, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the request's information, such as GeoIP data and user profiles.
|
// Get the request's information, such as GeoIP data and user profiles.
|
||||||
ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
|
ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -277,7 +277,7 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
@ -287,7 +287,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked bool
|
profBlocked bool
|
||||||
wantRCode dnsmsg.RCode
|
wantRCode dnsmsg.RCode
|
||||||
}{{
|
}{{
|
||||||
name: "blocked_by_fltgrp",
|
name: "private_relay_blocked_by_fltgrp",
|
||||||
host: applePrivateRelayMaskHost,
|
host: applePrivateRelayMaskHost,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: true,
|
fltGrpBlocked: true,
|
||||||
@ -295,7 +295,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked: false,
|
profBlocked: false,
|
||||||
wantRCode: dns.RcodeNameError,
|
wantRCode: dns.RcodeNameError,
|
||||||
}, {
|
}, {
|
||||||
name: "no_private_relay_domain",
|
name: "no_special_domain",
|
||||||
host: "www.example.com",
|
host: "www.example.com",
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: true,
|
fltGrpBlocked: true,
|
||||||
@ -311,7 +311,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked: false,
|
profBlocked: false,
|
||||||
wantRCode: dns.RcodeSuccess,
|
wantRCode: dns.RcodeSuccess,
|
||||||
}, {
|
}, {
|
||||||
name: "blocked_by_prof",
|
name: "private_relay_blocked_by_prof",
|
||||||
host: applePrivateRelayMaskHost,
|
host: applePrivateRelayMaskHost,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: false,
|
fltGrpBlocked: false,
|
||||||
@ -319,7 +319,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked: true,
|
profBlocked: true,
|
||||||
wantRCode: dns.RcodeNameError,
|
wantRCode: dns.RcodeNameError,
|
||||||
}, {
|
}, {
|
||||||
name: "allowed_by_prof",
|
name: "private_relay_allowed_by_prof",
|
||||||
host: applePrivateRelayMaskHost,
|
host: applePrivateRelayMaskHost,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: true,
|
fltGrpBlocked: true,
|
||||||
@ -327,7 +327,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked: false,
|
profBlocked: false,
|
||||||
wantRCode: dns.RcodeSuccess,
|
wantRCode: dns.RcodeSuccess,
|
||||||
}, {
|
}, {
|
||||||
name: "allowed_by_both",
|
name: "private_relay_allowed_by_both",
|
||||||
host: applePrivateRelayMaskHost,
|
host: applePrivateRelayMaskHost,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: false,
|
fltGrpBlocked: false,
|
||||||
@ -335,13 +335,45 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
profBlocked: false,
|
profBlocked: false,
|
||||||
wantRCode: dns.RcodeSuccess,
|
wantRCode: dns.RcodeSuccess,
|
||||||
}, {
|
}, {
|
||||||
name: "blocked_by_both",
|
name: "private_relay_blocked_by_both",
|
||||||
host: applePrivateRelayMaskHost,
|
host: applePrivateRelayMaskHost,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
fltGrpBlocked: true,
|
fltGrpBlocked: true,
|
||||||
hasProf: true,
|
hasProf: true,
|
||||||
profBlocked: true,
|
profBlocked: true,
|
||||||
wantRCode: dns.RcodeNameError,
|
wantRCode: dns.RcodeNameError,
|
||||||
|
}, {
|
||||||
|
name: "firefox_canary_allowed_by_prof",
|
||||||
|
host: firefoxCanaryHost,
|
||||||
|
qtype: dns.TypeA,
|
||||||
|
fltGrpBlocked: false,
|
||||||
|
hasProf: true,
|
||||||
|
profBlocked: false,
|
||||||
|
wantRCode: dns.RcodeSuccess,
|
||||||
|
}, {
|
||||||
|
name: "firefox_canary_allowed_by_fltgrp",
|
||||||
|
host: firefoxCanaryHost,
|
||||||
|
qtype: dns.TypeA,
|
||||||
|
fltGrpBlocked: false,
|
||||||
|
hasProf: false,
|
||||||
|
profBlocked: false,
|
||||||
|
wantRCode: dns.RcodeSuccess,
|
||||||
|
}, {
|
||||||
|
name: "firefox_canary_blocked_by_prof",
|
||||||
|
host: firefoxCanaryHost,
|
||||||
|
qtype: dns.TypeA,
|
||||||
|
fltGrpBlocked: false,
|
||||||
|
hasProf: true,
|
||||||
|
profBlocked: true,
|
||||||
|
wantRCode: dns.RcodeRefused,
|
||||||
|
}, {
|
||||||
|
name: "firefox_canary_blocked_by_fltgrp",
|
||||||
|
host: firefoxCanaryHost,
|
||||||
|
qtype: dns.TypeA,
|
||||||
|
fltGrpBlocked: true,
|
||||||
|
hasProf: false,
|
||||||
|
profBlocked: false,
|
||||||
|
wantRCode: dns.RcodeRefused,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@ -368,9 +400,12 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
return nil, nil, agd.DeviceNotFoundError{}
|
return nil, nil, agd.DeviceNotFoundError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &agd.Profile{
|
prof := &agd.Profile{
|
||||||
BlockPrivateRelay: tc.profBlocked,
|
BlockPrivateRelay: tc.profBlocked,
|
||||||
}, &agd.Device{}, nil
|
BlockFirefoxCanary: tc.profBlocked,
|
||||||
|
}
|
||||||
|
|
||||||
|
return prof, &agd.Device{}, nil
|
||||||
}
|
}
|
||||||
db := &agdtest.ProfileDB{
|
db := &agdtest.ProfileDB{
|
||||||
OnProfileByDeviceID: func(
|
OnProfileByDeviceID: func(
|
||||||
@ -406,7 +441,8 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
FilteredResponseTTL: 10 * time.Second,
|
FilteredResponseTTL: 10 * time.Second,
|
||||||
},
|
},
|
||||||
fltGrp: &agd.FilteringGroup{
|
fltGrp: &agd.FilteringGroup{
|
||||||
BlockPrivateRelay: tc.fltGrpBlocked,
|
BlockPrivateRelay: tc.fltGrpBlocked,
|
||||||
|
BlockFirefoxCanary: tc.fltGrpBlocked,
|
||||||
},
|
},
|
||||||
srvGrp: &agd.ServerGroup{},
|
srvGrp: &agd.ServerGroup{},
|
||||||
srv: &agd.Server{
|
srv: &agd.Server{
|
||||||
@ -436,7 +472,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
|
|||||||
resp := rw.Msg()
|
resp := rw.Msg()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
|
|
||||||
assert.Equal(t, dnsmsg.RCode(resp.Rcode), tc.wantRCode)
|
assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,7 +101,7 @@ func (svc *Service) filterQuery(
|
|||||||
) (reqRes, respRes filter.Result) {
|
) (reqRes, respRes filter.Result) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
reportMetrics(ri, reqRes, respRes, time.Since(start))
|
svc.reportMetrics(ri, reqRes, respRes, time.Since(start))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
f := svc.fltStrg.FilterFromContext(ctx, ri)
|
f := svc.fltStrg.FilterFromContext(ctx, ri)
|
||||||
@ -120,7 +120,7 @@ func (svc *Service) filterQuery(
|
|||||||
|
|
||||||
// reportMetrics extracts filtering metrics data from the context and reports it
|
// reportMetrics extracts filtering metrics data from the context and reports it
|
||||||
// to Prometheus.
|
// to Prometheus.
|
||||||
func reportMetrics(
|
func (svc *Service) reportMetrics(
|
||||||
ri *agd.RequestInfo,
|
ri *agd.RequestInfo,
|
||||||
reqRes filter.Result,
|
reqRes filter.Result,
|
||||||
respRes filter.Result,
|
respRes filter.Result,
|
||||||
@ -139,7 +139,7 @@ func reportMetrics(
|
|||||||
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
|
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
|
||||||
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
|
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
|
||||||
|
|
||||||
id, _, _ := filteringData(reqRes, respRes)
|
id, _, blocked := filteringData(reqRes, respRes)
|
||||||
metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
|
metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
|
||||||
string(id),
|
string(id),
|
||||||
metrics.BoolString(ri.Profile == nil),
|
metrics.BoolString(ri.Profile == nil),
|
||||||
@ -147,6 +147,22 @@ func reportMetrics(
|
|||||||
|
|
||||||
metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds())
|
metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds())
|
||||||
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
|
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
|
||||||
|
|
||||||
|
if svc.researchMetrics {
|
||||||
|
anonymous := ri.Profile == nil
|
||||||
|
filteringEnabled := ri.FilteringGroup != nil &&
|
||||||
|
ri.FilteringGroup.RuleListsEnabled &&
|
||||||
|
len(ri.FilteringGroup.RuleListIDs) > 0
|
||||||
|
|
||||||
|
metrics.ReportResearchMetrics(
|
||||||
|
anonymous,
|
||||||
|
filteringEnabled,
|
||||||
|
asn,
|
||||||
|
ctry,
|
||||||
|
string(id),
|
||||||
|
blocked,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reportf is a helper method for reporting non-critical errors.
|
// reportf is a helper method for reporting non-critical errors.
|
||||||
|
135
internal/dnssvc/presvcmw_test.go
Normal file
135
internal/dnssvc/presvcmw_test.go
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
package dnssvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
|
||||||
|
const safeBrowsingHost = "scam.example.net."
|
||||||
|
|
||||||
|
var (
|
||||||
|
ip = net.IP{127, 0, 0, 1}
|
||||||
|
name = "example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
sum := sha256.Sum256([]byte(safeBrowsingHost))
|
||||||
|
hashStr := hex.EncodeToString(sum[:])
|
||||||
|
hashes, herr := hashstorage.New(safeBrowsingHost)
|
||||||
|
require.NoError(t, herr)
|
||||||
|
|
||||||
|
srv := filter.NewSafeBrowsingServer(hashes, nil)
|
||||||
|
host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
|
||||||
|
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
|
||||||
|
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
|
||||||
|
|
||||||
|
const ttl = 60
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
req *dns.Msg
|
||||||
|
dnscheckResp *dns.Msg
|
||||||
|
ri *agd.RequestInfo
|
||||||
|
wantAns []dns.RR
|
||||||
|
}{{
|
||||||
|
name: "normal",
|
||||||
|
req: dnsservertest.CreateMessage(name, dns.TypeA),
|
||||||
|
dnscheckResp: nil,
|
||||||
|
ri: &agd.RequestInfo{},
|
||||||
|
wantAns: []dns.RR{
|
||||||
|
dnsservertest.NewA(name, 100, ip),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "dnscheck",
|
||||||
|
req: dnsservertest.CreateMessage(name, dns.TypeA),
|
||||||
|
dnscheckResp: dnsservertest.NewResp(
|
||||||
|
dns.RcodeSuccess,
|
||||||
|
dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET),
|
||||||
|
dnsservertest.RRSection{
|
||||||
|
RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)},
|
||||||
|
Sec: dnsservertest.SectionAnswer,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ri: &agd.RequestInfo{
|
||||||
|
Host: name,
|
||||||
|
QType: dns.TypeA,
|
||||||
|
QClass: dns.ClassINET,
|
||||||
|
},
|
||||||
|
wantAns: []dns.RR{
|
||||||
|
dnsservertest.NewA(name, ttl, ip),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "with_hashes",
|
||||||
|
req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT),
|
||||||
|
dnscheckResp: nil,
|
||||||
|
ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT},
|
||||||
|
wantAns: []dns.RR{&dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: safeBrowsingHost,
|
||||||
|
Rrtype: dns.TypeTXT,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: ttl,
|
||||||
|
},
|
||||||
|
Txt: []string{hashStr},
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
name: "not_matched",
|
||||||
|
req: dnsservertest.CreateMessage(name, dns.TypeTXT),
|
||||||
|
dnscheckResp: nil,
|
||||||
|
ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT},
|
||||||
|
wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||||
|
tctx := agd.ContextWithRequestInfo(ctx, tc.ri)
|
||||||
|
|
||||||
|
dnsCk := &agdtest.DNSCheck{
|
||||||
|
OnCheck: func(
|
||||||
|
ctx context.Context,
|
||||||
|
msg *dns.Msg,
|
||||||
|
ri *agd.RequestInfo,
|
||||||
|
) (resp *dns.Msg, err error) {
|
||||||
|
return tc.dnscheckResp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mw := &preServiceMw{
|
||||||
|
messages: &dnsmsg.Constructor{
|
||||||
|
FilteredResponseTTL: ttl * time.Second,
|
||||||
|
},
|
||||||
|
filter: srv,
|
||||||
|
checker: dnsCk,
|
||||||
|
}
|
||||||
|
handler := dnsservertest.DefaultHandler()
|
||||||
|
h := mw.Wrap(handler)
|
||||||
|
|
||||||
|
err := h.ServeDNS(tctx, rw, tc.req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := rw.Msg()
|
||||||
|
require.NotNil(t, msg)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantAns, msg.Answer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
264
internal/dnssvc/preupstreammw_test.go
Normal file
264
internal/dnssvc/preupstreammw_test.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
package dnssvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
reqHostname = "example.com."
|
||||||
|
defaultTTL = 3600
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) {
|
||||||
|
remoteIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
|
||||||
|
respIP := remoteIP.AsSlice()
|
||||||
|
|
||||||
|
resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
|
||||||
|
RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)},
|
||||||
|
Sec: dnsservertest.SectionAnswer,
|
||||||
|
})
|
||||||
|
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
|
||||||
|
Host: aReq.Question[0].Name,
|
||||||
|
})
|
||||||
|
|
||||||
|
const N = 5
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cacheSize int
|
||||||
|
wantNumReq int
|
||||||
|
}{{
|
||||||
|
name: "no_cache",
|
||||||
|
cacheSize: 0,
|
||||||
|
wantNumReq: N,
|
||||||
|
}, {
|
||||||
|
name: "with_cache",
|
||||||
|
cacheSize: 100,
|
||||||
|
wantNumReq: 1,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
numReq := 0
|
||||||
|
handler := dnsserver.HandlerFunc(func(
|
||||||
|
ctx context.Context,
|
||||||
|
rw dnsserver.ResponseWriter,
|
||||||
|
req *dns.Msg,
|
||||||
|
) error {
|
||||||
|
numReq++
|
||||||
|
|
||||||
|
return rw.WriteMsg(ctx, req, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := &preUpstreamMw{
|
||||||
|
db: dnsdb.Empty{},
|
||||||
|
cacheSize: tc.cacheSize,
|
||||||
|
}
|
||||||
|
h := mw.Wrap(handler)
|
||||||
|
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
|
||||||
|
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
|
||||||
|
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
|
||||||
|
|
||||||
|
err := h.ServeDNS(ctx, nrw, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantNumReq, numReq)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) {
|
||||||
|
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
|
||||||
|
remoteIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
subnet := netip.MustParsePrefix("1.2.3.4/24")
|
||||||
|
|
||||||
|
const ctry = agd.CountryAD
|
||||||
|
|
||||||
|
resp := dnsservertest.NewResp(
|
||||||
|
dns.RcodeSuccess,
|
||||||
|
aReq,
|
||||||
|
dnsservertest.RRSection{
|
||||||
|
RRs: []dns.RR{dnsservertest.NewA(
|
||||||
|
reqHostname,
|
||||||
|
defaultTTL,
|
||||||
|
net.IP{1, 2, 3, 4},
|
||||||
|
)},
|
||||||
|
Sec: dnsservertest.SectionAnswer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
numReq := 0
|
||||||
|
handler := dnsserver.HandlerFunc(
|
||||||
|
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
|
||||||
|
numReq++
|
||||||
|
|
||||||
|
return rw.WriteMsg(ctx, req, resp)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
geoIP := &agdtest.GeoIP{
|
||||||
|
OnSubnetByLocation: func(
|
||||||
|
ctry agd.Country,
|
||||||
|
_ agd.ASN,
|
||||||
|
_ netutil.AddrFamily,
|
||||||
|
) (n netip.Prefix, err error) {
|
||||||
|
return netip.MustParsePrefix("1.2.0.0/16"), nil
|
||||||
|
},
|
||||||
|
OnData: func(_ string, _ netip.Addr) (_ *agd.Location, _ error) {
|
||||||
|
panic("not implemented")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mw := &preUpstreamMw{
|
||||||
|
db: dnsdb.Empty{},
|
||||||
|
geoIP: geoIP,
|
||||||
|
cacheSize: 100,
|
||||||
|
ecsCacheSize: 100,
|
||||||
|
useECSCache: true,
|
||||||
|
}
|
||||||
|
h := mw.Wrap(handler)
|
||||||
|
|
||||||
|
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
|
||||||
|
Location: &agd.Location{
|
||||||
|
Country: ctry,
|
||||||
|
},
|
||||||
|
ECS: &agd.ECS{
|
||||||
|
Location: &agd.Location{
|
||||||
|
Country: ctry,
|
||||||
|
},
|
||||||
|
Subnet: subnet,
|
||||||
|
Scope: 0,
|
||||||
|
},
|
||||||
|
Host: aReq.Question[0].Name,
|
||||||
|
RemoteIP: remoteIP,
|
||||||
|
})
|
||||||
|
|
||||||
|
const N = 5
|
||||||
|
var nrw *dnsserver.NonWriterResponseWriter
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
|
||||||
|
nrw = dnsserver.NewNonWriterResponseWriter(addr, addr)
|
||||||
|
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
|
||||||
|
|
||||||
|
err := h.ServeDNS(ctx, nrw, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1, numReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
|
||||||
|
mw := &preUpstreamMw{db: dnsdb.Empty{}}
|
||||||
|
|
||||||
|
req := dnsservertest.CreateMessage("example.com", dns.TypeA)
|
||||||
|
resp := new(dns.Msg).SetReply(req)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
|
||||||
|
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
|
||||||
|
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
|
||||||
|
ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{})
|
||||||
|
|
||||||
|
const ttl = 100
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
req *dns.Msg
|
||||||
|
resp *dns.Msg
|
||||||
|
wantName string
|
||||||
|
wantAns []dns.RR
|
||||||
|
}{{
|
||||||
|
name: "no_changes",
|
||||||
|
req: dnsservertest.CreateMessage("example.com.", dns.TypeA),
|
||||||
|
resp: resp,
|
||||||
|
wantName: "example.com.",
|
||||||
|
wantAns: nil,
|
||||||
|
}, {
|
||||||
|
name: "android-tls-metric",
|
||||||
|
req: dnsservertest.CreateMessage(
|
||||||
|
"12345678-dnsotls-ds.metric.gstatic.com.",
|
||||||
|
dns.TypeA,
|
||||||
|
),
|
||||||
|
resp: resp,
|
||||||
|
wantName: "00000000-dnsotls-ds.metric.gstatic.com.",
|
||||||
|
wantAns: nil,
|
||||||
|
}, {
|
||||||
|
name: "android-https-metric",
|
||||||
|
req: dnsservertest.CreateMessage(
|
||||||
|
"123456-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
dns.TypeA,
|
||||||
|
),
|
||||||
|
resp: resp,
|
||||||
|
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
wantAns: nil,
|
||||||
|
}, {
|
||||||
|
name: "multiple_answers_metric",
|
||||||
|
req: dnsservertest.CreateMessage(
|
||||||
|
"123456-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
dns.TypeA,
|
||||||
|
),
|
||||||
|
resp: dnsservertest.NewResp(
|
||||||
|
dns.RcodeSuccess,
|
||||||
|
req,
|
||||||
|
dnsservertest.RRSection{
|
||||||
|
RRs: []dns.RR{dnsservertest.NewA(
|
||||||
|
"123456-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
ttl,
|
||||||
|
net.IP{1, 2, 3, 4},
|
||||||
|
), dnsservertest.NewA(
|
||||||
|
"654321-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
ttl,
|
||||||
|
net.IP{1, 2, 3, 5},
|
||||||
|
)},
|
||||||
|
Sec: dnsservertest.SectionAnswer,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
|
||||||
|
wantAns: []dns.RR{
|
||||||
|
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}),
|
||||||
|
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
handler := dnsserver.HandlerFunc(func(
|
||||||
|
ctx context.Context,
|
||||||
|
rw dnsserver.ResponseWriter,
|
||||||
|
req *dns.Msg,
|
||||||
|
) error {
|
||||||
|
assert.Equal(t, tc.wantName, req.Question[0].Name)
|
||||||
|
|
||||||
|
return rw.WriteMsg(ctx, req, tc.resp)
|
||||||
|
})
|
||||||
|
h := mw.Wrap(handler)
|
||||||
|
|
||||||
|
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
|
||||||
|
|
||||||
|
err := h.ServeDNS(ctx, rw, tc.req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := rw.Msg()
|
||||||
|
require.NotNil(t, msg)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantAns, msg.Answer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -32,57 +32,23 @@ const (
|
|||||||
// Resolvers for querying the resolver with unknown or absent name.
|
// Resolvers for querying the resolver with unknown or absent name.
|
||||||
ddrDomain = ddrLabel + "." + resolverArpaDomain
|
ddrDomain = ddrLabel + "." + resolverArpaDomain
|
||||||
|
|
||||||
// firefoxCanaryFQDN is the fully-qualified canary domain that Firefox uses
|
// firefoxCanaryHost is the hostname that Firefox uses to check if it
|
||||||
// to check if it should use its own DNS-over-HTTPS settings.
|
// should use its own DNS-over-HTTPS settings.
|
||||||
//
|
//
|
||||||
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
|
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
|
||||||
firefoxCanaryFQDN = "use-application-dns.net."
|
firefoxCanaryHost = "use-application-dns.net"
|
||||||
|
|
||||||
// applePrivateRelayMaskHost and applePrivateRelayMaskH2Host are the
|
|
||||||
// hostnames that Apple devices use to check if Apple Private Relay can be
|
|
||||||
// enabled. Returning NXDOMAIN to queries for these domain names blocks
|
|
||||||
// Apple Private Relay.
|
|
||||||
//
|
|
||||||
// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
|
|
||||||
applePrivateRelayMaskHost = "mask.icloud.com"
|
|
||||||
applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// noReqInfoSpecialHandler returns a handler that can handle a special-domain
|
// Hostnames that Apple devices use to check if Apple Private Relay can be
|
||||||
// query based only on its question type, class, and target, as well as the
|
// enabled. Returning NXDOMAIN to queries for these domain names blocks Apple
|
||||||
// handler's name for debugging.
|
// Private Relay.
|
||||||
func (mw *initMw) noReqInfoSpecialHandler(
|
//
|
||||||
fqdn string,
|
// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
|
||||||
qt dnsmsg.RRType,
|
const (
|
||||||
cl dnsmsg.Class,
|
applePrivateRelayMaskHost = "mask.icloud.com"
|
||||||
) (f dnsserver.HandlerFunc, name string) {
|
applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
|
||||||
if cl != dns.ClassINET {
|
applePrivateRelayMaskCanaryHost = "mask-canary.icloud.com"
|
||||||
return nil, ""
|
)
|
||||||
}
|
|
||||||
|
|
||||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && fqdn == firefoxCanaryFQDN {
|
|
||||||
return mw.handleFirefoxCanary, "firefox"
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Firefox Canary
|
|
||||||
|
|
||||||
// handleFirefoxCanary checks if the request is for the fully-qualified domain
|
|
||||||
// name that Firefox uses to check DoH settings and writes a response if needed.
|
|
||||||
func (mw *initMw) handleFirefoxCanary(
|
|
||||||
ctx context.Context,
|
|
||||||
rw dnsserver.ResponseWriter,
|
|
||||||
req *dns.Msg,
|
|
||||||
) (err error) {
|
|
||||||
metrics.DNSSvcFirefoxRequestsTotal.Inc()
|
|
||||||
|
|
||||||
resp := mw.messages.NewMsgREFUSED(req)
|
|
||||||
err = rw.WriteMsg(ctx, req, resp)
|
|
||||||
|
|
||||||
return errors.Annotate(err, "writing firefox canary resp: %w")
|
|
||||||
}
|
|
||||||
|
|
||||||
// reqInfoSpecialHandler returns a handler that can handle a special-domain
|
// reqInfoSpecialHandler returns a handler that can handle a special-domain
|
||||||
// query based on the request info, as well as the handler's name for debugging.
|
// query based on the request info, as well as the handler's name for debugging.
|
||||||
@ -99,11 +65,9 @@ func (mw *initMw) reqInfoSpecialHandler(
|
|||||||
} else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) {
|
} else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) {
|
||||||
// A badly formed resolver.arpa subdomain query.
|
// A badly formed resolver.arpa subdomain query.
|
||||||
return mw.handleBadResolverARPA, "bad_resolver_arpa"
|
return mw.handleBadResolverARPA, "bad_resolver_arpa"
|
||||||
} else if shouldBlockPrivateRelay(ri) {
|
|
||||||
return mw.handlePrivateRelay, "apple_private_relay"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ""
|
return mw.specialDomainHandler(ri)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reqInfoHandlerFunc is an alias for handler functions that additionally accept
|
// reqInfoHandlerFunc is an alias for handler functions that additionally accept
|
||||||
@ -222,24 +186,46 @@ func (mw *initMw) handleBadResolverARPA(
|
|||||||
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
|
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// specialDomainHandler returns a handler that can handle a special-domain
|
||||||
|
// query for Apple Private Relay or Firefox canary domain based on the request
|
||||||
|
// or profile information, as well as the handler's name for debugging.
|
||||||
|
func (mw *initMw) specialDomainHandler(
|
||||||
|
ri *agd.RequestInfo,
|
||||||
|
) (f reqInfoHandlerFunc, name string) {
|
||||||
|
qt := ri.QType
|
||||||
|
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
host := ri.Host
|
||||||
|
prof := ri.Profile
|
||||||
|
|
||||||
|
switch host {
|
||||||
|
case
|
||||||
|
applePrivateRelayMaskHost,
|
||||||
|
applePrivateRelayMaskH2Host,
|
||||||
|
applePrivateRelayMaskCanaryHost:
|
||||||
|
if shouldBlockPrivateRelay(ri, prof) {
|
||||||
|
return mw.handlePrivateRelay, "apple_private_relay"
|
||||||
|
}
|
||||||
|
case firefoxCanaryHost:
|
||||||
|
if shouldBlockFirefoxCanary(ri, prof) {
|
||||||
|
return mw.handleFirefoxCanary, "firefox"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
// Apple Private Relay
|
// Apple Private Relay
|
||||||
|
|
||||||
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
|
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
|
||||||
// Relay check domain and the request information indicates that Apple Private
|
// Relay check domain and the request information or profile indicates that
|
||||||
// Relay should be blocked.
|
// Apple Private Relay should be blocked.
|
||||||
func shouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
|
func shouldBlockPrivateRelay(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
|
||||||
qt := ri.QType
|
if prof != nil {
|
||||||
host := ri.Host
|
|
||||||
|
|
||||||
return (qt == dns.TypeA || qt == dns.TypeAAAA) &&
|
|
||||||
(host == applePrivateRelayMaskHost || host == applePrivateRelayMaskH2Host) &&
|
|
||||||
reqInfoShouldBlockPrivateRelay(ri)
|
|
||||||
}
|
|
||||||
|
|
||||||
// reqInfoShouldBlockPrivateRelay returns true if Apple Private Relay queries
|
|
||||||
// should be blocked based on the request information.
|
|
||||||
func reqInfoShouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
|
|
||||||
if prof := ri.Profile; prof != nil {
|
|
||||||
return prof.BlockPrivateRelay
|
return prof.BlockPrivateRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,3 +246,32 @@ func (mw *initMw) handlePrivateRelay(
|
|||||||
|
|
||||||
return errors.Annotate(err, "writing private relay resp: %w")
|
return errors.Annotate(err, "writing private relay resp: %w")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Firefox canary domain
|
||||||
|
|
||||||
|
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
|
||||||
|
// domain and the request information or profile indicates that Firefox canary
|
||||||
|
// domain should be blocked.
|
||||||
|
func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
|
||||||
|
if prof != nil {
|
||||||
|
return prof.BlockFirefoxCanary
|
||||||
|
}
|
||||||
|
|
||||||
|
return ri.FilteringGroup.BlockFirefoxCanary
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFirefoxCanary checks if the request is for the fully-qualified domain
|
||||||
|
// name that Firefox uses to check DoH settings and writes a response if needed.
|
||||||
|
func (mw *initMw) handleFirefoxCanary(
|
||||||
|
ctx context.Context,
|
||||||
|
rw dnsserver.ResponseWriter,
|
||||||
|
req *dns.Msg,
|
||||||
|
ri *agd.RequestInfo,
|
||||||
|
) (err error) {
|
||||||
|
metrics.DNSSvcFirefoxRequestsTotal.Inc()
|
||||||
|
|
||||||
|
resp := mw.messages.NewMsgREFUSED(req)
|
||||||
|
err = rw.WriteMsg(ctx, req, resp)
|
||||||
|
|
||||||
|
return errors.Annotate(err, "writing firefox canary resp: %w")
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/mathutil"
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@ -93,22 +94,16 @@ func (mw *Middleware) toCacheKey(cr *cacheRequest, hostHasECS bool) (key uint64)
|
|||||||
var buf [6]byte
|
var buf [6]byte
|
||||||
binary.LittleEndian.PutUint16(buf[:2], cr.qType)
|
binary.LittleEndian.PutUint16(buf[:2], cr.qType)
|
||||||
binary.LittleEndian.PutUint16(buf[2:4], cr.qClass)
|
binary.LittleEndian.PutUint16(buf[2:4], cr.qClass)
|
||||||
if cr.reqDO {
|
|
||||||
buf[4] = 1
|
|
||||||
} else {
|
|
||||||
buf[4] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if cr.subnet.Addr().Is4() {
|
buf[4] = mathutil.BoolToNumber[byte](cr.reqDO)
|
||||||
buf[5] = 0
|
|
||||||
} else {
|
addr := cr.subnet.Addr()
|
||||||
buf[5] = 1
|
buf[5] = mathutil.BoolToNumber[byte](addr.Is6())
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = h.Write(buf[:])
|
_, _ = h.Write(buf[:])
|
||||||
|
|
||||||
if hostHasECS {
|
if hostHasECS {
|
||||||
_, _ = h.Write(cr.subnet.Addr().AsSlice())
|
_, _ = h.Write(addr.AsSlice())
|
||||||
_ = h.WriteByte(byte(cr.subnet.Bits()))
|
_ = h.WriteByte(byte(cr.subnet.Bits()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||||
|
"github.com/AdguardTeam/golibs/mathutil"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@ -269,9 +270,5 @@ func getTTLIfLower(r dns.RR, ttl uint32) (res uint32) {
|
|||||||
// Go on.
|
// Go on.
|
||||||
}
|
}
|
||||||
|
|
||||||
if httl := r.Header().Ttl; httl < ttl {
|
return mathutil.Min(r.Header().Ttl, ttl)
|
||||||
return httl
|
|
||||||
}
|
|
||||||
|
|
||||||
return ttl
|
|
||||||
}
|
}
|
||||||
|
@ -67,13 +67,13 @@ type SentryReportableError interface {
|
|||||||
// TODO(a.garipov): Make sure that we use this approach everywhere.
|
// TODO(a.garipov): Make sure that we use this approach everywhere.
|
||||||
func isReportable(err error) (ok bool) {
|
func isReportable(err error) (ok bool) {
|
||||||
var (
|
var (
|
||||||
ravErr SentryReportableError
|
sentryRepErr SentryReportableError
|
||||||
fwdErr *forward.Error
|
fwdErr *forward.Error
|
||||||
dnsWErr *dnsserver.WriteError
|
dnsWErr *dnsserver.WriteError
|
||||||
)
|
)
|
||||||
|
|
||||||
if errors.As(err, &ravErr) {
|
if errors.As(err, &sentryRepErr) {
|
||||||
return ravErr.IsSentryReportable()
|
return sentryRepErr.IsSentryReportable()
|
||||||
} else if errors.As(err, &fwdErr) {
|
} else if errors.As(err, &fwdErr) {
|
||||||
return isReportableNetwork(fwdErr.Err)
|
return isReportableNetwork(fwdErr.Err)
|
||||||
} else if errors.As(err, &dnsWErr) {
|
} else if errors.As(err, &dnsWErr) {
|
||||||
|
@ -21,8 +21,8 @@ var _ Interface = (*compFilter)(nil)
|
|||||||
// compFilter is a composite filter based on several types of safe search
|
// compFilter is a composite filter based on several types of safe search
|
||||||
// filters and rule lists.
|
// filters and rule lists.
|
||||||
type compFilter struct {
|
type compFilter struct {
|
||||||
safeBrowsing *hashPrefixFilter
|
safeBrowsing *HashPrefix
|
||||||
adultBlocking *hashPrefixFilter
|
adultBlocking *HashPrefix
|
||||||
|
|
||||||
genSafeSearch *safeSearch
|
genSafeSearch *safeSearch
|
||||||
ytSafeSearch *safeSearch
|
ytSafeSearch *safeSearch
|
||||||
|
@ -18,10 +18,11 @@ import (
|
|||||||
// maxFilterSize is the maximum size of downloaded filters.
|
// maxFilterSize is the maximum size of downloaded filters.
|
||||||
const maxFilterSize = 196 * int64(datasize.MB)
|
const maxFilterSize = 196 * int64(datasize.MB)
|
||||||
|
|
||||||
// defaultTimeout is the default timeout to use when fetching filter data.
|
// defaultFilterRefreshTimeout is the default timeout to use when fetching
|
||||||
|
// filter lists data.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Consider making timeouts where they are used configurable.
|
// TODO(a.garipov): Consider making timeouts where they are used configurable.
|
||||||
const defaultTimeout = 30 * time.Second
|
const defaultFilterRefreshTimeout = 180 * time.Second
|
||||||
|
|
||||||
// defaultResolveTimeout is the default timeout for resolving hosts for safe
|
// defaultResolveTimeout is the default timeout for resolving hosts for safe
|
||||||
// search and safe browsing filters.
|
// search and safe browsing filters.
|
||||||
|
@ -181,24 +181,18 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) {
|
|||||||
FilterIndexURL: fltsURL,
|
FilterIndexURL: fltsURL,
|
||||||
GeneralSafeSearchRulesURL: ssURL,
|
GeneralSafeSearchRulesURL: ssURL,
|
||||||
YoutubeSafeSearchRulesURL: ssURL,
|
YoutubeSafeSearchRulesURL: ssURL,
|
||||||
SafeBrowsing: &filter.HashPrefixConfig{
|
SafeBrowsing: &filter.HashPrefix{},
|
||||||
CacheTTL: 1 * time.Hour,
|
AdultBlocking: &filter.HashPrefix{},
|
||||||
CacheSize: 100,
|
Now: time.Now,
|
||||||
},
|
ErrColl: nil,
|
||||||
AdultBlocking: &filter.HashPrefixConfig{
|
Resolver: nil,
|
||||||
CacheTTL: 1 * time.Hour,
|
CacheDir: cacheDir,
|
||||||
CacheSize: 100,
|
CustomFilterCacheSize: 100,
|
||||||
},
|
SafeSearchCacheSize: 100,
|
||||||
Now: time.Now,
|
SafeSearchCacheTTL: 1 * time.Hour,
|
||||||
ErrColl: nil,
|
RuleListCacheSize: 100,
|
||||||
Resolver: nil,
|
RefreshIvl: testRefreshIvl,
|
||||||
CacheDir: cacheDir,
|
UseRuleListCache: false,
|
||||||
CustomFilterCacheSize: 100,
|
|
||||||
SafeSearchCacheSize: 100,
|
|
||||||
SafeSearchCacheTTL: 1 * time.Hour,
|
|
||||||
RuleListCacheSize: 100,
|
|
||||||
RefreshIvl: testRefreshIvl,
|
|
||||||
UseRuleListCache: false,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,13 +3,17 @@ package filter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
@ -21,13 +25,32 @@ import (
|
|||||||
// HashPrefixConfig is the hash-prefix filter configuration structure.
|
// HashPrefixConfig is the hash-prefix filter configuration structure.
|
||||||
type HashPrefixConfig struct {
|
type HashPrefixConfig struct {
|
||||||
// Hashes are the hostname hashes for this filter.
|
// Hashes are the hostname hashes for this filter.
|
||||||
Hashes *HashStorage
|
Hashes *hashstorage.Storage
|
||||||
|
|
||||||
|
// URL is the URL used to update the filter.
|
||||||
|
URL *url.URL
|
||||||
|
|
||||||
|
// ErrColl is used to collect non-critical and rare errors.
|
||||||
|
ErrColl agd.ErrorCollector
|
||||||
|
|
||||||
|
// Resolver is used to resolve hosts for the hash-prefix filter.
|
||||||
|
Resolver agdnet.Resolver
|
||||||
|
|
||||||
|
// ID is the ID of this hash storage for logging and error reporting.
|
||||||
|
ID agd.FilterListID
|
||||||
|
|
||||||
|
// CachePath is the path to the file containing the cached filtered
|
||||||
|
// hostnames, one per line.
|
||||||
|
CachePath string
|
||||||
|
|
||||||
// ReplacementHost is the replacement host for this filter. Queries
|
// ReplacementHost is the replacement host for this filter. Queries
|
||||||
// matched by the filter receive a response with the IP addresses of
|
// matched by the filter receive a response with the IP addresses of
|
||||||
// this host.
|
// this host.
|
||||||
ReplacementHost string
|
ReplacementHost string
|
||||||
|
|
||||||
|
// Staleness is the time after which a file is considered stale.
|
||||||
|
Staleness time.Duration
|
||||||
|
|
||||||
// CacheTTL is the time-to-live value used to cache the results of the
|
// CacheTTL is the time-to-live value used to cache the results of the
|
||||||
// filter.
|
// filter.
|
||||||
//
|
//
|
||||||
@ -38,57 +61,58 @@ type HashPrefixConfig struct {
|
|||||||
CacheSize int
|
CacheSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
// hashPrefixFilter is a filter that matches hosts by their hashes based on
|
// HashPrefix is a filter that matches hosts by their hashes based on a
|
||||||
// a hash-prefix table.
|
// hash-prefix table.
|
||||||
type hashPrefixFilter struct {
|
type HashPrefix struct {
|
||||||
hashes *HashStorage
|
hashes *hashstorage.Storage
|
||||||
|
refr *refreshableFilter
|
||||||
resCache *resultcache.Cache[*ResultModified]
|
resCache *resultcache.Cache[*ResultModified]
|
||||||
resolver agdnet.Resolver
|
resolver agdnet.Resolver
|
||||||
errColl agd.ErrorCollector
|
errColl agd.ErrorCollector
|
||||||
repHost string
|
repHost string
|
||||||
id agd.FilterListID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHashPrefixFilter returns a new hash-prefix filter. c must not be nil.
|
// NewHashPrefix returns a new hash-prefix filter. c must not be nil.
|
||||||
func newHashPrefixFilter(
|
func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) {
|
||||||
c *HashPrefixConfig,
|
f = &HashPrefix{
|
||||||
resolver agdnet.Resolver,
|
hashes: c.Hashes,
|
||||||
errColl agd.ErrorCollector,
|
refr: &refreshableFilter{
|
||||||
id agd.FilterListID,
|
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
||||||
) (f *hashPrefixFilter) {
|
Timeout: defaultFilterRefreshTimeout,
|
||||||
f = &hashPrefixFilter{
|
}),
|
||||||
hashes: c.Hashes,
|
url: c.URL,
|
||||||
|
id: c.ID,
|
||||||
|
cachePath: c.CachePath,
|
||||||
|
typ: "hash storage",
|
||||||
|
staleness: c.Staleness,
|
||||||
|
},
|
||||||
resCache: resultcache.New[*ResultModified](c.CacheSize),
|
resCache: resultcache.New[*ResultModified](c.CacheSize),
|
||||||
resolver: resolver,
|
resolver: c.Resolver,
|
||||||
errColl: errColl,
|
errColl: c.ErrColl,
|
||||||
repHost: c.ReplacementHost,
|
repHost: c.ReplacementHost,
|
||||||
id: id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Patch the refresh function of the hash storage, if there is one, to make
|
f.refr.resetRules = f.resetRules
|
||||||
// sure that we clear the result cache during every refresh.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Create a better way to do that than this spaghetti-style
|
|
||||||
// patching. Perhaps, lift the entire logic of hash-storage refresh into
|
|
||||||
// hashPrefixFilter.
|
|
||||||
if f.hashes != nil && f.hashes.refr != nil {
|
|
||||||
resetRules := f.hashes.refr.resetRules
|
|
||||||
f.hashes.refr.resetRules = func(text string) (err error) {
|
|
||||||
f.resCache.Clear()
|
|
||||||
|
|
||||||
return resetRules(text)
|
err = f.refresh(context.Background(), true)
|
||||||
}
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return f
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// id returns the ID of the hash storage.
|
||||||
|
func (f *HashPrefix) id() (fltID agd.FilterListID) {
|
||||||
|
return f.refr.id
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ qtHostFilter = (*hashPrefixFilter)(nil)
|
var _ qtHostFilter = (*HashPrefix)(nil)
|
||||||
|
|
||||||
// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
|
// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
|
||||||
// modifies the response if host matches f.
|
// modifies the response if host matches f.
|
||||||
func (f *hashPrefixFilter) filterReq(
|
func (f *HashPrefix) filterReq(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
ri *agd.RequestInfo,
|
ri *agd.RequestInfo,
|
||||||
req *dns.Msg,
|
req *dns.Msg,
|
||||||
@ -115,7 +139,7 @@ func (f *hashPrefixFilter) filterReq(
|
|||||||
var matched string
|
var matched string
|
||||||
sub := hashableSubdomains(host)
|
sub := hashableSubdomains(host)
|
||||||
for _, s := range sub {
|
for _, s := range sub {
|
||||||
if f.hashes.hashMatches(s) {
|
if f.hashes.Matches(s) {
|
||||||
matched = s
|
matched = s
|
||||||
|
|
||||||
break
|
break
|
||||||
@ -134,19 +158,19 @@ func (f *hashPrefixFilter) filterReq(
|
|||||||
var result *dns.Msg
|
var result *dns.Msg
|
||||||
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
|
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err)
|
agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err)
|
||||||
|
|
||||||
result = ri.Messages.NewMsgSERVFAIL(req)
|
result = ri.Messages.NewMsgSERVFAIL(req)
|
||||||
} else {
|
} else {
|
||||||
result, err = ri.Messages.NewIPRespMsg(req, ips...)
|
result, err = ri.Messages.NewIPRespMsg(req, ips...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err)
|
return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rm = &ResultModified{
|
rm = &ResultModified{
|
||||||
Msg: result,
|
Msg: result,
|
||||||
List: f.id,
|
List: f.id(),
|
||||||
Rule: agd.FilterRuleText(matched),
|
Rule: agd.FilterRuleText(matched),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,21 +185,21 @@ func (f *hashPrefixFilter) filterReq(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateCacheSizeMetrics updates cache size metrics.
|
// updateCacheSizeMetrics updates cache size metrics.
|
||||||
func (f *hashPrefixFilter) updateCacheSizeMetrics(size int) {
|
func (f *HashPrefix) updateCacheSizeMetrics(size int) {
|
||||||
switch f.id {
|
switch id := f.id(); id {
|
||||||
case agd.FilterListIDSafeBrowsing:
|
case agd.FilterListIDSafeBrowsing:
|
||||||
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
|
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
|
||||||
case agd.FilterListIDAdultBlocking:
|
case agd.FilterListIDAdultBlocking:
|
||||||
metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size))
|
metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size))
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("unsupported FilterListID %s", f.id))
|
panic(fmt.Errorf("unsupported FilterListID %s", id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateCacheLookupsMetrics updates cache lookups metrics.
|
// updateCacheLookupsMetrics updates cache lookups metrics.
|
||||||
func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
|
func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
|
||||||
var hitsMetric, missesMetric prometheus.Counter
|
var hitsMetric, missesMetric prometheus.Counter
|
||||||
switch f.id {
|
switch id := f.id(); id {
|
||||||
case agd.FilterListIDSafeBrowsing:
|
case agd.FilterListIDSafeBrowsing:
|
||||||
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
|
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
|
||||||
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
|
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
|
||||||
@ -183,7 +207,7 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
|
|||||||
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
|
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
|
||||||
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
|
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("unsupported FilterListID %s", f.id))
|
panic(fmt.Errorf("unsupported FilterListID %s", id))
|
||||||
}
|
}
|
||||||
|
|
||||||
if hit {
|
if hit {
|
||||||
@ -194,12 +218,50 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// name implements the qtHostFilter interface for *hashPrefixFilter.
|
// name implements the qtHostFilter interface for *hashPrefixFilter.
|
||||||
func (f *hashPrefixFilter) name() (n string) {
|
func (f *HashPrefix) name() (n string) {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(f.id)
|
return string(f.id())
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ agd.Refresher = (*HashPrefix)(nil)
|
||||||
|
|
||||||
|
// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter.
|
||||||
|
func (f *HashPrefix) Refresh(ctx context.Context) (err error) {
|
||||||
|
return f.refresh(ctx, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// refresh reloads the hash filter data. If acceptStale is true, do not try to
|
||||||
|
// load the list from its URL when there is already a file in the cache
|
||||||
|
// directory, regardless of its staleness.
|
||||||
|
func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) {
|
||||||
|
return f.refr.refresh(ctx, acceptStale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetRules resets the hosts in the index.
|
||||||
|
func (f *HashPrefix) resetRules(text string) (err error) {
|
||||||
|
n, err := f.hashes.Reset(text)
|
||||||
|
|
||||||
|
// Report the filter update to prometheus.
|
||||||
|
promLabels := prometheus.Labels{
|
||||||
|
"filter": string(f.id()),
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
|
||||||
|
metrics.FilterRulesTotal.With(promLabels).Set(float64(n))
|
||||||
|
|
||||||
|
log.Info("filter %s: reset %d hosts", f.id(), n)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// subDomainNum defines how many labels should be hashed to match against a hash
|
// subDomainNum defines how many labels should be hashed to match against a hash
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -21,8 +22,10 @@ import (
|
|||||||
func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
|
func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
|
||||||
cacheDir := t.TempDir()
|
cacheDir := t.TempDir()
|
||||||
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
|
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
|
||||||
hosts := "scam.example.net\n"
|
err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
|
||||||
err := os.WriteFile(cachePath, []byte(hosts), 0o644)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hashes, err := hashstorage.New("")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
errColl := &agdtest.ErrorCollector{
|
errColl := &agdtest.ErrorCollector{
|
||||||
@ -31,46 +34,33 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
|
resolver := &agdtest.Resolver{
|
||||||
URL: nil,
|
OnLookupIP: func(
|
||||||
ErrColl: errColl,
|
_ context.Context,
|
||||||
ID: agd.FilterListIDSafeBrowsing,
|
_ netutil.AddrFamily,
|
||||||
CachePath: cachePath,
|
_ string,
|
||||||
RefreshIvl: testRefreshIvl,
|
) (ips []net.IP, err error) {
|
||||||
})
|
return []net.IP{safeBrowsingSafeIP4}, nil
|
||||||
require.NoError(t, err)
|
},
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
err = hashes.Start()
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
|
||||||
return hashes.Shutdown(ctx)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Fake Data
|
|
||||||
|
|
||||||
onLookupIP := func(
|
|
||||||
_ context.Context,
|
|
||||||
_ netutil.AddrFamily,
|
|
||||||
_ string,
|
|
||||||
) (ips []net.IP, err error) {
|
|
||||||
return []net.IP{safeBrowsingSafeIP4}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c := prepareConf(t)
|
c := prepareConf(t)
|
||||||
|
|
||||||
c.SafeBrowsing = &filter.HashPrefixConfig{
|
c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
|
||||||
Hashes: hashes,
|
Hashes: hashes,
|
||||||
|
ErrColl: errColl,
|
||||||
|
Resolver: resolver,
|
||||||
|
ID: agd.FilterListIDSafeBrowsing,
|
||||||
|
CachePath: cachePath,
|
||||||
ReplacementHost: safeBrowsingSafeHost,
|
ReplacementHost: safeBrowsingSafeHost,
|
||||||
|
Staleness: 1 * time.Hour,
|
||||||
CacheTTL: 10 * time.Second,
|
CacheTTL: 10 * time.Second,
|
||||||
CacheSize: 100,
|
CacheSize: 100,
|
||||||
}
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
c.ErrColl = errColl
|
c.ErrColl = errColl
|
||||||
|
c.Resolver = resolver
|
||||||
c.Resolver = &agdtest.Resolver{
|
|
||||||
OnLookupIP: onLookupIP,
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := filter.NewDefaultStorage(c)
|
s, err := filter.NewDefaultStorage(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -93,7 +83,7 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
|
ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
|
||||||
ctx = agd.ContextWithRequestInfo(ctx, ri)
|
ctx := agd.ContextWithRequestInfo(context.Background(), ri)
|
||||||
|
|
||||||
f := s.FilterFromContext(ctx, ri)
|
f := s.FilterFromContext(ctx, ri)
|
||||||
require.NotNil(t, f)
|
require.NotNil(t, f)
|
||||||
|
@ -1,352 +0,0 @@
|
|||||||
package filter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Hash Storage
|
|
||||||
|
|
||||||
// Hash and hash part length constants.
|
|
||||||
const (
|
|
||||||
// hashPrefixLen is the length of the prefix of the hash of the filtered
|
|
||||||
// hostname.
|
|
||||||
hashPrefixLen = 2
|
|
||||||
|
|
||||||
// HashPrefixEncLen is the encoded length of the hash prefix. Two text
|
|
||||||
// bytes per one binary byte.
|
|
||||||
HashPrefixEncLen = hashPrefixLen * 2
|
|
||||||
|
|
||||||
// hashLen is the length of the whole hash of the checked hostname.
|
|
||||||
hashLen = sha256.Size
|
|
||||||
|
|
||||||
// hashSuffixLen is the length of the suffix of the hash of the filtered
|
|
||||||
// hostname.
|
|
||||||
hashSuffixLen = hashLen - hashPrefixLen
|
|
||||||
|
|
||||||
// hashEncLen is the encoded length of the hash. Two text bytes per one
|
|
||||||
// binary byte.
|
|
||||||
hashEncLen = hashLen * 2
|
|
||||||
|
|
||||||
// legacyHashPrefixEncLen is the encoded length of a legacy hash.
|
|
||||||
legacyHashPrefixEncLen = 8
|
|
||||||
)
|
|
||||||
|
|
||||||
// hashPrefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of
|
|
||||||
// a host being checked.
|
|
||||||
type hashPrefix [hashPrefixLen]byte
|
|
||||||
|
|
||||||
// hashSuffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of
|
|
||||||
// a host being checked.
|
|
||||||
type hashSuffix [hashSuffixLen]byte
|
|
||||||
|
|
||||||
// HashStorage is a storage for hashes of the filtered hostnames.
|
|
||||||
type HashStorage struct {
|
|
||||||
// mu protects hashSuffixes.
|
|
||||||
mu *sync.RWMutex
|
|
||||||
hashSuffixes map[hashPrefix][]hashSuffix
|
|
||||||
|
|
||||||
// refr contains data for refreshing the filter.
|
|
||||||
refr *refreshableFilter
|
|
||||||
|
|
||||||
refrWorker *agd.RefreshWorker
|
|
||||||
}
|
|
||||||
|
|
||||||
// HashStorageConfig is the configuration structure for a *HashStorage.
|
|
||||||
type HashStorageConfig struct {
|
|
||||||
// URL is the URL used to update the filter.
|
|
||||||
URL *url.URL
|
|
||||||
|
|
||||||
// ErrColl is used to collect non-critical and rare errors.
|
|
||||||
ErrColl agd.ErrorCollector
|
|
||||||
|
|
||||||
// ID is the ID of this hash storage for logging and error reporting.
|
|
||||||
ID agd.FilterListID
|
|
||||||
|
|
||||||
// CachePath is the path to the file containing the cached filtered
|
|
||||||
// hostnames, one per line.
|
|
||||||
CachePath string
|
|
||||||
|
|
||||||
// RefreshIvl is the refresh interval.
|
|
||||||
RefreshIvl time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHashStorage returns a new hash storage containing hashes of all hostnames.
|
|
||||||
func NewHashStorage(c *HashStorageConfig) (hs *HashStorage, err error) {
|
|
||||||
hs = &HashStorage{
|
|
||||||
mu: &sync.RWMutex{},
|
|
||||||
hashSuffixes: map[hashPrefix][]hashSuffix{},
|
|
||||||
refr: &refreshableFilter{
|
|
||||||
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
|
||||||
Timeout: defaultTimeout,
|
|
||||||
}),
|
|
||||||
url: c.URL,
|
|
||||||
id: c.ID,
|
|
||||||
cachePath: c.CachePath,
|
|
||||||
typ: "hash storage",
|
|
||||||
refreshIvl: c.RefreshIvl,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do not set this in the literal above, since hs is nil there.
|
|
||||||
hs.refr.resetRules = hs.resetHosts
|
|
||||||
|
|
||||||
refrWorker := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
|
|
||||||
Context: func() (ctx context.Context, cancel context.CancelFunc) {
|
|
||||||
return context.WithTimeout(context.Background(), defaultTimeout)
|
|
||||||
},
|
|
||||||
Refresher: hs,
|
|
||||||
ErrColl: c.ErrColl,
|
|
||||||
Name: string(c.ID),
|
|
||||||
Interval: c.RefreshIvl,
|
|
||||||
RefreshOnShutdown: false,
|
|
||||||
RoutineLogsAreDebug: false,
|
|
||||||
})
|
|
||||||
|
|
||||||
hs.refrWorker = refrWorker
|
|
||||||
|
|
||||||
err = hs.refresh(context.Background(), true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("initializing %s: %w", c.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashes returns all hashes starting with the given prefixes, if any. The
|
|
||||||
// resulting slice shares storage for all underlying strings.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): This currently doesn't take duplicates into account.
|
|
||||||
func (hs *HashStorage) hashes(hps []hashPrefix) (hashes []string) {
|
|
||||||
if len(hps) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.mu.RLock()
|
|
||||||
defer hs.mu.RUnlock()
|
|
||||||
|
|
||||||
// First, calculate the number of hashes to allocate the buffer.
|
|
||||||
l := 0
|
|
||||||
for _, hp := range hps {
|
|
||||||
hashSufs := hs.hashSuffixes[hp]
|
|
||||||
l += len(hashSufs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then, allocate the buffer of the appropriate size and write all hashes
|
|
||||||
// into one big buffer and slice it into separate strings to make the
|
|
||||||
// garbage collector's work easier. This assumes that all references to
|
|
||||||
// this buffer will become unreachable at the same time.
|
|
||||||
//
|
|
||||||
// The fact that we iterate over the map twice shouldn't matter, since we
|
|
||||||
// assume that len(hps) will be below 5 most of the time.
|
|
||||||
b := &strings.Builder{}
|
|
||||||
b.Grow(l * hashEncLen)
|
|
||||||
|
|
||||||
// Use a buffer and write the resulting buffer into b directly instead of
|
|
||||||
// using hex.NewEncoder, because that seems to incur a significant
|
|
||||||
// performance hit.
|
|
||||||
var buf [hashEncLen]byte
|
|
||||||
for _, hp := range hps {
|
|
||||||
hashSufs := hs.hashSuffixes[hp]
|
|
||||||
for _, suf := range hashSufs {
|
|
||||||
// Slicing is safe here, since the contents of hp and suf are being
|
|
||||||
// encoded.
|
|
||||||
|
|
||||||
// nolint:looppointer
|
|
||||||
hex.Encode(buf[:], hp[:])
|
|
||||||
// nolint:looppointer
|
|
||||||
hex.Encode(buf[HashPrefixEncLen:], suf[:])
|
|
||||||
_, _ = b.Write(buf[:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s := b.String()
|
|
||||||
hashes = make([]string, 0, l)
|
|
||||||
for i := 0; i < l; i++ {
|
|
||||||
hashes = append(hashes, s[i*hashEncLen:(i+1)*hashEncLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
return hashes
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
|
|
||||||
// concurrent use.
|
|
||||||
func (hs *HashStorage) loadHashSuffixes(hp hashPrefix) (sufs []hashSuffix, ok bool) {
|
|
||||||
hs.mu.RLock()
|
|
||||||
defer hs.mu.RUnlock()
|
|
||||||
|
|
||||||
sufs, ok = hs.hashSuffixes[hp]
|
|
||||||
|
|
||||||
return sufs, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashMatches returns true if the host matches one of the hashes.
|
|
||||||
func (hs *HashStorage) hashMatches(host string) (ok bool) {
|
|
||||||
sum := sha256.Sum256([]byte(host))
|
|
||||||
hp := hashPrefix{sum[0], sum[1]}
|
|
||||||
|
|
||||||
var buf [hashLen]byte
|
|
||||||
hashSufs, ok := hs.loadHashSuffixes(hp)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(buf[:], hp[:])
|
|
||||||
for _, suf := range hashSufs {
|
|
||||||
// Slicing is safe here, because we make a copy.
|
|
||||||
|
|
||||||
// nolint:looppointer
|
|
||||||
copy(buf[hashPrefixLen:], suf[:])
|
|
||||||
if buf == sum {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
|
|
||||||
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashPrefix, err error) {
|
|
||||||
if prefixesStr == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixSet := stringutil.NewSet()
|
|
||||||
prefixStrs := strings.Split(prefixesStr, ".")
|
|
||||||
for _, s := range prefixStrs {
|
|
||||||
if len(s) != HashPrefixEncLen {
|
|
||||||
// Some legacy clients send eight-character hashes instead of
|
|
||||||
// four-character ones. For now, remove the final four characters.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Either remove this crutch or support such
|
|
||||||
// prefixes better.
|
|
||||||
if len(s) == legacyHashPrefixEncLen {
|
|
||||||
s = s[:HashPrefixEncLen]
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("bad hash len for %q", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixSet.Add(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
hashPrefixes = make([]hashPrefix, prefixSet.Len())
|
|
||||||
prefixStrs = prefixSet.Values()
|
|
||||||
for i, s := range prefixStrs {
|
|
||||||
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("bad hash encoding for %q", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hashPrefixes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ agd.Refresher = (*HashStorage)(nil)
|
|
||||||
|
|
||||||
// Refresh implements the agd.Refresher interface for *HashStorage. If the file
|
|
||||||
// at the storage's path exists and its mtime shows that it's still fresh, it
|
|
||||||
// loads the data from the file. Otherwise, it uses the URL of the storage.
|
|
||||||
func (hs *HashStorage) Refresh(ctx context.Context) (err error) {
|
|
||||||
err = hs.refresh(ctx, false)
|
|
||||||
|
|
||||||
// Report the filter update to prometheus.
|
|
||||||
promLabels := prometheus.Labels{
|
|
||||||
"filter": string(hs.id()),
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
|
|
||||||
|
|
||||||
// Count the total number of hashes loaded.
|
|
||||||
count := 0
|
|
||||||
for _, v := range hs.hashSuffixes {
|
|
||||||
count += len(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.FilterRulesTotal.With(promLabels).Set(float64(count))
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// id returns the ID of the hash storage.
|
|
||||||
func (hs *HashStorage) id() (fltID agd.FilterListID) {
|
|
||||||
return hs.refr.id
|
|
||||||
}
|
|
||||||
|
|
||||||
// refresh reloads the hash filter data. If acceptStale is true, do not try to
|
|
||||||
// load the list from its URL when there is already a file in the cache
|
|
||||||
// directory, regardless of its staleness.
|
|
||||||
func (hs *HashStorage) refresh(ctx context.Context, acceptStale bool) (err error) {
|
|
||||||
return hs.refr.refresh(ctx, acceptStale)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetHosts resets the hosts in the index.
|
|
||||||
func (hs *HashStorage) resetHosts(hostsStr string) (err error) {
|
|
||||||
hs.mu.Lock()
|
|
||||||
defer hs.mu.Unlock()
|
|
||||||
|
|
||||||
// Delete all elements without allocating a new map to safe space and
|
|
||||||
// performance.
|
|
||||||
//
|
|
||||||
// This is optimized, see https://github.com/golang/go/issues/20138.
|
|
||||||
for hp := range hs.hashSuffixes {
|
|
||||||
delete(hs.hashSuffixes, hp)
|
|
||||||
}
|
|
||||||
|
|
||||||
var n int
|
|
||||||
s := bufio.NewScanner(strings.NewReader(hostsStr))
|
|
||||||
for s.Scan() {
|
|
||||||
host := s.Text()
|
|
||||||
if len(host) == 0 || host[0] == '#' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := sha256.Sum256([]byte(host))
|
|
||||||
hp := hashPrefix{sum[0], sum[1]}
|
|
||||||
|
|
||||||
// TODO(a.garipov): Convert to array directly when proposal
|
|
||||||
// golang/go#46505 is implemented in Go 1.20.
|
|
||||||
suf := *(*hashSuffix)(sum[hashPrefixLen:])
|
|
||||||
hs.hashSuffixes[hp] = append(hs.hashSuffixes[hp], suf)
|
|
||||||
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.Err()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("scanning hosts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("filter %s: reset %d hosts", hs.id(), n)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start implements the agd.Service interface for *HashStorage.
|
|
||||||
func (hs *HashStorage) Start() (err error) {
|
|
||||||
return hs.refrWorker.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown implements the agd.Service interface for *HashStorage.
|
|
||||||
func (hs *HashStorage) Shutdown(ctx context.Context) (err error) {
|
|
||||||
return hs.refrWorker.Shutdown(ctx)
|
|
||||||
}
|
|
203
internal/filter/hashstorage/hashstorage.go
Normal file
203
internal/filter/hashstorage/hashstorage.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
// Package hashstorage defines a storage of hashes of domain names used for
|
||||||
|
// filtering.
|
||||||
|
package hashstorage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hash and hash part length constants.
|
||||||
|
const (
|
||||||
|
// PrefixLen is the length of the prefix of the hash of the filtered
|
||||||
|
// hostname.
|
||||||
|
PrefixLen = 2
|
||||||
|
|
||||||
|
// PrefixEncLen is the encoded length of the hash prefix. Two text
|
||||||
|
// bytes per one binary byte.
|
||||||
|
PrefixEncLen = PrefixLen * 2
|
||||||
|
|
||||||
|
// hashLen is the length of the whole hash of the checked hostname.
|
||||||
|
hashLen = sha256.Size
|
||||||
|
|
||||||
|
// suffixLen is the length of the suffix of the hash of the filtered
|
||||||
|
// hostname.
|
||||||
|
suffixLen = hashLen - PrefixLen
|
||||||
|
|
||||||
|
// hashEncLen is the encoded length of the hash. Two text bytes per one
|
||||||
|
// binary byte.
|
||||||
|
hashEncLen = hashLen * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a
|
||||||
|
// host being checked.
|
||||||
|
type Prefix [PrefixLen]byte
|
||||||
|
|
||||||
|
// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a
|
||||||
|
// host being checked.
|
||||||
|
type suffix [suffixLen]byte
|
||||||
|
|
||||||
|
// Storage stores hashes of the filtered hostnames. All methods are safe for
|
||||||
|
// concurrent use.
|
||||||
|
type Storage struct {
|
||||||
|
// mu protects hashSuffixes.
|
||||||
|
mu *sync.RWMutex
|
||||||
|
hashSuffixes map[Prefix][]suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new hash storage containing hashes of the domain names listed
|
||||||
|
// in hostnames, one domain name per line.
|
||||||
|
func New(hostnames string) (s *Storage, err error) {
|
||||||
|
s = &Storage{
|
||||||
|
mu: &sync.RWMutex{},
|
||||||
|
hashSuffixes: map[Prefix][]suffix{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if hostnames != "" {
|
||||||
|
_, err = s.Reset(hostnames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hashes returns all hashes starting with the given prefixes, if any. The
|
||||||
|
// resulting slice shares storage for all underlying strings.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): This currently doesn't take duplicates into account.
|
||||||
|
func (s *Storage) Hashes(hps []Prefix) (hashes []string) {
|
||||||
|
if len(hps) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
// First, calculate the number of hashes to allocate the buffer.
|
||||||
|
l := 0
|
||||||
|
for _, hp := range hps {
|
||||||
|
hashSufs := s.hashSuffixes[hp]
|
||||||
|
l += len(hashSufs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then, allocate the buffer of the appropriate size and write all hashes
|
||||||
|
// into one big buffer and slice it into separate strings to make the
|
||||||
|
// garbage collector's work easier. This assumes that all references to
|
||||||
|
// this buffer will become unreachable at the same time.
|
||||||
|
//
|
||||||
|
// The fact that we iterate over the [s.hashSuffixes] map twice shouldn't
|
||||||
|
// matter, since we assume that len(hps) will be below 5 most of the time.
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.Grow(l * hashEncLen)
|
||||||
|
|
||||||
|
// Use a buffer and write the resulting buffer into b directly instead of
|
||||||
|
// using hex.NewEncoder, because that seems to incur a significant
|
||||||
|
// performance hit.
|
||||||
|
var buf [hashEncLen]byte
|
||||||
|
for _, hp := range hps {
|
||||||
|
hashSufs := s.hashSuffixes[hp]
|
||||||
|
for _, suf := range hashSufs {
|
||||||
|
// Slicing is safe here, since the contents of hp and suf are being
|
||||||
|
// encoded.
|
||||||
|
|
||||||
|
// nolint:looppointer
|
||||||
|
hex.Encode(buf[:], hp[:])
|
||||||
|
// nolint:looppointer
|
||||||
|
hex.Encode(buf[PrefixEncLen:], suf[:])
|
||||||
|
_, _ = b.Write(buf[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
str := b.String()
|
||||||
|
hashes = make([]string, 0, l)
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
hashes = append(hashes, str[i*hashEncLen:(i+1)*hashEncLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns true if the host matches one of the hashes.
|
||||||
|
func (s *Storage) Matches(host string) (ok bool) {
|
||||||
|
sum := sha256.Sum256([]byte(host))
|
||||||
|
hp := *(*Prefix)(sum[:PrefixLen])
|
||||||
|
|
||||||
|
var buf [hashLen]byte
|
||||||
|
hashSufs, ok := s.loadHashSuffixes(hp)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(buf[:], hp[:])
|
||||||
|
for _, suf := range hashSufs {
|
||||||
|
// Slicing is safe here, because we make a copy.
|
||||||
|
|
||||||
|
// nolint:looppointer
|
||||||
|
copy(buf[PrefixLen:], suf[:])
|
||||||
|
if buf == sum {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the hosts in the index using the domain names listed in
|
||||||
|
// hostnames, one domain name per line, and returns the total number of
|
||||||
|
// processed rules.
|
||||||
|
func (s *Storage) Reset(hostnames string) (n int, err error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Delete all elements without allocating a new map to save space and
|
||||||
|
// improve performance.
|
||||||
|
//
|
||||||
|
// This is optimized, see https://github.com/golang/go/issues/20138.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Use clear once golang/go#56351 is implemented.
|
||||||
|
for hp := range s.hashSuffixes {
|
||||||
|
delete(s.hashSuffixes, hp)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc := bufio.NewScanner(strings.NewReader(hostnames))
|
||||||
|
for sc.Scan() {
|
||||||
|
host := sc.Text()
|
||||||
|
if len(host) == 0 || host[0] == '#' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sha256.Sum256([]byte(host))
|
||||||
|
hp := *(*Prefix)(sum[:PrefixLen])
|
||||||
|
|
||||||
|
// TODO(a.garipov): Here and everywhere, convert to array directly when
|
||||||
|
// proposal golang/go#46505 is implemented in Go 1.20.
|
||||||
|
suf := *(*suffix)(sum[PrefixLen:])
|
||||||
|
s.hashSuffixes[hp] = append(s.hashSuffixes[hp], suf)
|
||||||
|
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sc.Err()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("scanning hosts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
|
||||||
|
// concurrent use. sufs must not be modified.
|
||||||
|
func (s *Storage) loadHashSuffixes(hp Prefix) (sufs []suffix, ok bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
sufs, ok = s.hashSuffixes[hp]
|
||||||
|
|
||||||
|
return sufs, ok
|
||||||
|
}
|
142
internal/filter/hashstorage/hashstorage_test.go
Normal file
142
internal/filter/hashstorage/hashstorage_test.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package hashstorage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common hostnames for tests.
|
||||||
|
const (
|
||||||
|
testHost = "porn.example"
|
||||||
|
otherHost = "otherporn.example"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStorage_Hashes(t *testing.T) {
|
||||||
|
s, err := hashstorage.New(testHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(testHost))
|
||||||
|
want := []string{hex.EncodeToString(h[:])}
|
||||||
|
|
||||||
|
p := hashstorage.Prefix{h[0], h[1]}
|
||||||
|
got := s.Hashes([]hashstorage.Prefix{p})
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
|
||||||
|
wrong := s.Hashes([]hashstorage.Prefix{{}})
|
||||||
|
assert.Empty(t, wrong)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Matches(t *testing.T) {
|
||||||
|
s, err := hashstorage.New(testHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got := s.Matches(testHost)
|
||||||
|
assert.True(t, got)
|
||||||
|
|
||||||
|
got = s.Matches(otherHost)
|
||||||
|
assert.False(t, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Reset(t *testing.T) {
|
||||||
|
s, err := hashstorage.New(testHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err := s.Reset(otherHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, n)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(otherHost))
|
||||||
|
want := []string{hex.EncodeToString(h[:])}
|
||||||
|
|
||||||
|
p := hashstorage.Prefix{h[0], h[1]}
|
||||||
|
got := s.Hashes([]hashstorage.Prefix{p})
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
|
||||||
|
prevHash := sha256.Sum256([]byte(testHost))
|
||||||
|
prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}})
|
||||||
|
assert.Empty(t, prev)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sinks for benchmarks.
|
||||||
|
var (
|
||||||
|
errSink error
|
||||||
|
strsSink []string
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkStorage_Hashes(b *testing.B) {
|
||||||
|
const N = 10_000
|
||||||
|
|
||||||
|
var hosts []string
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := hashstorage.New(strings.Join(hosts, "\n"))
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
var hashPrefixes []hashstorage.Prefix
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]})
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 1; n <= 4; n++ {
|
||||||
|
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
|
||||||
|
hps := hashPrefixes[:n]
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
strsSink = s.Hashes(hps)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||||
|
//
|
||||||
|
// goos: linux
|
||||||
|
// goarch: amd64
|
||||||
|
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
|
||||||
|
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||||
|
// BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkStorage_Hashes/3-16 13492526 92.22 ns/op 0 B/op 0 allocs/op
|
||||||
|
// BenchmarkStorage_Hashes/4-16 9542425 109.2 ns/op 0 B/op 0 allocs/op
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStorage_ResetHosts(b *testing.B) {
|
||||||
|
const N = 1_000
|
||||||
|
|
||||||
|
var hosts []string
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
|
||||||
|
}
|
||||||
|
|
||||||
|
hostnames := strings.Join(hosts, "\n")
|
||||||
|
s, err := hashstorage.New(hostnames)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, errSink = s.Reset(hostnames)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(b, errSink)
|
||||||
|
|
||||||
|
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||||
|
//
|
||||||
|
// goos: linux
|
||||||
|
// goarch: amd64
|
||||||
|
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
|
||||||
|
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||||
|
// BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op
|
||||||
|
}
|
@ -1,91 +0,0 @@
|
|||||||
package filter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
var strsSink []string
|
|
||||||
|
|
||||||
func BenchmarkHashStorage_Hashes(b *testing.B) {
|
|
||||||
const N = 10_000
|
|
||||||
|
|
||||||
var hosts []string
|
|
||||||
for i := 0; i < N; i++ {
|
|
||||||
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't use a constructor, since we don't need the whole contents of the
|
|
||||||
// storage.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Think of a better way to do this.
|
|
||||||
hs := &HashStorage{
|
|
||||||
mu: &sync.RWMutex{},
|
|
||||||
hashSuffixes: map[hashPrefix][]hashSuffix{},
|
|
||||||
refr: &refreshableFilter{id: "test_filter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := hs.resetHosts(strings.Join(hosts, "\n"))
|
|
||||||
require.NoError(b, err)
|
|
||||||
|
|
||||||
var hashPrefixes []hashPrefix
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
hashPrefixes = append(hashPrefixes, hashPrefix{hosts[i][0], hosts[i][1]})
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 1; n <= 4; n++ {
|
|
||||||
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
|
|
||||||
hps := hashPrefixes[:n]
|
|
||||||
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
strsSink = hs.hashes(hps)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkHashStorage_resetHosts(b *testing.B) {
|
|
||||||
const N = 1_000
|
|
||||||
|
|
||||||
var hosts []string
|
|
||||||
for i := 0; i < N; i++ {
|
|
||||||
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't use a constructor, since we don't need the whole contents of the
|
|
||||||
// storage.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Think of a better way to do this.
|
|
||||||
hs := &HashStorage{
|
|
||||||
mu: &sync.RWMutex{},
|
|
||||||
hashSuffixes: map[hashPrefix][]hashSuffix{},
|
|
||||||
refr: &refreshableFilter{id: "test_filter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset them once to fill the initial map.
|
|
||||||
err := hs.resetHosts(strings.Join(hosts, "\n"))
|
|
||||||
require.NoError(b, err)
|
|
||||||
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
err = hs.resetHosts(strings.Join(hosts, "\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(b, err)
|
|
||||||
|
|
||||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
|
||||||
//
|
|
||||||
// goos: linux
|
|
||||||
// goarch: amd64
|
|
||||||
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter
|
|
||||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
|
||||||
// BenchmarkHashStorage_resetHosts-16 1404 829505 ns/op 58289 B/op 1011 allocs/op
|
|
||||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/mathutil"
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -93,13 +94,9 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
|
|||||||
// Save on allocations by reusing a buffer.
|
// Save on allocations by reusing a buffer.
|
||||||
var buf [3]byte
|
var buf [3]byte
|
||||||
binary.LittleEndian.PutUint16(buf[:2], qt)
|
binary.LittleEndian.PutUint16(buf[:2], qt)
|
||||||
if isAns {
|
buf[2] = mathutil.BoolToNumber[byte](isAns)
|
||||||
buf[2] = 1
|
|
||||||
} else {
|
|
||||||
buf[2] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = h.Write(buf[:3])
|
_, _ = h.Write(buf[:])
|
||||||
|
|
||||||
return Key(h.Sum64())
|
return Key(h.Sum64())
|
||||||
}
|
}
|
||||||
|
@ -45,9 +45,8 @@ type refreshableFilter struct {
|
|||||||
// typ is the type of this filter used for logging and error reporting.
|
// typ is the type of this filter used for logging and error reporting.
|
||||||
typ string
|
typ string
|
||||||
|
|
||||||
// refreshIvl is the refresh interval for this filter. It is also used to
|
// staleness is the time after which a file is considered stale.
|
||||||
// check if the cached file is fresh enough.
|
staleness time.Duration
|
||||||
refreshIvl time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// refresh reloads the filter data. If acceptStale is true, refresh doesn't try
|
// refresh reloads the filter data. If acceptStale is true, refresh doesn't try
|
||||||
@ -118,7 +117,7 @@ func (f *refreshableFilter) refreshFromFile(
|
|||||||
return "", fmt.Errorf("reading filter file stat: %w", err)
|
return "", fmt.Errorf("reading filter file stat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mtime := fi.ModTime(); !mtime.Add(f.refreshIvl).After(time.Now()) {
|
if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,43 +30,43 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
cachePath string
|
cachePath string
|
||||||
wantText string
|
wantText string
|
||||||
refreshIvl time.Duration
|
staleness time.Duration
|
||||||
acceptStale bool
|
acceptStale bool
|
||||||
}{{
|
}{{
|
||||||
name: "no_file",
|
name: "no_file",
|
||||||
cachePath: "does_not_exist",
|
cachePath: "does_not_exist",
|
||||||
wantText: "",
|
wantText: "",
|
||||||
refreshIvl: 0,
|
staleness: 0,
|
||||||
acceptStale: true,
|
acceptStale: true,
|
||||||
}, {
|
}, {
|
||||||
name: "file",
|
name: "file",
|
||||||
cachePath: cachePath,
|
cachePath: cachePath,
|
||||||
wantText: defaultText,
|
wantText: defaultText,
|
||||||
refreshIvl: 0,
|
staleness: 0,
|
||||||
acceptStale: true,
|
acceptStale: true,
|
||||||
}, {
|
}, {
|
||||||
name: "file_stale",
|
name: "file_stale",
|
||||||
cachePath: cachePath,
|
cachePath: cachePath,
|
||||||
wantText: "",
|
wantText: "",
|
||||||
refreshIvl: -1 * time.Second,
|
staleness: -1 * time.Second,
|
||||||
acceptStale: false,
|
acceptStale: false,
|
||||||
}, {
|
}, {
|
||||||
name: "file_stale_accept",
|
name: "file_stale_accept",
|
||||||
cachePath: cachePath,
|
cachePath: cachePath,
|
||||||
wantText: defaultText,
|
wantText: defaultText,
|
||||||
refreshIvl: -1 * time.Second,
|
staleness: -1 * time.Second,
|
||||||
acceptStale: true,
|
acceptStale: true,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
f := &refreshableFilter{
|
f := &refreshableFilter{
|
||||||
http: nil,
|
http: nil,
|
||||||
url: nil,
|
url: nil,
|
||||||
id: "test_filter",
|
id: "test_filter",
|
||||||
cachePath: tc.cachePath,
|
cachePath: tc.cachePath,
|
||||||
typ: "test filter",
|
typ: "test filter",
|
||||||
refreshIvl: tc.refreshIvl,
|
staleness: tc.staleness,
|
||||||
}
|
}
|
||||||
|
|
||||||
var text string
|
var text string
|
||||||
@ -161,12 +161,12 @@ func TestRefreshableFilter_RefreshFromURL(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
f := &refreshableFilter{
|
f := &refreshableFilter{
|
||||||
http: httpCli,
|
http: httpCli,
|
||||||
url: u,
|
url: u,
|
||||||
id: "test_filter",
|
id: "test_filter",
|
||||||
cachePath: tc.cachePath,
|
cachePath: tc.cachePath,
|
||||||
typ: "test filter",
|
typ: "test filter",
|
||||||
refreshIvl: testTimeout,
|
staleness: testTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.expectReq {
|
if tc.expectReq {
|
||||||
|
@ -72,13 +72,13 @@ func newRuleListFilter(
|
|||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
refr: &refreshableFilter{
|
refr: &refreshableFilter{
|
||||||
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultFilterRefreshTimeout,
|
||||||
}),
|
}),
|
||||||
url: l.URL,
|
url: l.URL,
|
||||||
id: l.ID,
|
id: l.ID,
|
||||||
cachePath: filepath.Join(fileCacheDir, string(l.ID)),
|
cachePath: filepath.Join(fileCacheDir, string(l.ID)),
|
||||||
typ: "rule list",
|
typ: "rule list",
|
||||||
refreshIvl: l.RefreshIvl,
|
staleness: l.RefreshIvl,
|
||||||
},
|
},
|
||||||
urlFilterID: newURLFilterID(),
|
urlFilterID: newURLFilterID(),
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,13 @@ package filter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Safe Browsing TXT Record Server
|
// Safe Browsing TXT Record Server
|
||||||
@ -14,12 +18,12 @@ import (
|
|||||||
//
|
//
|
||||||
// TODO(a.garipov): Consider making an interface to simplify testing.
|
// TODO(a.garipov): Consider making an interface to simplify testing.
|
||||||
type SafeBrowsingServer struct {
|
type SafeBrowsingServer struct {
|
||||||
generalHashes *HashStorage
|
generalHashes *hashstorage.Storage
|
||||||
adultBlockingHashes *HashStorage
|
adultBlockingHashes *hashstorage.Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSafeBrowsingServer returns a new safe browsing DNS server.
|
// NewSafeBrowsingServer returns a new safe browsing DNS server.
|
||||||
func NewSafeBrowsingServer(general, adultBlocking *HashStorage) (f *SafeBrowsingServer) {
|
func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) {
|
||||||
return &SafeBrowsingServer{
|
return &SafeBrowsingServer{
|
||||||
generalHashes: general,
|
generalHashes: general,
|
||||||
adultBlockingHashes: adultBlocking,
|
adultBlockingHashes: adultBlocking,
|
||||||
@ -48,8 +52,7 @@ func (srv *SafeBrowsingServer) Hashes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var prefixesStr string
|
var prefixesStr string
|
||||||
var strg *HashStorage
|
var strg *hashstorage.Storage
|
||||||
|
|
||||||
if strings.HasSuffix(host, GeneralTXTSuffix) {
|
if strings.HasSuffix(host, GeneralTXTSuffix) {
|
||||||
prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
|
prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
|
||||||
strg = srv.generalHashes
|
strg = srv.generalHashes
|
||||||
@ -67,5 +70,45 @@ func (srv *SafeBrowsingServer) Hashes(
|
|||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return strg.hashes(hashPrefixes), true, nil
|
return strg.Hashes(hashPrefixes), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// legacyPrefixEncLen is the encoded length of a legacy hash.
|
||||||
|
const legacyPrefixEncLen = 8
|
||||||
|
|
||||||
|
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
|
||||||
|
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) {
|
||||||
|
if prefixesStr == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixSet := stringutil.NewSet()
|
||||||
|
prefixStrs := strings.Split(prefixesStr, ".")
|
||||||
|
for _, s := range prefixStrs {
|
||||||
|
if len(s) != hashstorage.PrefixEncLen {
|
||||||
|
// Some legacy clients send eight-character hashes instead of
|
||||||
|
// four-character ones. For now, remove the final four characters.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Either remove this crutch or support such
|
||||||
|
// prefixes better.
|
||||||
|
if len(s) == legacyPrefixEncLen {
|
||||||
|
s = s[:hashstorage.PrefixEncLen]
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("bad hash len for %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixSet.Add(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len())
|
||||||
|
prefixStrs = prefixSet.Values()
|
||||||
|
for i, s := range prefixStrs {
|
||||||
|
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bad hash encoding for %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashPrefixes, nil
|
||||||
}
|
}
|
||||||
|
@ -4,17 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -43,41 +37,10 @@ func TestSafeBrowsingServer(t *testing.T) {
|
|||||||
hashStrs[i] = hex.EncodeToString(sum[:])
|
hashStrs[i] = hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash Storage
|
hashes, err := hashstorage.New(strings.Join(hosts, "\n"))
|
||||||
|
|
||||||
errColl := &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(_ context.Context, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheDir := t.TempDir()
|
|
||||||
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
|
|
||||||
err := os.WriteFile(cachePath, []byte(strings.Join(hosts, "\n")), 0o644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
|
|
||||||
URL: &url.URL{},
|
|
||||||
ErrColl: errColl,
|
|
||||||
ID: agd.FilterListIDSafeBrowsing,
|
|
||||||
CachePath: cachePath,
|
|
||||||
RefreshIvl: testRefreshIvl,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
err = hashes.Start()
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
|
||||||
return hashes.Shutdown(ctx)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Give the storage some time to process the hashes.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Think of a less stupid way of doing this.
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
@ -90,14 +53,14 @@ func TestSafeBrowsingServer(t *testing.T) {
|
|||||||
wantMatched: false,
|
wantMatched: false,
|
||||||
}, {
|
}, {
|
||||||
name: "realistic",
|
name: "realistic",
|
||||||
host: hashStrs[realisticHostIdx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
|
host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
|
||||||
wantHashStrs: []string{
|
wantHashStrs: []string{
|
||||||
hashStrs[realisticHostIdx],
|
hashStrs[realisticHostIdx],
|
||||||
},
|
},
|
||||||
wantMatched: true,
|
wantMatched: true,
|
||||||
}, {
|
}, {
|
||||||
name: "same_prefix",
|
name: "same_prefix",
|
||||||
host: hashStrs[samePrefixHost1Idx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
|
host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
|
||||||
wantHashStrs: []string{
|
wantHashStrs: []string{
|
||||||
hashStrs[samePrefixHost1Idx],
|
hashStrs[samePrefixHost1Idx],
|
||||||
hashStrs[samePrefixHost2Idx],
|
hashStrs[samePrefixHost2Idx],
|
||||||
|
@ -3,9 +3,7 @@ package filter_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||||
@ -20,45 +18,29 @@ import (
|
|||||||
|
|
||||||
func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
|
func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
|
||||||
numLookupIP := 0
|
numLookupIP := 0
|
||||||
onLookupIP := func(
|
resolver := &agdtest.Resolver{
|
||||||
_ context.Context,
|
OnLookupIP: func(
|
||||||
fam netutil.AddrFamily,
|
_ context.Context,
|
||||||
_ string,
|
fam netutil.AddrFamily,
|
||||||
) (ips []net.IP, err error) {
|
_ string,
|
||||||
numLookupIP++
|
) (ips []net.IP, err error) {
|
||||||
|
numLookupIP++
|
||||||
|
|
||||||
if fam == netutil.AddrFamilyIPv4 {
|
if fam == netutil.AddrFamilyIPv4 {
|
||||||
return []net.IP{safeSearchIPRespIP4}, nil
|
return []net.IP{safeSearchIPRespIP4}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return []net.IP{safeSearchIPRespIP6}, nil
|
return []net.IP{safeSearchIPRespIP6}, nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpFile, err := os.CreateTemp(t.TempDir(), "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = tmpFile.Write([]byte("bad.example.com\n"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
|
|
||||||
|
|
||||||
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
|
|
||||||
CachePath: tmpFile.Name(),
|
|
||||||
RefreshIvl: 1 * time.Hour,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c := prepareConf(t)
|
c := prepareConf(t)
|
||||||
|
|
||||||
c.SafeBrowsing.Hashes = hashes
|
|
||||||
c.AdultBlocking.Hashes = hashes
|
|
||||||
|
|
||||||
c.ErrColl = &agdtest.ErrorCollector{
|
c.ErrColl = &agdtest.ErrorCollector{
|
||||||
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Resolver = &agdtest.Resolver{
|
c.Resolver = resolver
|
||||||
OnLookupIP: onLookupIP,
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := filter.NewDefaultStorage(c)
|
s, err := filter.NewDefaultStorage(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -80,13 +62,13 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
|
|||||||
host: safeSearchIPHost,
|
host: safeSearchIPHost,
|
||||||
wantIP: safeSearchIPRespIP4,
|
wantIP: safeSearchIPRespIP4,
|
||||||
rrtype: dns.TypeA,
|
rrtype: dns.TypeA,
|
||||||
wantLookups: 0,
|
wantLookups: 1,
|
||||||
}, {
|
}, {
|
||||||
name: "ip6",
|
name: "ip6",
|
||||||
host: safeSearchIPHost,
|
host: safeSearchIPHost,
|
||||||
wantIP: nil,
|
wantIP: safeSearchIPRespIP6,
|
||||||
rrtype: dns.TypeAAAA,
|
rrtype: dns.TypeAAAA,
|
||||||
wantLookups: 0,
|
wantLookups: 1,
|
||||||
}, {
|
}, {
|
||||||
name: "host_ip4",
|
name: "host_ip4",
|
||||||
host: safeSearchHost,
|
host: safeSearchHost,
|
||||||
|
@ -48,7 +48,7 @@ func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *servic
|
|||||||
return &serviceBlocker{
|
return &serviceBlocker{
|
||||||
url: indexURL,
|
url: indexURL,
|
||||||
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultFilterRefreshTimeout,
|
||||||
}),
|
}),
|
||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
errColl: errColl,
|
errColl: errColl,
|
||||||
|
@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
@ -55,10 +54,10 @@ type DefaultStorage struct {
|
|||||||
services *serviceBlocker
|
services *serviceBlocker
|
||||||
|
|
||||||
// safeBrowsing is the general safe browsing filter.
|
// safeBrowsing is the general safe browsing filter.
|
||||||
safeBrowsing *hashPrefixFilter
|
safeBrowsing *HashPrefix
|
||||||
|
|
||||||
// adultBlocking is the adult content blocking safe browsing filter.
|
// adultBlocking is the adult content blocking safe browsing filter.
|
||||||
adultBlocking *hashPrefixFilter
|
adultBlocking *HashPrefix
|
||||||
|
|
||||||
// genSafeSearch is the general safe search filter.
|
// genSafeSearch is the general safe search filter.
|
||||||
genSafeSearch *safeSearch
|
genSafeSearch *safeSearch
|
||||||
@ -111,11 +110,11 @@ type DefaultStorageConfig struct {
|
|||||||
|
|
||||||
// SafeBrowsing is the configuration for the default safe browsing filter.
|
// SafeBrowsing is the configuration for the default safe browsing filter.
|
||||||
// It must not be nil.
|
// It must not be nil.
|
||||||
SafeBrowsing *HashPrefixConfig
|
SafeBrowsing *HashPrefix
|
||||||
|
|
||||||
// AdultBlocking is the configuration for the adult content blocking safe
|
// AdultBlocking is the configuration for the adult content blocking safe
|
||||||
// browsing filter. It must not be nil.
|
// browsing filter. It must not be nil.
|
||||||
AdultBlocking *HashPrefixConfig
|
AdultBlocking *HashPrefix
|
||||||
|
|
||||||
// Now is a function that returns current time.
|
// Now is a function that returns current time.
|
||||||
Now func() (now time.Time)
|
Now func() (now time.Time)
|
||||||
@ -156,25 +155,8 @@ type DefaultStorageConfig struct {
|
|||||||
|
|
||||||
// NewDefaultStorage returns a new filter storage. c must not be nil.
|
// NewDefaultStorage returns a new filter storage. c must not be nil.
|
||||||
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
|
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
|
||||||
// TODO(ameshkov): Consider making configurable.
|
|
||||||
resolver := agdnet.NewCachingResolver(c.Resolver, 1*timeutil.Day)
|
|
||||||
|
|
||||||
safeBrowsing := newHashPrefixFilter(
|
|
||||||
c.SafeBrowsing,
|
|
||||||
resolver,
|
|
||||||
c.ErrColl,
|
|
||||||
agd.FilterListIDSafeBrowsing,
|
|
||||||
)
|
|
||||||
|
|
||||||
adultBlocking := newHashPrefixFilter(
|
|
||||||
c.AdultBlocking,
|
|
||||||
resolver,
|
|
||||||
c.ErrColl,
|
|
||||||
agd.FilterListIDAdultBlocking,
|
|
||||||
)
|
|
||||||
|
|
||||||
genSafeSearch := newSafeSearch(&safeSearchConfig{
|
genSafeSearch := newSafeSearch(&safeSearchConfig{
|
||||||
resolver: resolver,
|
resolver: c.Resolver,
|
||||||
errColl: c.ErrColl,
|
errColl: c.ErrColl,
|
||||||
list: &agd.FilterList{
|
list: &agd.FilterList{
|
||||||
URL: c.GeneralSafeSearchRulesURL,
|
URL: c.GeneralSafeSearchRulesURL,
|
||||||
@ -187,7 +169,7 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ytSafeSearch := newSafeSearch(&safeSearchConfig{
|
ytSafeSearch := newSafeSearch(&safeSearchConfig{
|
||||||
resolver: resolver,
|
resolver: c.Resolver,
|
||||||
errColl: c.ErrColl,
|
errColl: c.ErrColl,
|
||||||
list: &agd.FilterList{
|
list: &agd.FilterList{
|
||||||
URL: c.YoutubeSafeSearchRulesURL,
|
URL: c.YoutubeSafeSearchRulesURL,
|
||||||
@ -203,11 +185,11 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
|
|||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
url: c.FilterIndexURL,
|
url: c.FilterIndexURL,
|
||||||
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
http: agdhttp.NewClient(&agdhttp.ClientConfig{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultFilterRefreshTimeout,
|
||||||
}),
|
}),
|
||||||
services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
|
services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
|
||||||
safeBrowsing: safeBrowsing,
|
safeBrowsing: c.SafeBrowsing,
|
||||||
adultBlocking: adultBlocking,
|
adultBlocking: c.AdultBlocking,
|
||||||
genSafeSearch: genSafeSearch,
|
genSafeSearch: genSafeSearch,
|
||||||
ytSafeSearch: ytSafeSearch,
|
ytSafeSearch: ytSafeSearch,
|
||||||
now: c.Now,
|
now: c.Now,
|
||||||
@ -322,7 +304,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b
|
|||||||
func (s *DefaultStorage) safeBrowsingForProfile(
|
func (s *DefaultStorage) safeBrowsingForProfile(
|
||||||
p *agd.Profile,
|
p *agd.Profile,
|
||||||
parentalEnabled bool,
|
parentalEnabled bool,
|
||||||
) (safeBrowsing, adultBlocking *hashPrefixFilter) {
|
) (safeBrowsing, adultBlocking *HashPrefix) {
|
||||||
if p.SafeBrowsingEnabled {
|
if p.SafeBrowsingEnabled {
|
||||||
safeBrowsing = s.safeBrowsing
|
safeBrowsing = s.safeBrowsing
|
||||||
}
|
}
|
||||||
@ -359,7 +341,7 @@ func (s *DefaultStorage) safeSearchForProfile(
|
|||||||
// in the filtering group. g must not be nil.
|
// in the filtering group. g must not be nil.
|
||||||
func (s *DefaultStorage) safeBrowsingForGroup(
|
func (s *DefaultStorage) safeBrowsingForGroup(
|
||||||
g *agd.FilteringGroup,
|
g *agd.FilteringGroup,
|
||||||
) (safeBrowsing, adultBlocking *hashPrefixFilter) {
|
) (safeBrowsing, adultBlocking *HashPrefix) {
|
||||||
if g.SafeBrowsingEnabled {
|
if g.SafeBrowsingEnabled {
|
||||||
safeBrowsing = s.safeBrowsing
|
safeBrowsing = s.safeBrowsing
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -124,34 +125,11 @@ func TestStorage_FilterFromContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStorage_FilterFromContext_customAllow(t *testing.T) {
|
func TestStorage_FilterFromContext_customAllow(t *testing.T) {
|
||||||
// Initialize the hashes file and use it with the storage.
|
errColl := &agdtest.ErrorCollector{
|
||||||
tmpFile, err := os.CreateTemp(t.TempDir(), "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
|
|
||||||
|
|
||||||
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
|
|
||||||
CachePath: tmpFile.Name(),
|
|
||||||
RefreshIvl: 1 * time.Hour,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c := prepareConf(t)
|
|
||||||
|
|
||||||
c.SafeBrowsing = &filter.HashPrefixConfig{
|
|
||||||
Hashes: hashes,
|
|
||||||
ReplacementHost: safeBrowsingSafeHost,
|
|
||||||
CacheTTL: 10 * time.Second,
|
|
||||||
CacheSize: 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ErrColl = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Resolver = &agdtest.Resolver{
|
resolver := &agdtest.Resolver{
|
||||||
OnLookupIP: func(
|
OnLookupIP: func(
|
||||||
_ context.Context,
|
_ context.Context,
|
||||||
_ netutil.AddrFamily,
|
_ netutil.AddrFamily,
|
||||||
@ -161,6 +139,35 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the hashes file and use it with the storage.
|
||||||
|
tmpFile, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
|
||||||
|
|
||||||
|
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hashes, err := hashstorage.New(safeBrowsingHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c := prepareConf(t)
|
||||||
|
|
||||||
|
c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
|
||||||
|
Hashes: hashes,
|
||||||
|
ErrColl: errColl,
|
||||||
|
Resolver: resolver,
|
||||||
|
ID: agd.FilterListIDSafeBrowsing,
|
||||||
|
CachePath: tmpFile.Name(),
|
||||||
|
ReplacementHost: safeBrowsingSafeHost,
|
||||||
|
Staleness: 1 * time.Hour,
|
||||||
|
CacheTTL: 10 * time.Second,
|
||||||
|
CacheSize: 100,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c.ErrColl = errColl
|
||||||
|
c.Resolver = resolver
|
||||||
|
|
||||||
s, err := filter.NewDefaultStorage(c)
|
s, err := filter.NewDefaultStorage(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -211,39 +218,11 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
|
|||||||
// parental protection from 11:00:00 until 12:59:59.
|
// parental protection from 11:00:00 until 12:59:59.
|
||||||
nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC)
|
nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
// Initialize the hashes file and use it with the storage.
|
errColl := &agdtest.ErrorCollector{
|
||||||
tmpFile, err := os.CreateTemp(t.TempDir(), "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
|
|
||||||
|
|
||||||
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
|
|
||||||
CachePath: tmpFile.Name(),
|
|
||||||
RefreshIvl: 1 * time.Hour,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c := prepareConf(t)
|
|
||||||
|
|
||||||
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
|
|
||||||
c.AdultBlocking = &filter.HashPrefixConfig{
|
|
||||||
Hashes: hashes,
|
|
||||||
ReplacementHost: safeBrowsingSafeHost,
|
|
||||||
CacheTTL: 10 * time.Second,
|
|
||||||
CacheSize: 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Now = func() (t time.Time) {
|
|
||||||
return nowTime
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ErrColl = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Resolver = &agdtest.Resolver{
|
resolver := &agdtest.Resolver{
|
||||||
OnLookupIP: func(
|
OnLookupIP: func(
|
||||||
_ context.Context,
|
_ context.Context,
|
||||||
_ netutil.AddrFamily,
|
_ netutil.AddrFamily,
|
||||||
@ -253,6 +232,40 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the hashes file and use it with the storage.
|
||||||
|
tmpFile, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
|
||||||
|
|
||||||
|
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hashes, err := hashstorage.New(safeBrowsingHost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c := prepareConf(t)
|
||||||
|
|
||||||
|
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
|
||||||
|
c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
|
||||||
|
Hashes: hashes,
|
||||||
|
ErrColl: errColl,
|
||||||
|
Resolver: resolver,
|
||||||
|
ID: agd.FilterListIDAdultBlocking,
|
||||||
|
CachePath: tmpFile.Name(),
|
||||||
|
ReplacementHost: safeBrowsingSafeHost,
|
||||||
|
Staleness: 1 * time.Hour,
|
||||||
|
CacheTTL: 10 * time.Second,
|
||||||
|
CacheSize: 100,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c.Now = func() (t time.Time) {
|
||||||
|
return nowTime
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ErrColl = errColl
|
||||||
|
c.Resolver = resolver
|
||||||
|
|
||||||
s, err := filter.NewDefaultStorage(c)
|
s, err := filter.NewDefaultStorage(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -20,9 +20,6 @@ import (
|
|||||||
|
|
||||||
// FileConfig is the file-based GeoIP configuration structure.
|
// FileConfig is the file-based GeoIP configuration structure.
|
||||||
type FileConfig struct {
|
type FileConfig struct {
|
||||||
// ErrColl is the error collector that is used to report errors.
|
|
||||||
ErrColl agd.ErrorCollector
|
|
||||||
|
|
||||||
// ASNPath is the path to the GeoIP database of ASNs.
|
// ASNPath is the path to the GeoIP database of ASNs.
|
||||||
ASNPath string
|
ASNPath string
|
||||||
|
|
||||||
@ -39,8 +36,6 @@ type FileConfig struct {
|
|||||||
|
|
||||||
// File is a file implementation of [geoip.Interface].
|
// File is a file implementation of [geoip.Interface].
|
||||||
type File struct {
|
type File struct {
|
||||||
errColl agd.ErrorCollector
|
|
||||||
|
|
||||||
// mu protects asn, country, country subnet maps, and caches against
|
// mu protects asn, country, country subnet maps, and caches against
|
||||||
// simultaneous access during a refresh.
|
// simultaneous access during a refresh.
|
||||||
mu *sync.RWMutex
|
mu *sync.RWMutex
|
||||||
@ -77,8 +72,6 @@ type asnSubnets map[agd.ASN]netip.Prefix
|
|||||||
// NewFile returns a new GeoIP database that reads information from a file.
|
// NewFile returns a new GeoIP database that reads information from a file.
|
||||||
func NewFile(c *FileConfig) (f *File, err error) {
|
func NewFile(c *FileConfig) (f *File, err error) {
|
||||||
f = &File{
|
f = &File{
|
||||||
errColl: c.ErrColl,
|
|
||||||
|
|
||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
|
|
||||||
asnPath: c.ASNPath,
|
asnPath: c.ASNPath,
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
package geoip_test
|
package geoip_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -14,12 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestFile_Data(t *testing.T) {
|
func TestFile_Data(t *testing.T) {
|
||||||
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &geoip.FileConfig{
|
conf := &geoip.FileConfig{
|
||||||
ErrColl: ec,
|
|
||||||
ASNPath: asnPath,
|
ASNPath: asnPath,
|
||||||
CountryPath: countryPath,
|
CountryPath: countryPath,
|
||||||
HostCacheSize: 0,
|
HostCacheSize: 0,
|
||||||
@ -42,12 +35,7 @@ func TestFile_Data(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFile_Data_hostCache(t *testing.T) {
|
func TestFile_Data_hostCache(t *testing.T) {
|
||||||
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &geoip.FileConfig{
|
conf := &geoip.FileConfig{
|
||||||
ErrColl: ec,
|
|
||||||
ASNPath: asnPath,
|
ASNPath: asnPath,
|
||||||
CountryPath: countryPath,
|
CountryPath: countryPath,
|
||||||
HostCacheSize: 1,
|
HostCacheSize: 1,
|
||||||
@ -74,12 +62,7 @@ func TestFile_Data_hostCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFile_SubnetByLocation(t *testing.T) {
|
func TestFile_SubnetByLocation(t *testing.T) {
|
||||||
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &geoip.FileConfig{
|
conf := &geoip.FileConfig{
|
||||||
ErrColl: ec,
|
|
||||||
ASNPath: asnPath,
|
ASNPath: asnPath,
|
||||||
CountryPath: countryPath,
|
CountryPath: countryPath,
|
||||||
HostCacheSize: 0,
|
HostCacheSize: 0,
|
||||||
@ -106,12 +89,7 @@ var locSink *agd.Location
|
|||||||
var errSink error
|
var errSink error
|
||||||
|
|
||||||
func BenchmarkFile_Data(b *testing.B) {
|
func BenchmarkFile_Data(b *testing.B) {
|
||||||
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &geoip.FileConfig{
|
conf := &geoip.FileConfig{
|
||||||
ErrColl: ec,
|
|
||||||
ASNPath: asnPath,
|
ASNPath: asnPath,
|
||||||
CountryPath: countryPath,
|
CountryPath: countryPath,
|
||||||
HostCacheSize: 0,
|
HostCacheSize: 0,
|
||||||
@ -162,12 +140,7 @@ func BenchmarkFile_Data(b *testing.B) {
|
|||||||
var fileSink *geoip.File
|
var fileSink *geoip.File
|
||||||
|
|
||||||
func BenchmarkNewFile(b *testing.B) {
|
func BenchmarkNewFile(b *testing.B) {
|
||||||
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
|
|
||||||
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &geoip.FileConfig{
|
conf := &geoip.FileConfig{
|
||||||
ErrColl: ec,
|
|
||||||
ASNPath: asnPath,
|
ASNPath: asnPath,
|
||||||
CountryPath: countryPath,
|
CountryPath: countryPath,
|
||||||
HostCacheSize: 0,
|
HostCacheSize: 0,
|
||||||
|
@ -11,7 +11,7 @@ var DNSSvcRequestByCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
|||||||
Name: "request_per_country_total",
|
Name: "request_per_country_total",
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
Subsystem: subsystemDNSSvc,
|
Subsystem: subsystemDNSSvc,
|
||||||
Help: "The number of filtered DNS requests labeled by country and continent.",
|
Help: "The number of processed DNS requests labeled by country and continent.",
|
||||||
}, []string{"continent", "country"})
|
}, []string{"continent", "country"})
|
||||||
|
|
||||||
// DNSSvcRequestByASNTotal is a counter with the total number of queries
|
// DNSSvcRequestByASNTotal is a counter with the total number of queries
|
||||||
@ -20,13 +20,14 @@ var DNSSvcRequestByASNTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
|||||||
Name: "request_per_asn_total",
|
Name: "request_per_asn_total",
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
Subsystem: subsystemDNSSvc,
|
Subsystem: subsystemDNSSvc,
|
||||||
Help: "The number of filtered DNS requests labeled by country and ASN.",
|
Help: "The number of processed DNS requests labeled by country and ASN.",
|
||||||
}, []string{"country", "asn"})
|
}, []string{"country", "asn"})
|
||||||
|
|
||||||
// DNSSvcRequestByFilterTotal is a counter with the total number of queries
|
// DNSSvcRequestByFilterTotal is a counter with the total number of queries
|
||||||
// processed labeled by filter. "filter" contains the ID of the filter list
|
// processed labeled by filter. Processed could mean that the request was
|
||||||
// applied. "anonymous" is "0" if the request is from a AdGuard DNS customer,
|
// blocked or unblocked by a rule from that filter list. "filter" contains
|
||||||
// otherwise it is "1".
|
// the ID of the filter list applied. "anonymous" is "0" if the request is
|
||||||
|
// from a AdGuard DNS customer, otherwise it is "1".
|
||||||
var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
Name: "request_per_filter_total",
|
Name: "request_per_filter_total",
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
|
@ -26,6 +26,8 @@ const (
|
|||||||
subsystemQueryLog = "querylog"
|
subsystemQueryLog = "querylog"
|
||||||
subsystemRuleStat = "rulestat"
|
subsystemRuleStat = "rulestat"
|
||||||
subsystemTLS = "tls"
|
subsystemTLS = "tls"
|
||||||
|
subsystemResearch = "research"
|
||||||
|
subsystemWebSvc = "websvc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetUpGauge signals that the server has been started.
|
// SetUpGauge signals that the server has been started.
|
||||||
|
61
internal/metrics/research.go
Normal file
61
internal/metrics/research.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResearchRequestsPerCountryTotal counts the total number of queries per
|
||||||
|
// country from anonymous users.
|
||||||
|
var ResearchRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "requests_per_country_total",
|
||||||
|
Namespace: namespace,
|
||||||
|
Subsystem: subsystemResearch,
|
||||||
|
Help: "The total number of DNS queries per country from anonymous users.",
|
||||||
|
}, []string{"country"})
|
||||||
|
|
||||||
|
// ResearchBlockedRequestsPerCountryTotal counts the number of blocked queries
|
||||||
|
// per country from anonymous users.
|
||||||
|
var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "blocked_per_country_total",
|
||||||
|
Namespace: namespace,
|
||||||
|
Subsystem: subsystemResearch,
|
||||||
|
Help: "The number of blocked DNS queries per country from anonymous users.",
|
||||||
|
}, []string{"filter", "country"})
|
||||||
|
|
||||||
|
// ReportResearchMetrics reports metrics to prometheus that we may need to
|
||||||
|
// conduct researches.
|
||||||
|
//
|
||||||
|
// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved.
|
||||||
|
func ReportResearchMetrics(
|
||||||
|
anonymous bool,
|
||||||
|
filteringEnabled bool,
|
||||||
|
asn string,
|
||||||
|
ctry string,
|
||||||
|
filterID string,
|
||||||
|
blocked bool,
|
||||||
|
) {
|
||||||
|
// The current research metrics only count queries that come to public
|
||||||
|
// DNS servers where filtering is enabled.
|
||||||
|
if !filteringEnabled || !anonymous {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore AdGuard ASN specifically in order to avoid counting queries that
|
||||||
|
// come from the monitoring. This part is ugly, but since these metrics
|
||||||
|
// are a one-time deal, this is acceptable.
|
||||||
|
//
|
||||||
|
// TODO(ameshkov): think of a better way later if we need to do that again.
|
||||||
|
if asn == "212772" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if blocked {
|
||||||
|
ResearchBlockedRequestsPerCountryTotal.WithLabelValues(
|
||||||
|
filterID,
|
||||||
|
ctry,
|
||||||
|
).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
|
||||||
|
}
|
@ -86,7 +86,7 @@ func (c *userCounter) record(now time.Time, ip netip.Addr, syncUpdate bool) {
|
|||||||
prevMinuteCounter := c.currentMinuteCounter
|
prevMinuteCounter := c.currentMinuteCounter
|
||||||
|
|
||||||
c.currentMinute = minuteOfTheDay
|
c.currentMinute = minuteOfTheDay
|
||||||
c.currentMinuteCounter = hyperloglog.New()
|
c.currentMinuteCounter = newHyperLogLog()
|
||||||
|
|
||||||
// If this is the first iteration and prevMinute is -1, don't update the
|
// If this is the first iteration and prevMinute is -1, don't update the
|
||||||
// counters, since there are none.
|
// counters, since there are none.
|
||||||
@ -123,7 +123,7 @@ func (c *userCounter) updateCounters(prevMinute int, prevCounter *hyperloglog.Sk
|
|||||||
// estimate uses HyperLogLog counters to estimate the hourly and daily users
|
// estimate uses HyperLogLog counters to estimate the hourly and daily users
|
||||||
// count, starting with the minute of the day m.
|
// count, starting with the minute of the day m.
|
||||||
func (c *userCounter) estimate(m int) (hourly, daily uint64) {
|
func (c *userCounter) estimate(m int) (hourly, daily uint64) {
|
||||||
hourlyCounter, dailyCounter := hyperloglog.New(), hyperloglog.New()
|
hourlyCounter, dailyCounter := newHyperLogLog(), newHyperLogLog()
|
||||||
|
|
||||||
// Go through all minutes in a day while decreasing the current minute m.
|
// Go through all minutes in a day while decreasing the current minute m.
|
||||||
// Decreasing m, as opposed to increasing it or using i as the minute, is
|
// Decreasing m, as opposed to increasing it or using i as the minute, is
|
||||||
@ -168,6 +168,11 @@ func decrMod(n, m int) (res int) {
|
|||||||
return n - 1
|
return n - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newHyperLogLog creates a new instance of hyperloglog.Sketch.
|
||||||
|
func newHyperLogLog() (sk *hyperloglog.Sketch) {
|
||||||
|
return hyperloglog.New16()
|
||||||
|
}
|
||||||
|
|
||||||
// defaultUserCounter is the main user statistics counter.
|
// defaultUserCounter is the main user statistics counter.
|
||||||
var defaultUserCounter = newUserCounter()
|
var defaultUserCounter = newUserCounter()
|
||||||
|
|
||||||
|
69
internal/metrics/websvc.go
Normal file
69
internal/metrics/websvc.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
webSvcRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "websvc_requests_total",
|
||||||
|
Namespace: namespace,
|
||||||
|
Subsystem: subsystemWebSvc,
|
||||||
|
Help: "The number of DNS requests for websvc.",
|
||||||
|
}, []string{"kind"})
|
||||||
|
|
||||||
|
// WebSvcError404RequestsTotal is a counter with total number of
|
||||||
|
// requests with error 404.
|
||||||
|
WebSvcError404RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "error404",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcError500RequestsTotal is a counter with total number of
|
||||||
|
// requests with error 500.
|
||||||
|
WebSvcError500RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "error500",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcStaticContentRequestsTotal is a counter with total number of
|
||||||
|
// requests for static content.
|
||||||
|
WebSvcStaticContentRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "static_content",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcDNSCheckTestRequestsTotal is a counter with total number of
|
||||||
|
// requests for dnscheck_test.
|
||||||
|
WebSvcDNSCheckTestRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "dnscheck_test",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcRobotsTxtRequestsTotal is a counter with total number of
|
||||||
|
// requests for robots_txt.
|
||||||
|
WebSvcRobotsTxtRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "robots_txt",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcRootRedirectRequestsTotal is a counter with total number of
|
||||||
|
// root redirected requests.
|
||||||
|
WebSvcRootRedirectRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "root_redirect",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcLinkedIPProxyRequestsTotal is a counter with total number of
|
||||||
|
// requests with linked ip.
|
||||||
|
WebSvcLinkedIPProxyRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "linkip",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcAdultBlockingPageRequestsTotal is a counter with total number
|
||||||
|
// of requests for adult blocking page.
|
||||||
|
WebSvcAdultBlockingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "adult_blocking_page",
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSvcSafeBrowsingPageRequestsTotal is a counter with total number
|
||||||
|
// of requests for safe browsing page.
|
||||||
|
WebSvcSafeBrowsingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
|
||||||
|
"kind": "safe_browsing_page",
|
||||||
|
})
|
||||||
|
)
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/mathutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FileSystemConfig is the configuration of the file system query log.
|
// FileSystemConfig is the configuration of the file system query log.
|
||||||
@ -71,11 +72,6 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
|
|||||||
metrics.QueryLogItemsCount.Inc()
|
metrics.QueryLogItemsCount.Inc()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var dnssec uint8 = 0
|
|
||||||
if e.DNSSEC {
|
|
||||||
dnssec = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
entBuf := l.bufferPool.Get().(*entryBuffer)
|
entBuf := l.bufferPool.Get().(*entryBuffer)
|
||||||
defer l.bufferPool.Put(entBuf)
|
defer l.bufferPool.Put(entBuf)
|
||||||
entBuf.buf.Reset()
|
entBuf.buf.Reset()
|
||||||
@ -94,7 +90,7 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
|
|||||||
ClientASN: e.ClientASN,
|
ClientASN: e.ClientASN,
|
||||||
Elapsed: e.Elapsed,
|
Elapsed: e.Elapsed,
|
||||||
RequestType: e.RequestType,
|
RequestType: e.RequestType,
|
||||||
DNSSEC: dnssec,
|
DNSSEC: mathutil.BoolToNumber[uint8](e.DNSSEC),
|
||||||
Protocol: e.Protocol,
|
Protocol: e.Protocol,
|
||||||
ResultCode: c,
|
ResultCode: c,
|
||||||
ResponseCode: e.ResponseCode,
|
ResponseCode: e.ResponseCode,
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
@ -58,12 +59,16 @@ func (svc *Service) processRec(
|
|||||||
body = svc.error404
|
body = svc.error404
|
||||||
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
|
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metrics.WebSvcError404RequestsTotal.Inc()
|
||||||
case http.StatusInternalServerError:
|
case http.StatusInternalServerError:
|
||||||
action = "writing 500"
|
action = "writing 500"
|
||||||
if len(svc.error500) != 0 {
|
if len(svc.error500) != 0 {
|
||||||
body = svc.error500
|
body = svc.error500
|
||||||
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
|
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metrics.WebSvcError500RequestsTotal.Inc()
|
||||||
default:
|
default:
|
||||||
action = "writing response"
|
action = "writing response"
|
||||||
for k, v := range rec.Header() {
|
for k, v := range rec.Header() {
|
||||||
@ -87,6 +92,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/dnscheck/test":
|
case "/dnscheck/test":
|
||||||
svc.dnsCheck.ServeHTTP(w, r)
|
svc.dnsCheck.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
metrics.WebSvcDNSCheckTestRequestsTotal.Inc()
|
||||||
case "/robots.txt":
|
case "/robots.txt":
|
||||||
serveRobotsDisallow(w.Header(), w, "handler")
|
serveRobotsDisallow(w.Header(), w, "handler")
|
||||||
case "/":
|
case "/":
|
||||||
@ -94,6 +101,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
} else {
|
} else {
|
||||||
http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound)
|
http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound)
|
||||||
|
|
||||||
|
metrics.WebSvcRootRedirectRequestsTotal.Inc()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
@ -101,7 +110,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// safeBrowsingHandler returns an HTTP handler serving the block page from the
|
// safeBrowsingHandler returns an HTTP handler serving the block page from the
|
||||||
// blockPagePath. name is used for logging.
|
// blockPagePath. name is used for logging and metrics.
|
||||||
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
|
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
|
||||||
f := func(w http.ResponseWriter, r *http.Request) {
|
f := func(w http.ResponseWriter, r *http.Request) {
|
||||||
hdr := w.Header()
|
hdr := w.Header()
|
||||||
@ -122,6 +131,15 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
|
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case adultBlockingName:
|
||||||
|
metrics.WebSvcAdultBlockingPageRequestsTotal.Inc()
|
||||||
|
case safeBrowsingName:
|
||||||
|
metrics.WebSvcSafeBrowsingPageRequestsTotal.Inc()
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,10 +158,11 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.AllowOrigin != "" {
|
h := w.Header()
|
||||||
w.Header().Set(agdhttp.HdrNameAccessControlAllowOrigin, f.AllowOrigin)
|
for k, v := range f.Headers {
|
||||||
|
h[k] = v
|
||||||
}
|
}
|
||||||
w.Header().Set(agdhttp.HdrNameContentType, f.ContentType)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
_, err := w.Write(f.Content)
|
_, err := w.Write(f.Content)
|
||||||
@ -151,16 +170,15 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
|
|||||||
logErrorByType(err, "websvc: static content: writing %s: %s", p, err)
|
logErrorByType(err, "websvc: static content: writing %s: %s", p, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metrics.WebSvcStaticContentRequestsTotal.Inc()
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticFile is a single file in a StaticFS.
|
// StaticFile is a single file in a StaticFS.
|
||||||
type StaticFile struct {
|
type StaticFile struct {
|
||||||
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
|
// Headers contains headers of the HTTP response.
|
||||||
AllowOrigin string
|
Headers http.Header
|
||||||
|
|
||||||
// ContentType is the value for the HTTP Content-Type header.
|
|
||||||
ContentType string
|
|
||||||
|
|
||||||
// Content is the file content.
|
// Content is the file content.
|
||||||
Content []byte
|
Content []byte
|
||||||
@ -174,6 +192,8 @@ func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
|
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metrics.WebSvcRobotsTxtRequestsTotal.Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// logErrorByType writes err to the error log, unless err is a network error or
|
// logErrorByType writes err to the error log, unless err is a network error or
|
||||||
|
@ -31,8 +31,10 @@ func TestService_ServeHTTP(t *testing.T) {
|
|||||||
|
|
||||||
staticContent := map[string]*websvc.StaticFile{
|
staticContent := map[string]*websvc.StaticFile{
|
||||||
"/favicon.ico": {
|
"/favicon.ico": {
|
||||||
ContentType: "image/x-icon",
|
Content: []byte{},
|
||||||
Content: []byte{},
|
Headers: http.Header{
|
||||||
|
agdhttp.HdrNameContentType: []string{"image/x-icon"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,22 +58,33 @@ func TestService_ServeHTTP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// DNSCheck path.
|
// DNSCheck path.
|
||||||
assertPathResponse(t, svc, "/dnscheck/test", http.StatusOK)
|
assertResponse(t, svc, "/dnscheck/test", http.StatusOK)
|
||||||
|
|
||||||
// Static content path.
|
// Static content path with headers.
|
||||||
assertPathResponse(t, svc, "/favicon.ico", http.StatusOK)
|
h := http.Header{
|
||||||
|
agdhttp.HdrNameContentType: []string{"image/x-icon"},
|
||||||
|
agdhttp.HdrNameServer: []string{"AdGuardDNS/"},
|
||||||
|
}
|
||||||
|
assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h)
|
||||||
|
|
||||||
// Robots path.
|
// Robots path.
|
||||||
assertPathResponse(t, svc, "/robots.txt", http.StatusOK)
|
assertResponse(t, svc, "/robots.txt", http.StatusOK)
|
||||||
|
|
||||||
// Root redirect path.
|
// Root redirect path.
|
||||||
assertPathResponse(t, svc, "/", http.StatusFound)
|
assertResponse(t, svc, "/", http.StatusFound)
|
||||||
|
|
||||||
// Other path.
|
// Other path.
|
||||||
assertPathResponse(t, svc, "/other", http.StatusNotFound)
|
assertResponse(t, svc, "/other", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCode int) {
|
// assertResponse is a helper function that checks status code of HTTP
|
||||||
|
// response.
|
||||||
|
func assertResponse(
|
||||||
|
t *testing.T,
|
||||||
|
svc *websvc.Service,
|
||||||
|
path string,
|
||||||
|
statusCode int,
|
||||||
|
) (rw *httptest.ResponseRecorder) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, (&url.URL{
|
r := httptest.NewRequest(http.MethodGet, (&url.URL{
|
||||||
@ -79,9 +92,27 @@ func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCo
|
|||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Path: path,
|
Path: path,
|
||||||
}).String(), strings.NewReader(""))
|
}).String(), strings.NewReader(""))
|
||||||
rw := httptest.NewRecorder()
|
rw = httptest.NewRecorder()
|
||||||
svc.ServeHTTP(rw, r)
|
svc.ServeHTTP(rw, r)
|
||||||
|
|
||||||
assert.Equal(t, statusCode, rw.Code)
|
assert.Equal(t, statusCode, rw.Code)
|
||||||
assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
|
assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
|
||||||
|
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertResponseWithHeaders is a helper function that checks status code and
|
||||||
|
// headers of HTTP response.
|
||||||
|
func assertResponseWithHeaders(
|
||||||
|
t *testing.T,
|
||||||
|
svc *websvc.Service,
|
||||||
|
path string,
|
||||||
|
statusCode int,
|
||||||
|
header http.Header,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rw := assertResponse(t, svc, path, statusCode)
|
||||||
|
|
||||||
|
assert.Equal(t, header, rw.Header())
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
|
||||||
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
@ -148,6 +149,8 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
|
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
|
||||||
|
|
||||||
prx.httpProxy.ServeHTTP(w, r)
|
prx.httpProxy.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
metrics.WebSvcLinkedIPProxyRequestsTotal.Inc()
|
||||||
} else if r.URL.Path == "/robots.txt" {
|
} else if r.URL.Path == "/robots.txt" {
|
||||||
serveRobotsDisallow(respHdr, w, prx.logPrefix)
|
serveRobotsDisallow(respHdr, w, prx.logPrefix)
|
||||||
} else {
|
} else {
|
||||||
|
@ -118,8 +118,8 @@ func New(c *Config) (svc *Service) {
|
|||||||
error404: c.Error404,
|
error404: c.Error404,
|
||||||
error500: c.Error500,
|
error500: c.Error500,
|
||||||
|
|
||||||
adultBlocking: blockPageServers(c.AdultBlocking, "adult blocking", c.Timeout),
|
adultBlocking: blockPageServers(c.AdultBlocking, adultBlockingName, c.Timeout),
|
||||||
safeBrowsing: blockPageServers(c.SafeBrowsing, "safe browsing", c.Timeout),
|
safeBrowsing: blockPageServers(c.SafeBrowsing, safeBrowsingName, c.Timeout),
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.RootRedirectURL != nil {
|
if c.RootRedirectURL != nil {
|
||||||
@ -164,6 +164,12 @@ func New(c *Config) (svc *Service) {
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Names for safeBrowsingHandler for logging and metrics.
|
||||||
|
const (
|
||||||
|
safeBrowsingName = "safe browsing"
|
||||||
|
adultBlockingName = "adult blocking"
|
||||||
|
)
|
||||||
|
|
||||||
// blockPageServers is a helper function that converts a *BlockPageServer into
|
// blockPageServers is a helper function that converts a *BlockPageServer into
|
||||||
// HTTP servers.
|
// HTTP servers.
|
||||||
func blockPageServers(
|
func blockPageServers(
|
||||||
|
@ -117,7 +117,9 @@ underscores() {
|
|||||||
git ls-files '*_*.go'\
|
git ls-files '*_*.go'\
|
||||||
| grep -F\
|
| grep -F\
|
||||||
-e '_generate.go'\
|
-e '_generate.go'\
|
||||||
|
-e '_linux.go'\
|
||||||
-e '_noreuseport.go'\
|
-e '_noreuseport.go'\
|
||||||
|
-e '_others.go'\
|
||||||
-e '_reuseport.go'\
|
-e '_reuseport.go'\
|
||||||
-e '_test.go'\
|
-e '_test.go'\
|
||||||
-e '_unix.go'\
|
-e '_unix.go'\
|
||||||
|
Loading…
x
Reference in New Issue
Block a user