Sync v2.1.5

This commit is contained in:
Andrey Meshkov 2023-02-03 15:27:58 +03:00
parent 7dec041e0f
commit f20c533cc3
97 changed files with 2865 additions and 1340 deletions

View File

@ -11,6 +11,99 @@ The format is **not** based on [Keep a Changelog][kec], since the project
## AGDNS-916 / Build 456
* `ratelimit` now defines rate of requests per second for IPv4 and IPv6
addresses separately. So replace this:
```yaml
ratelimit:
rps: 30
ipv4_subnet_key_len: 24
ipv6_subnet_key_len: 48
```
with this:
```yaml
ratelimit:
ipv4:
rps: 30
subnet_key_len: 24
ipv6:
rps: 300
subnet_key_len: 48
```
## AGDNS-907 / Build 449
* The objects within the `filtering_groups` have a new property,
`block_firefox_canary`. So replace this:
```yaml
filtering_groups:
-
id: default
# …
```
with this:
```yaml
filtering_groups:
-
id: default
# …
block_firefox_canary: true
```
The recommended default value is `true`.
## AGDNS-1308 / Build 447
* There is now a new env variable `RESEARCH_METRICS` that controls whether
collecting research metrics is enabled or not. Also, the first research
metric is added: `dns_research_blocked_per_country_total`, it counts the
number of blocked requests per country. Its default value is `0`, i.e.
research metrics collection is disabled by default.
## AGDNS-1051 / Build 443
* There are two changes in the keys of the `static_content` map. Firstly,
properties `allow_origin` and `content_type` are removed. Secondly, a new
property, called `headers`, is added. So replace this:
```yaml
static_content:
'/favicon.ico':
# …
allow_origin: '*'
content_type: 'image/x-icon'
```
with this:
```yaml
static_content:
'/favicon.ico':
# …
headers:
'Access-Control-Allow-Origin':
- '*'
'Content-Type':
- 'image/x-icon'
```
Adjust or add the values, if necessary.
## AGDNS-1278 / Build 423 ## AGDNS-1278 / Build 423
* The object `filters` has two new properties, `rule_list_cache_size` and * The object `filters` has two new properties, `rule_list_cache_size` and

View File

@ -8,8 +8,21 @@ ratelimit:
refuseany: true refuseany: true
# If response is larger than this, it is counted as several responses. # If response is larger than this, it is counted as several responses.
response_size_estimate: 1KB response_size_estimate: 1KB
# Rate of requests per second for one subnet.
rps: 30 rps: 30
# Rate limit options for IPv4 addresses.
ipv4:
# Rate of requests per second for one subnet for IPv4 addresses.
rps: 30
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv4 addresses.
subnet_key_len: 24
# Rate limit options for IPv6 addresses.
ipv6:
# Rate of requests per second for one subnet for IPv6 addresses.
rps: 300
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv6 addresses.
subnet_key_len: 48
# The time during which to count the number of times a client has hit the # The time during which to count the number of times a client has hit the
# rate limit for a back off. # rate limit for a back off.
# #
@ -21,10 +34,6 @@ ratelimit:
# How much a client that has hit the rate limit too often stays in the back # How much a client that has hit the rate limit too often stays in the back
# off. # off.
back_off_duration: 30m back_off_duration: 30m
# The lengths of the subnet prefixes used to calculate rate limiter bucket
# keys for IPv4 and IPv6 addresses correspondingly.
ipv4_subnet_key_len: 24
ipv6_subnet_key_len: 48
# Configuration for the allowlist. # Configuration for the allowlist.
allowlist: allowlist:
@ -156,9 +165,12 @@ web:
# servers. Paths must not cross the ones used by the DNS-over-HTTPS server. # servers. Paths must not cross the ones used by the DNS-over-HTTPS server.
static_content: static_content:
'/favicon.ico': '/favicon.ico':
allow_origin: '*'
content_type: 'image/x-icon'
content: '' content: ''
headers:
Access-Control-Allow-Origin:
- '*'
Content-Type:
- 'image/x-icon'
# If not defined, AdGuard DNS will respond with a 404 page to all such # If not defined, AdGuard DNS will respond with a 404 page to all such
# requests. # requests.
root_redirect_url: 'https://adguard-dns.com' root_redirect_url: 'https://adguard-dns.com'
@ -221,6 +233,7 @@ filtering_groups:
safe_browsing: safe_browsing:
enabled: true enabled: true
block_private_relay: false block_private_relay: false
block_firefox_canary: true
- id: 'family' - id: 'family'
parental: parental:
enabled: true enabled: true
@ -234,6 +247,7 @@ filtering_groups:
safe_browsing: safe_browsing:
enabled: true enabled: true
block_private_relay: false block_private_relay: false
block_firefox_canary: true
- id: 'non_filtering' - id: 'non_filtering'
rule_lists: rule_lists:
enabled: false enabled: false
@ -242,6 +256,7 @@ filtering_groups:
safe_browsing: safe_browsing:
enabled: false enabled: false
block_private_relay: false block_private_relay: false
block_firefox_canary: true
# Server groups and servers. # Server groups and servers.
server_groups: server_groups:

View File

@ -85,11 +85,23 @@ The `ratelimit` object has the following properties:
**Example:** `30m`. **Example:** `30m`.
* <a href="#ratelimit-rps" id="ratelimit-rps" name="ratelimit-rps">`rps`</a>: * <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>:
The rate of requests per second for one subnet. Requests above this are The ipv4 configuration object. It has the following fields:
counted in the backoff count.
**Example:** `30`. * <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>:
The rate of requests per second for one subnet. Requests above this are
counted in the backoff count.
**Example:** `30`.
* <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>:
The length of the subnet prefix used to calculate rate limiter bucket keys.
**Example:** `24`.
* <a href="#ratelimit-ipv6" id="ratelimit-ipv6" name="ratelimit-ipv6">`ipv6`</a>:
The `ipv6` configuration object has the same properties as the `ipv4` one
above.
* <a href="#ratelimit-back_off_count" id="ratelimit-back_off_count" name="ratelimit-back_off_count">`back_off_count`</a>: * <a href="#ratelimit-back_off_count" id="ratelimit-back_off_count" name="ratelimit-back_off_count">`back_off_count`</a>:
Maximum number of requests a client can make above the RPS within Maximum number of requests a client can make above the RPS within
@ -120,22 +132,11 @@ The `ratelimit` object has the following properties:
**Example:** `30s`. **Example:** `30s`.
* <a href="#ratelimit-ipv4_subnet_key_len" id="ratelimit-ipv4_subnet_key_len" name="ratelimit-ipv4_subnet_key_len">`ipv4_subnet_key_len`</a>: For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and
The length of the subnet prefix used to calculate rate limiter bucket keys `ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined
for IPv4 addresses. by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests
(one above `rps`) every second for 10 seconds within one minute, the client is
**Example:** `24`. blocked for `back_off_duration`.
* <a href="#ratelimit-ipv6_subnet_key_len" id="ratelimit-ipv6_subnet_key_len" name="ratelimit-ipv6_subnet_key_len">`ipv6_subnet_key_len`</a>:
Same as `ipv4_subnet_key_len` above but for IPv6 addresses.
**Example:** `48`.
For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and `rps`
is `5`, a client (meaning all IP addresses within the subnet defined by
`ipv4_subnet_key_len` and `ipv6_subnet_key_len`) that made 15 requests in one
second or 6 requests (one above `rps`) every second for 10 seconds within one
minute, the client is blocked for `back_off_duration`.
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL [env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
@ -454,13 +455,17 @@ The optional `web` object has the following properties:
`safe_browsing` and `adult_blocking` servers. Paths must not duplicate the `safe_browsing` and `adult_blocking` servers. Paths must not duplicate the
ones used by the DNS-over-HTTPS server. ones used by the DNS-over-HTTPS server.
Inside of the `headers` map, the header `Content-Type` is required.
**Property example:** **Property example:**
```yaml ```yaml
'static_content': static_content:
'/favicon.ico': '/favicon.ico':
'content_type': 'image/x-icon' content: 'base64content'
'content': 'base64content' headers:
'Content-Type':
- 'image/x-icon'
``` ```
* <a href="#web-root_redirect_url" id="web-root_redirect_url" name="web-root_redirect_url">`root_redirect_url`</a>: * <a href="#web-root_redirect_url" id="web-root_redirect_url" name="web-root_redirect_url">`root_redirect_url`</a>:
@ -647,6 +652,12 @@ The items of the `filtering_groups` array have the following properties:
**Example:** `false`. **Example:** `false`.
* <a href="#fg-*-block_firefox_canary" id="fg-*-block_firefox_canary" name="fg-*-block_firefox_canary">`block_firefox_canary`</a>:
If true, Firefox canary domain queries are blocked for requests using this
filtering group.
**Example:** `true`.
## <a href="#server_groups" id="server_groups" name="server_groups">Server groups</a> ## <a href="#server_groups" id="server_groups" name="server_groups">Server groups</a>

View File

@ -22,6 +22,7 @@ sensitive configuration. All other configuration is stored in the
* [`LISTEN_PORT`](#LISTEN_PORT) * [`LISTEN_PORT`](#LISTEN_PORT)
* [`LOG_TIMESTAMP`](#LOG_TIMESTAMP) * [`LOG_TIMESTAMP`](#LOG_TIMESTAMP)
* [`QUERYLOG_PATH`](#QUERYLOG_PATH) * [`QUERYLOG_PATH`](#QUERYLOG_PATH)
* [`RESEARCH_METRICS`](#RESEARCH_METRICS)
* [`RULESTAT_URL`](#RULESTAT_URL) * [`RULESTAT_URL`](#RULESTAT_URL)
* [`SENTRY_DSN`](#SENTRY_DSN) * [`SENTRY_DSN`](#SENTRY_DSN)
* [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE) * [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE)
@ -198,6 +199,15 @@ The path to the file into which the query log is going to be written.
## <a href="#RESEARCH_METRICS" id="RESEARCH_METRICS" name="RESEARCH_METRICS">`RESEARCH_METRICS`</a>
If `1`, enable collection of a set of special prometheus metrics (prefix is
`dns_research`). If `0`, disable collection of those metrics.
**Default:** `0`.
## <a href="#RULESTAT_URL" id="RULESTAT_URL" name="RULESTAT_URL">`RULESTAT_URL`</a> ## <a href="#RULESTAT_URL" id="RULESTAT_URL" name="RULESTAT_URL">`RULESTAT_URL`</a>
The URL to send filtering rule list statistics to. If empty or unset, the The URL to send filtering rule list statistics to. If empty or unset, the

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.19
require ( require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0 github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
github.com/AdguardTeam/golibs v0.11.3 github.com/AdguardTeam/golibs v0.11.4
github.com/AdguardTeam/urlfilter v0.16.1 github.com/AdguardTeam/urlfilter v0.16.1
github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a

4
go.sum
View File

@ -33,8 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310= github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw= github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=

View File

@ -9,7 +9,11 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVl
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM= dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU= github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
@ -74,7 +78,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
@ -198,41 +202,36 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA= github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM= github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM= go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4= golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w= google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
@ -246,6 +245,7 @@ gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=

View File

@ -1,11 +1,6 @@
package agd package agd
import ( import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
)
// Common DNS Message Constants, Types, And Utilities // Common DNS Message Constants, Types, And Utilities
@ -27,10 +22,3 @@ const (
ProtoDoT = dnsserver.ProtoDoT ProtoDoT = dnsserver.ProtoDoT
ProtoDNSCrypt = dnsserver.ProtoDNSCrypt ProtoDNSCrypt = dnsserver.ProtoDNSCrypt
) )
// Resolver is the DNS resolver interface.
//
// See go doc net.Resolver.
type Resolver interface {
LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error)
}

View File

@ -146,6 +146,10 @@ type FilteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay is blocked for requests // BlockPrivateRelay shows if Apple Private Relay is blocked for requests
// using this filtering group. // using this filtering group.
BlockPrivateRelay bool BlockPrivateRelay bool
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests using this filtering group.
BlockFirefoxCanary bool
} }
// FilteringGroupID is the ID of a filter group. It is an opaque string. // FilteringGroupID is the ID of a filter group. It is an opaque string.

View File

@ -73,6 +73,10 @@ type Profile struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for // BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests from all devices in this profile. // requests from all devices in this profile.
BlockPrivateRelay bool BlockPrivateRelay bool
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests from all devices in this profile.
BlockFirefoxCanary bool
} }
// ProfileID is the ID of a profile. It is an opaque string. // ProfileID is the ID of a profile. It is an opaque string.

View File

@ -17,6 +17,8 @@ import (
// Data Storage // Data Storage
// ProfileDB is the local database of profiles and other data. // ProfileDB is the local database of profiles and other data.
//
// TODO(a.garipov): move this logic to the backend package.
type ProfileDB interface { type ProfileDB interface {
ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error) ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error) ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)

View File

@ -7,6 +7,8 @@ package agdio
import ( import (
"fmt" "fmt"
"io" "io"
"github.com/AdguardTeam/golibs/mathutil"
) )
// LimitError is returned when the Limit is reached. // LimitError is returned when the Limit is reached.
@ -35,9 +37,8 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
} }
} }
if int64(len(p)) > lr.n { l := mathutil.Min(int64(len(p)), lr.n)
p = p[0:lr.n] p = p[:l]
}
n, err = lr.r.Read(p) n, err = lr.r.Read(p)
lr.n -= int64(n) lr.n -= int64(n)

View File

@ -165,6 +165,31 @@ type v1SettingsRespSettings struct {
FilteringEnabled bool `json:"filtering_enabled"` FilteringEnabled bool `json:"filtering_enabled"`
Deleted bool `json:"deleted"` Deleted bool `json:"deleted"`
BlockPrivateRelay bool `json:"block_private_relay"` BlockPrivateRelay bool `json:"block_private_relay"`
BlockFirefoxCanary bool `json:"block_firefox_canary"`
}
// type check
var _ json.Unmarshaler = (*v1SettingsRespSettings)(nil)
// UnmarshalJSON implements the [json.Unmarshaler] interface for
// *v1SettingsRespSettings. It puts default value into BlockFirefoxCanary
// field while it is not implemented on the backend side.
//
// TODO(a.garipov): Remove once the backend starts to always send it.
func (rs *v1SettingsRespSettings) UnmarshalJSON(b []byte) (err error) {
type defaultDec v1SettingsRespSettings
s := defaultDec{
BlockFirefoxCanary: true,
}
if err = json.Unmarshal(b, &s); err != nil {
return err
}
*rs = v1SettingsRespSettings(s)
return nil
} }
// v1SettingsRespRuleLists is the structure for decoding filtering rule lists // v1SettingsRespRuleLists is the structure for decoding filtering rule lists
@ -414,10 +439,11 @@ const maxFltRespTTL = 1 * time.Hour
func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) { func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) {
ttl = time.Duration(respTTL) * time.Second ttl = time.Duration(respTTL) * time.Second
if ttl > maxFltRespTTL { if ttl > maxFltRespTTL {
return ttl, fmt.Errorf("too high: got %d, max %d", respTTL, maxFltRespTTL) ttl = maxFltRespTTL
err = fmt.Errorf("too high: got %s, max %s", ttl, maxFltRespTTL)
} }
return ttl, nil return ttl, err
} }
// toInternal converts r to an [agd.DSProfilesResponse] instance. // toInternal converts r to an [agd.DSProfilesResponse] instance.
@ -461,6 +487,9 @@ func (r *v1SettingsResp) toInternal(
reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err) reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err)
// Go on and use the fixed value. // Go on and use the fixed value.
//
// TODO(ameshkov, a.garipov): Consider continuing, like with all
// other validation errors.
} }
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
@ -479,6 +508,7 @@ func (r *v1SettingsResp) toInternal(
QueryLogEnabled: s.QueryLogEnabled, QueryLogEnabled: s.QueryLogEnabled,
Deleted: s.Deleted, Deleted: s.Deleted,
BlockPrivateRelay: s.BlockPrivateRelay, BlockPrivateRelay: s.BlockPrivateRelay,
BlockFirefoxCanary: s.BlockFirefoxCanary,
}) })
} }

View File

@ -131,6 +131,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true, QueryLogEnabled: true,
Deleted: false, Deleted: false,
BlockPrivateRelay: true, BlockPrivateRelay: true,
BlockFirefoxCanary: true,
}, { }, {
Parental: wantParental, Parental: wantParental,
ID: "83f3ea8f", ID: "83f3ea8f",
@ -162,6 +163,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true, QueryLogEnabled: true,
Deleted: true, Deleted: true,
BlockPrivateRelay: false, BlockPrivateRelay: false,
BlockFirefoxCanary: false,
}}, }},
} }

View File

@ -11,6 +11,7 @@
}, },
"deleted": false, "deleted": false,
"block_private_relay": true, "block_private_relay": true,
"block_firefox_canary": true,
"devices": [ "devices": [
{ {
"id": "118ffe93", "id": "118ffe93",
@ -43,6 +44,7 @@
}, },
"deleted": true, "deleted": true,
"block_private_relay": false, "block_private_relay": false,
"block_firefox_canary": false,
"devices": [ "devices": [
{ {
"id": "0d7724fa", "id": "0d7724fa",

View File

@ -28,6 +28,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc" "github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
) )
// Main is the entry point of application. // Main is the entry point of application.
@ -89,52 +90,75 @@ func Main() {
go func() { go func() {
defer geoIPMu.Unlock() defer geoIPMu.Unlock()
geoIP, geoIPErr = envs.geoIP(c.GeoIP, errColl) geoIP, geoIPErr = envs.geoIP(c.GeoIP)
}() }()
// Safe Browsing Hosts // Safe-browsing and adult-blocking filters
// TODO(ameshkov): Consider making configurable.
filteringResolver := agdnet.NewCachingResolver(
agdnet.DefaultResolver{},
1*timeutil.Day,
)
err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm) err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm)
check(err) check(err)
safeBrowsingConf := c.SafeBrowsing.toInternal( safeBrowsingConf, err := c.SafeBrowsing.toInternal(
errColl,
filteringResolver,
agd.FilterListIDSafeBrowsing, agd.FilterListIDSafeBrowsing,
envs.FilterCachePath, envs.FilterCachePath,
errColl,
) )
safeBrowsingHashes, err := filter.NewHashStorage(safeBrowsingConf)
check(err) check(err)
err = safeBrowsingHashes.Start() safeBrowsingFilter, err := filter.NewHashPrefix(safeBrowsingConf)
check(err) check(err)
adultBlockingConf := c.AdultBlocking.toInternal( safeBrowsingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: ctxWithDefaultTimeout,
Refresher: safeBrowsingFilter,
ErrColl: errColl,
Name: string(agd.FilterListIDSafeBrowsing),
Interval: safeBrowsingConf.Staleness,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
err = safeBrowsingUpd.Start()
check(err)
adultBlockingConf, err := c.AdultBlocking.toInternal(
errColl,
filteringResolver,
agd.FilterListIDAdultBlocking, agd.FilterListIDAdultBlocking,
envs.FilterCachePath, envs.FilterCachePath,
errColl,
) )
adultBlockingHashes, err := filter.NewHashStorage(adultBlockingConf)
check(err) check(err)
err = adultBlockingHashes.Start() adultBlockingFilter, err := filter.NewHashPrefix(adultBlockingConf)
check(err) check(err)
// Filters And Filtering Groups adultBlockingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: ctxWithDefaultTimeout,
Refresher: adultBlockingFilter,
ErrColl: errColl,
Name: string(agd.FilterListIDAdultBlocking),
Interval: adultBlockingConf.Staleness,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
err = adultBlockingUpd.Start()
check(err)
fltStrgConf := c.Filters.toInternal(errColl, envs) // Filter storage and filtering groups
fltStrgConf.SafeBrowsing = &filter.HashPrefixConfig{
Hashes: safeBrowsingHashes,
ReplacementHost: c.SafeBrowsing.BlockHost,
CacheTTL: c.SafeBrowsing.CacheTTL.Duration,
CacheSize: c.SafeBrowsing.CacheSize,
}
fltStrgConf.AdultBlocking = &filter.HashPrefixConfig{ fltStrgConf := c.Filters.toInternal(
Hashes: adultBlockingHashes, errColl,
ReplacementHost: c.AdultBlocking.BlockHost, filteringResolver,
CacheTTL: c.AdultBlocking.CacheTTL.Duration, envs,
CacheSize: c.AdultBlocking.CacheSize, safeBrowsingFilter,
} adultBlockingFilter,
)
fltStrg, err := filter.NewDefaultStorage(fltStrgConf) fltStrg, err := filter.NewDefaultStorage(fltStrgConf)
check(err) check(err)
@ -153,8 +177,6 @@ func Main() {
err = fltStrgUpd.Start() err = fltStrgUpd.Start()
check(err) check(err)
safeBrowsing := filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes)
// Server Groups // Server Groups
fltGroups, err := c.FilteringGroups.toInternal(fltStrg) fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
@ -329,8 +351,11 @@ func Main() {
}, c.Upstream.Healthcheck.Enabled) }, c.Upstream.Healthcheck.Enabled)
dnsConf := &dnssvc.Config{ dnsConf := &dnssvc.Config{
Messages: messages, Messages: messages,
SafeBrowsing: safeBrowsing, SafeBrowsing: filter.NewSafeBrowsingServer(
safeBrowsingConf.Hashes,
adultBlockingConf.Hashes,
),
BillStat: billStatRec, BillStat: billStatRec,
ProfileDB: profDB, ProfileDB: profDB,
DNSCheck: dnsCk, DNSCheck: dnsCk,
@ -349,6 +374,7 @@ func Main() {
CacheSize: c.Cache.Size, CacheSize: c.Cache.Size,
ECSCacheSize: c.Cache.ECSSize, ECSCacheSize: c.Cache.ECSSize,
UseECSCache: c.Cache.Type == cacheTypeECS, UseECSCache: c.Cache.Type == cacheTypeECS,
ResearchMetrics: bool(envs.ResearchMetrics),
} }
dnsSvc, err := dnssvc.New(dnsConf) dnsSvc, err := dnssvc.New(dnsConf)
@ -383,11 +409,11 @@ func Main() {
) )
h := newSignalHandler( h := newSignalHandler(
adultBlockingHashes,
safeBrowsingHashes,
debugSvc, debugSvc,
webSvc, webSvc,
dnsSvc, dnsSvc,
safeBrowsingUpd,
adultBlockingUpd,
profDBUpd, profDBUpd,
dnsDBUpd, dnsDBUpd,
geoIPUpd, geoIPUpd,

View File

@ -3,14 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -213,74 +208,6 @@ func (c *geoIPConfig) validate() (err error) {
} }
} }
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []string `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
type safeBrowsingConfig struct {
// URL is the URL used to update the filter.
URL *agdhttp.URL `yaml:"url"`
// BlockHost is the hostname with which to respond to any requests that
// match the filter.
//
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
// addresses.
BlockHost string `yaml:"block_host"`
// CacheSize is the size of the response cache, in entries.
CacheSize int `yaml:"cache_size"`
// CacheTTL is the TTL of the response cache.
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// toInternal converts c to the safe browsing filter configuration for the
// filter storage of the DNS server. c is assumed to be valid.
func (c *safeBrowsingConfig) toInternal(
id agd.FilterListID,
cacheDir string,
errColl agd.ErrorCollector,
) (conf *filter.HashStorageConfig) {
return &filter.HashStorageConfig{
URL: netutil.CloneURL(&c.URL.URL),
ErrColl: errColl,
ID: id,
CachePath: filepath.Join(cacheDir, string(id)),
RefreshIvl: c.RefreshIvl.Duration,
}
}
// validate returns an error if the safe browsing filter configuration is
// invalid.
func (c *safeBrowsingConfig) validate() (err error) {
switch {
case c == nil:
return errNilConfig
case c.URL == nil:
return errors.Error("no url")
case c.BlockHost == "":
return errors.Error("no block_host")
case c.CacheSize <= 0:
return newMustBePositiveError("cache_size", c.CacheSize)
case c.CacheTTL.Duration <= 0:
return newMustBePositiveError("cache_ttl", c.CacheTTL)
case c.RefreshIvl.Duration <= 0:
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
default:
return nil
}
}
// readConfig reads the configuration. // readConfig reads the configuration.
func readConfig(confPath string) (c *configuration, err error) { func readConfig(confPath string) (c *configuration, err error) {
// #nosec G304 -- Trust the path to the configuration file that is given // #nosec G304 -- Trust the path to the configuration file that is given

View File

@ -48,8 +48,9 @@ type environments struct {
ListenPort int `env:"LISTEN_PORT" envDefault:"8181"` ListenPort int `env:"LISTEN_PORT" envDefault:"8181"`
LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"` LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
LogVerbose strictBool `env:"VERBOSE" envDefault:"0"` LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
ResearchMetrics strictBool `env:"RESEARCH_METRICS" envDefault:"0"`
} }
// readEnvs reads the configuration. // readEnvs reads the configuration.
@ -127,12 +128,10 @@ func (envs *environments) buildDNSDB(
// geoIP returns an GeoIP database implementation from environment. // geoIP returns an GeoIP database implementation from environment.
func (envs *environments) geoIP( func (envs *environments) geoIP(
c *geoIPConfig, c *geoIPConfig,
errColl agd.ErrorCollector,
) (g *geoip.File, err error) { ) (g *geoip.File, err error) {
log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath) log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath)
g, err = geoip.NewFile(&geoip.FileConfig{ g, err = geoip.NewFile(&geoip.FileConfig{
ErrColl: errColl,
ASNPath: envs.GeoIPASNPath, ASNPath: envs.GeoIPASNPath,
CountryPath: envs.GeoIPCountryPath, CountryPath: envs.GeoIPCountryPath,
HostCacheSize: c.HostCacheSize, HostCacheSize: c.HostCacheSize,

View File

@ -47,16 +47,21 @@ type filtersConfig struct {
// cacheDir must exist. c is assumed to be valid. // cacheDir must exist. c is assumed to be valid.
func (c *filtersConfig) toInternal( func (c *filtersConfig) toInternal(
errColl agd.ErrorCollector, errColl agd.ErrorCollector,
resolver agdnet.Resolver,
envs *environments, envs *environments,
safeBrowsing *filter.HashPrefix,
adultBlocking *filter.HashPrefix,
) (conf *filter.DefaultStorageConfig) { ) (conf *filter.DefaultStorageConfig) {
return &filter.DefaultStorageConfig{ return &filter.DefaultStorageConfig{
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL), FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL), BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL),
GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL), GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL),
YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL), YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL),
SafeBrowsing: safeBrowsing,
AdultBlocking: adultBlocking,
Now: time.Now, Now: time.Now,
ErrColl: errColl, ErrColl: errColl,
Resolver: agdnet.DefaultResolver{}, Resolver: resolver,
CacheDir: envs.FilterCachePath, CacheDir: envs.FilterCachePath,
CustomFilterCacheSize: c.CustomFilterCacheSize, CustomFilterCacheSize: c.CustomFilterCacheSize,
SafeSearchCacheSize: c.SafeSearchCacheSize, SafeSearchCacheSize: c.SafeSearchCacheSize,

View File

@ -29,6 +29,10 @@ type filteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for // BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests using this filtering group. // requests using this filtering group.
BlockPrivateRelay bool `yaml:"block_private_relay"` BlockPrivateRelay bool `yaml:"block_private_relay"`
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests using this filtering group.
BlockFirefoxCanary bool `yaml:"block_firefox_canary"`
} }
// fltGrpRuleLists contains filter rule lists configuration for a filtering // fltGrpRuleLists contains filter rule lists configuration for a filtering
@ -133,6 +137,7 @@ func (groups filteringGroups) toInternal(
GeneralSafeSearch: g.Parental.GeneralSafeSearch, GeneralSafeSearch: g.Parental.GeneralSafeSearch,
YoutubeSafeSearch: g.Parental.YoutubeSafeSearch, YoutubeSafeSearch: g.Parental.YoutubeSafeSearch,
BlockPrivateRelay: g.BlockPrivateRelay, BlockPrivateRelay: g.BlockPrivateRelay,
BlockFirefoxCanary: g.BlockFirefoxCanary,
} }
} }

View File

@ -15,14 +15,17 @@ type rateLimitConfig struct {
// AllowList is the allowlist of clients. // AllowList is the allowlist of clients.
Allowlist *allowListConfig `yaml:"allowlist"` Allowlist *allowListConfig `yaml:"allowlist"`
// Rate limit options for IPv4 addresses.
IPv4 *rateLimitOptions `yaml:"ipv4"`
// Rate limit options for IPv6 addresses.
IPv6 *rateLimitOptions `yaml:"ipv6"`
// ResponseSizeEstimate is the size of the estimate of the size of one DNS // ResponseSizeEstimate is the size of the estimate of the size of one DNS
// response for the purposes of rate limiting. Responses over this estimate // response for the purposes of rate limiting. Responses over this estimate
// are counted as several responses. // are counted as several responses.
ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"` ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"`
// RPS is the maximum number of requests per second.
RPS int `yaml:"rps"`
// BackOffCount helps with repeated offenders. It defines, how many times // BackOffCount helps with repeated offenders. It defines, how many times
// a client hits the rate limit before being held in the back off. // a client hits the rate limit before being held in the back off.
BackOffCount int `yaml:"back_off_count"` BackOffCount int `yaml:"back_off_count"`
@ -35,18 +38,42 @@ type rateLimitConfig struct {
// a client has hit the rate limit for a back off. // a client has hit the rate limit for a back off.
BackOffPeriod timeutil.Duration `yaml:"back_off_period"` BackOffPeriod timeutil.Duration `yaml:"back_off_period"`
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses.
IPv4SubnetKeyLen int `yaml:"ipv4_subnet_key_len"`
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses.
IPv6SubnetKeyLen int `yaml:"ipv6_subnet_key_len"`
// RefuseANY, if true, makes the server refuse DNS * queries. // RefuseANY, if true, makes the server refuse DNS * queries.
RefuseANY bool `yaml:"refuse_any"` RefuseANY bool `yaml:"refuse_any"`
} }
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []string `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
// addresses.
type rateLimitOptions struct {
// RPS is the maximum number of requests per second.
RPS int `yaml:"rps"`
// SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys.
SubnetKeyLen int `yaml:"subnet_key_len"`
}
// validate returns an error if rate limit options are invalid.
func (o *rateLimitOptions) validate() (err error) {
if o == nil {
return errNilConfig
}
return coalesceError(
validatePositive("rps", o.RPS),
validatePositive("subnet_key_len", o.SubnetKeyLen),
)
}
// toInternal converts c to the rate limiting configuration for the DNS server. // toInternal converts c to the rate limiting configuration for the DNS server.
// c is assumed to be valid. // c is assumed to be valid.
func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) { func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) {
@ -55,10 +82,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()), ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()),
Duration: c.BackOffDuration.Duration, Duration: c.BackOffDuration.Duration,
Period: c.BackOffPeriod.Duration, Period: c.BackOffPeriod.Duration,
RPS: c.RPS, IPv4RPS: c.IPv4.RPS,
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
IPv6RPS: c.IPv6.RPS,
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
Count: c.BackOffCount, Count: c.BackOffCount,
IPv4SubnetKeyLen: c.IPv4SubnetKeyLen,
IPv6SubnetKeyLen: c.IPv6SubnetKeyLen,
RefuseANY: c.RefuseANY, RefuseANY: c.RefuseANY,
} }
} }
@ -71,14 +99,21 @@ func (c *rateLimitConfig) validate() (err error) {
return fmt.Errorf("allowlist: %w", errNilConfig) return fmt.Errorf("allowlist: %w", errNilConfig)
} }
err = c.IPv4.validate()
if err != nil {
return fmt.Errorf("ipv4: %w", err)
}
err = c.IPv6.validate()
if err != nil {
return fmt.Errorf("ipv6: %w", err)
}
return coalesceError( return coalesceError(
validatePositive("rps", c.RPS),
validatePositive("back_off_count", c.BackOffCount), validatePositive("back_off_count", c.BackOffCount),
validatePositive("back_off_duration", c.BackOffDuration), validatePositive("back_off_duration", c.BackOffDuration),
validatePositive("back_off_period", c.BackOffPeriod), validatePositive("back_off_period", c.BackOffPeriod),
validatePositive("response_size_estimate", c.ResponseSizeEstimate), validatePositive("response_size_estimate", c.ResponseSizeEstimate),
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl), validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
validatePositive("ipv4_subnet_key_len", c.IPv4SubnetKeyLen),
validatePositive("ipv6_subnet_key_len", c.IPv6SubnetKeyLen),
) )
} }

View File

@ -0,0 +1,86 @@
package cmd
import (
"path/filepath"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// Safe-browsing and adult-blocking configuration
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
type safeBrowsingConfig struct {
// URL is the URL used to update the filter.
URL *agdhttp.URL `yaml:"url"`
// BlockHost is the hostname with which to respond to any requests that
// match the filter.
//
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
// addresses.
BlockHost string `yaml:"block_host"`
// CacheSize is the size of the response cache, in entries.
CacheSize int `yaml:"cache_size"`
// CacheTTL is the TTL of the response cache.
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// toInternal converts c to the safe browsing filter configuration for the
// filter storage of the DNS server. c is assumed to be valid.
func (c *safeBrowsingConfig) toInternal(
errColl agd.ErrorCollector,
resolver agdnet.Resolver,
id agd.FilterListID,
cacheDir string,
) (fltConf *filter.HashPrefixConfig, err error) {
hashes, err := hashstorage.New("")
if err != nil {
return nil, err
}
return &filter.HashPrefixConfig{
Hashes: hashes,
URL: netutil.CloneURL(&c.URL.URL),
ErrColl: errColl,
Resolver: resolver,
ID: id,
CachePath: filepath.Join(cacheDir, string(id)),
ReplacementHost: c.BlockHost,
Staleness: c.RefreshIvl.Duration,
CacheTTL: c.CacheTTL.Duration,
CacheSize: c.CacheSize,
}, nil
}
// validate returns an error if the safe browsing filter configuration is
// invalid.
func (c *safeBrowsingConfig) validate() (err error) {
switch {
case c == nil:
return errNilConfig
case c.URL == nil:
return errors.Error("no url")
case c.BlockHost == "":
return errors.Error("no block_host")
case c.CacheSize <= 0:
return newMustBePositiveError("cache_size", c.CacheSize)
case c.CacheTTL.Duration <= 0:
return newMustBePositiveError("cache_ttl", c.CacheTTL)
case c.RefreshIvl.Duration <= 0:
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
default:
return nil
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
"net/textproto"
"os" "os"
"path" "path"
@ -386,11 +387,8 @@ func (sc staticContent) validate() (err error) {
// staticFile is a single file in a static content mapping. // staticFile is a single file in a static content mapping.
type staticFile struct { type staticFile struct {
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header. // Headers contains headers of the HTTP response.
AllowOrigin string `yaml:"allow_origin"` Headers http.Header `yaml:"headers"`
// ContentType is the value for the HTTP Content-Type header.
ContentType string `yaml:"content_type"`
// Content is the file content. // Content is the file content.
Content string `yaml:"content"` Content string `yaml:"content"`
@ -400,8 +398,12 @@ type staticFile struct {
// assumed to be valid. // assumed to be valid.
func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) { func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
file = &websvc.StaticFile{ file = &websvc.StaticFile{
AllowOrigin: f.AllowOrigin, Headers: http.Header{},
ContentType: f.ContentType, }
for k, vs := range f.Headers {
ck := textproto.CanonicalMIMEHeaderKey(k)
file.Headers[ck] = vs
} }
file.Content, err = base64.StdEncoding.DecodeString(f.Content) file.Content, err = base64.StdEncoding.DecodeString(f.Content)
@ -409,17 +411,20 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
return nil, fmt.Errorf("content: %w", err) return nil, fmt.Errorf("content: %w", err)
} }
// Check Content-Type here as opposed to in validate, because we need
// all keys to be canonicalized first.
if file.Headers.Get(agdhttp.HdrNameContentType) == "" {
return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required")
}
return file, nil return file, nil
} }
// validate returns an error if the static content file is invalid. // validate returns an error if the static content file is invalid.
func (f *staticFile) validate() (err error) { func (f *staticFile) validate() (err error) {
switch { if f == nil {
case f == nil:
return errors.Error("no file") return errors.Error("no file")
case f.ContentType == "":
return errors.Error("no content_type")
default:
return nil
} }
return nil
} }

View File

@ -222,9 +222,9 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
} }
} }
// FindENDS0Option searches for the specified EDNS0 option in the OPT resource // FindEDNS0Option searches for the specified EDNS0 option in the OPT resource
// record of the msg and returns it or nil if it's not present. // record of the msg and returns it or nil if it's not present.
func FindENDS0Option[T dns.EDNS0](msg *dns.Msg) (o T) { func FindEDNS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
rr := msg.IsEdns0() rr := msg.IsEdns0()
if rr == nil { if rr == nil {
return o return o

View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.19 go 1.19
require ( require (
github.com/AdguardTeam/golibs v0.11.3 github.com/AdguardTeam/golibs v0.11.4
github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/ameshkov/dnsstamps v1.0.3 github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2 github.com/bluele/gcache v0.0.2

View File

@ -31,8 +31,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310= github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=

View File

@ -1,36 +0,0 @@
//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
package dnsserver
import (
"context"
"fmt"
"net"
"github.com/AdguardTeam/golibs/errors"
)
// listenUDP listens to the specified address on UDP.
func listenUDP(_ context.Context, addr string, _ bool) (conn *net.UDPConn, err error) {
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
c, err := net.ListenPacket("udp", addr)
if err != nil {
return nil, err
}
conn, ok := c.(*net.UDPConn)
if !ok {
// TODO(ameshkov): should not happen, consider panic here.
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
return nil, err
}
return conn, nil
}
// listenTCP listens to the specified address on TCP.
func listenTCP(_ context.Context, addr string) (conn net.Listener, err error) {
return net.Listen("tcp", addr)
}

View File

@ -1,63 +0,0 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
package dnsserver
import (
"context"
"fmt"
"net"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/unix"
)
func reuseportControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return opErr
}
// listenUDP listens to the specified address on UDP. If oob flag is set to
// true this method also enables OOB for the listen socket that enables using of
// ReadMsgUDP/WriteMsgUDP. Doing it this way is necessary to correctly discover
// the source address when it listens to 0.0.0.0.
func listenUDP(ctx context.Context, addr string, oob bool) (conn *net.UDPConn, err error) {
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
var lc net.ListenConfig
lc.Control = reuseportControl
c, err := lc.ListenPacket(ctx, "udp", addr)
if err != nil {
return nil, err
}
conn, ok := c.(*net.UDPConn)
if !ok {
// TODO(ameshkov): should not happen, consider panic here.
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
return nil, err
}
if oob {
if err = setUDPSocketOptions(conn); err != nil {
return nil, fmt.Errorf("failed to set socket options: %w", err)
}
}
return conn, err
}
// listenTCP listens to the specified address on TCP.
func listenTCP(ctx context.Context, addr string) (l net.Listener, err error) {
var lc net.ListenConfig
lc.Control = reuseportControl
return lc.Listen(ctx, "tcp", addr)
}

View File

@ -0,0 +1,73 @@
// Package netext contains extensions of package net in the Go standard library.
package netext
import (
"context"
"fmt"
"net"
)
// ListenConfig is the interface that allows controlling options of connections
// used by the DNS servers defined in this module. Default ListenConfigs are
// the ones returned by [DefaultListenConfigWithOOB] for plain DNS and
// [DefaultListenConfig] for others.
//
// This interface is modeled after [net.ListenConfig].
type ListenConfig interface {
Listen(ctx context.Context, network, address string) (l net.Listener, err error)
ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error)
}
// DefaultListenConfig returns the default [ListenConfig] used by the servers in
// this module except for the plain-DNS ones, which use
// [DefaultListenConfigWithOOB].
func DefaultListenConfig() (lc ListenConfig) {
return &net.ListenConfig{
Control: defaultListenControl,
}
}
// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the
// plain-DNS servers in this module. The resulting ListenConfig sets additional
// socket flags and processes the control-messages of connections created with
// ListenPacket.
func DefaultListenConfigWithOOB() (lc ListenConfig) {
return &listenConfigOOB{
ListenConfig: net.ListenConfig{
Control: defaultListenControl,
},
}
}
// type check
var _ ListenConfig = (*listenConfigOOB)(nil)
// listenConfigOOB is a wrapper around [net.ListenConfig] with modifications
// that set the control-message options on packet conns.
type listenConfigOOB struct {
net.ListenConfig
}
// ListenPacket implements the [ListenConfig] interface for *listenConfigOOB.
// It sets the control-message flags to receive additional out-of-band data to
// correctly discover the source address when it listens to 0.0.0.0 as well as
// in situations when SO_BINDTODEVICE is used.
//
// network must be "udp", "udp4", or "udp6".
func (lc *listenConfigOOB) ListenPacket(
ctx context.Context,
network string,
address string,
) (c net.PacketConn, err error) {
c, err = lc.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
}
err = setIPOpts(c)
if err != nil {
return nil, fmt.Errorf("setting socket options: %w", err)
}
return wrapPacketConn(c), nil
}

View File

@ -0,0 +1,44 @@
//go:build unix
package netext
import (
"net"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this
// package.
func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return errors.WithDeferred(opErr, err)
}
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
func setIPOpts(c net.PacketConn) (err error) {
// TODO(a.garipov): Returning an error only if both functions return one
// (which is what module dns does as well) seems rather fragile. Depending
// on the OS, the valid errors are ENOPROTOOPT, EINVAL, and maybe others.
// Investigate and make OS-specific versions to make sure we don't miss any
// real errors.
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err4 != nil && err6 != nil {
return errors.List("setting ipv4 and ipv6 options", err4, err6)
}
return nil
}

View File

@ -0,0 +1,67 @@
//go:build unix
package netext_test
import (
"context"
"syscall"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestDefaultListenConfigWithOOB(t *testing.T) {
lc := netext.DefaultListenConfigWithOOB()
require.NotNil(t, lc)
type syscallConner interface {
SyscallConn() (c syscall.RawConn, err error)
}
t.Run("ipv4", func(t *testing.T) {
c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
require.NoError(t, opErr)
// TODO(a.garipov): Rewrite this to use actual expected values for
// each OS.
assert.NotEqual(t, 0, val)
})
require.NoError(t, err)
})
t.Run("ipv6", func(t *testing.T) {
c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0")
if errors.Is(err, syscall.EADDRNOTAVAIL) {
// Some CI machines have IPv6 disabled.
t.Skipf("ipv6 seems to not be supported: %s", err)
}
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
require.NoError(t, opErr)
assert.NotEqual(t, 0, val)
})
require.NoError(t, err)
})
}

View File

@ -0,0 +1,17 @@
//go:build windows
package netext
import (
"net"
"syscall"
)
// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
func setIPOpts(c net.PacketConn) (err error) {
return nil
}

View File

@ -0,0 +1,69 @@
package netext
import (
"net"
)
// PacketSession contains additional information about a packet read from or
// written to a [SessionPacketConn].
type PacketSession interface {
LocalAddr() (addr net.Addr)
RemoteAddr() (addr net.Addr)
}
// NewSimplePacketSession returns a new packet session using the given
// parameters.
func NewSimplePacketSession(laddr, raddr net.Addr) (s PacketSession) {
return &simplePacketSession{
laddr: laddr,
raddr: raddr,
}
}
// simplePacketSession is a simple implementation of the [PacketSession]
// interface.
type simplePacketSession struct {
laddr net.Addr
raddr net.Addr
}
// LocalAddr implements the [PacketSession] interface for *simplePacketSession.
func (s *simplePacketSession) LocalAddr() (addr net.Addr) { return s.laddr }
// RemoteAddr implements the [PacketSession] interface for *simplePacketSession.
func (s *simplePacketSession) RemoteAddr() (addr net.Addr) { return s.raddr }
// SessionPacketConn extends [net.PacketConn] with methods for working with
// packet sessions.
type SessionPacketConn interface {
net.PacketConn
ReadFromSession(b []byte) (n int, s PacketSession, err error)
WriteToSession(b []byte, s PacketSession) (n int, err error)
}
// ReadFromSession is a convenience wrapper for types that may or may not
// implement [SessionPacketConn]. If c implements it, ReadFromSession uses
// c.ReadFromSession. Otherwise, it uses c.ReadFrom and the session is created
// by using [NewSimplePacketSession] with c.LocalAddr.
func ReadFromSession(c net.PacketConn, b []byte) (n int, s PacketSession, err error) {
if spc, ok := c.(SessionPacketConn); ok {
return spc.ReadFromSession(b)
}
n, raddr, err := c.ReadFrom(b)
s = NewSimplePacketSession(c.LocalAddr(), raddr)
return n, s, err
}
// WriteToSession is a convenience wrapper for types that may or may not
// implement [SessionPacketConn]. If c implements it, WriteToSession uses
// c.WriteToSession. Otherwise, it uses c.WriteTo using s.RemoteAddr.
func WriteToSession(c net.PacketConn, b []byte, s PacketSession) (n int, err error) {
if spc, ok := c.(SessionPacketConn); ok {
return spc.WriteToSession(b, s)
}
return c.WriteTo(b, s.RemoteAddr())
}

View File

@ -0,0 +1,159 @@
//go:build linux
// TODO(a.garipov): Technically, we can expand this to other platforms, but that
// would require separate udpOOBSize constants and tests.
package netext
import (
"fmt"
"net"
"sync"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// type check
var _ PacketSession = (*packetSession)(nil)
// packetSession contains additional information about the packet read from a
// UDP connection. It is basically an extended version of [dns.SessionUDP] that
// contains the local address as well.
type packetSession struct {
laddr *net.UDPAddr
raddr *net.UDPAddr
respOOB []byte
}
// LocalAddr implements the [PacketSession] interface for *packetSession.
func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr }
// RemoteAddr implements the [PacketSession] interface for *packetSession.
func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr }
// type check
var _ SessionPacketConn = (*sessionPacketConn)(nil)
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
// that.
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
return &sessionPacketConn{
UDPConn: *c.(*net.UDPConn),
}
}
// sessionPacketConn wraps a UDP connection and implements [SessionPacketConn].
type sessionPacketConn struct {
net.UDPConn
}
// oobPool is the pool of byte slices for out-of-band data.
var oobPool = &sync.Pool{
New: func() (v any) {
b := make([]byte, IPDstOOBSize)
return &b
},
}
// IPDstOOBSize is the required size of the control-message buffer for
// [net.UDPConn.ReadMsgUDP] to read the original destination on Linux.
//
// See packetconn_linux_internal_test.go.
const IPDstOOBSize = 40
// ReadFromSession implements the [SessionPacketConn] interface for *packetConn.
func (c *sessionPacketConn) ReadFromSession(b []byte) (n int, s PacketSession, err error) {
oobPtr := oobPool.Get().(*[]byte)
defer oobPool.Put(oobPtr)
var oobn int
oob := *oobPtr
ps := &packetSession{}
n, oobn, _, ps.raddr, err = c.ReadMsgUDP(b, oob)
if err != nil {
return 0, nil, err
}
var origDstIP net.IP
sockLAddr := c.LocalAddr().(*net.UDPAddr)
origDstIP, err = origLAddr(oob[:oobn])
if err != nil {
return 0, nil, fmt.Errorf("getting original addr: %w", err)
}
if origDstIP == nil {
ps.laddr = sockLAddr
} else {
ps.respOOB = newRespOOB(origDstIP)
ps.laddr = &net.UDPAddr{
IP: origDstIP,
Port: sockLAddr.Port,
}
}
return n, ps, nil
}
// origLAddr returns the original local address from the encoded control-message
// data, if there is one. If not nil, origDst will have a protocol-appropriate
// length.
func origLAddr(oob []byte) (origDst net.IP, err error) {
ctrlMsg6 := &ipv6.ControlMessage{}
err = ctrlMsg6.Parse(oob)
if err != nil {
return nil, fmt.Errorf("parsing ipv6 control message: %w", err)
}
if dst := ctrlMsg6.Dst; dst != nil {
// Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
// IPv4 dst from an IPv6 control-message.
origDst = dst.To4()
if origDst == nil {
origDst = dst
}
return origDst, nil
}
ctrlMsg4 := &ipv4.ControlMessage{}
err = ctrlMsg4.Parse(oob)
if err != nil {
return nil, fmt.Errorf("parsing ipv4 control message: %w", err)
}
return ctrlMsg4.Dst.To4(), nil
}
// newRespOOB returns an encoded control-message for the response for this IP
// address. origDst is expected to have a protocol-appropriate length.
func newRespOOB(origDst net.IP) (b []byte) {
switch len(origDst) {
case net.IPv4len:
cm := &ipv4.ControlMessage{
Src: origDst,
}
return cm.Marshal()
case net.IPv6len:
cm := &ipv6.ControlMessage{
Src: origDst,
}
return cm.Marshal()
default:
return nil
}
}
// WriteToSession implements the [SessionPacketConn] interface for *packetConn.
func (c *sessionPacketConn) WriteToSession(b []byte, s PacketSession) (n int, err error) {
if ps, ok := s.(*packetSession); ok {
n, _, err = c.WriteMsgUDP(b, ps.respOOB, ps.raddr)
return n, err
}
return c.WriteTo(b, s.RemoteAddr())
}

View File

@ -0,0 +1,25 @@
//go:build linux
package netext
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func TestUDPOOBSize(t *testing.T) {
// See https://github.com/miekg/dns/blob/v1.1.50/udp.go.
len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface))
len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface))
max := len4
if len6 > max {
max = len6
}
assert.Equal(t, max, IPDstOOBSize)
}

View File

@ -0,0 +1,140 @@
//go:build linux
package netext_test
import (
"context"
"fmt"
"net"
"os"
"syscall"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Add IPv6 test.
func TestSessionPacketConn(t *testing.T) {
const numTries = 5
// Try the test multiple times to reduce flakiness due to UDP failures.
var success4, success6 bool
for i := 0; i < numTries; i++ {
var isTimeout4, isTimeout6 bool
success4 = t.Run(fmt.Sprintf("ipv4_%d", i), func(t *testing.T) {
isTimeout4 = testSessionPacketConn(t, "udp4", "0.0.0.0:0", net.IP{127, 0, 0, 1})
})
success6 = t.Run(fmt.Sprintf("ipv6_%d", i), func(t *testing.T) {
isTimeout6 = testSessionPacketConn(t, "udp6", "[::]:0", net.IPv6loopback)
})
if success4 && success6 {
break
} else if isTimeout4 || isTimeout6 {
continue
}
t.Fail()
}
if !success4 {
t.Errorf("ipv4 test failed after %d attempts", numTries)
} else if !success6 {
t.Errorf("ipv6 test failed after %d attempts", numTries)
}
}
func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) {
lc := netext.DefaultListenConfigWithOOB()
require.NotNil(t, lc)
c, err := lc.ListenPacket(context.Background(), proto, addr)
if isTimeoutOrFail(t, err) {
return true
}
require.NotNil(t, c)
deadline := time.Now().Add(1 * time.Second)
err = c.SetDeadline(deadline)
require.NoError(t, err)
laddr := testutil.RequireTypeAssert[*net.UDPAddr](t, c.LocalAddr())
require.NotNil(t, laddr)
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: laddr.Port,
}
remoteConn, err := net.DialUDP(proto, nil, dstAddr)
if proto == "udp6" && errors.Is(err, syscall.EADDRNOTAVAIL) {
// Some CI machines have IPv6 disabled.
t.Skipf("ipv6 seems to not be supported: %s", err)
} else if isTimeoutOrFail(t, err) {
return true
}
err = remoteConn.SetDeadline(deadline)
require.NoError(t, err)
msg := []byte("hello")
msgLen := len(msg)
_, err = remoteConn.Write(msg)
if isTimeoutOrFail(t, err) {
return true
}
require.Implements(t, (*netext.SessionPacketConn)(nil), c)
buf := make([]byte, msgLen)
n, sess, err := netext.ReadFromSession(c, buf)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, msgLen, n)
assert.Equal(t, net.Addr(dstAddr), sess.LocalAddr())
assert.Equal(t, remoteConn.LocalAddr(), sess.RemoteAddr())
assert.Equal(t, msg, buf)
respMsg := []byte("world")
respMsgLen := len(respMsg)
n, err = netext.WriteToSession(c, respMsg, sess)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, respMsgLen, n)
buf = make([]byte, respMsgLen)
n, err = remoteConn.Read(buf)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, respMsgLen, n)
assert.Equal(t, respMsg, buf)
return false
}
// isTimeoutOrFail is a helper function that returns true if err is a timeout
// error and also calls require.NoError on err.
func isTimeoutOrFail(t *testing.T, err error) (ok bool) {
t.Helper()
if err == nil {
return false
}
defer require.NoError(t, err)
return errors.Is(err, os.ErrDeadlineExceeded)
}

View File

@ -0,0 +1,11 @@
//go:build !linux
package netext
import "net"
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
// that.
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
return c
}

View File

@ -27,7 +27,8 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
Duration: time.Minute, Duration: time.Minute,
Count: rps, Count: rps,
ResponseSizeEstimate: 1000, ResponseSizeEstimate: 1000,
RPS: rps, IPv4RPS: rps,
IPv6RPS: rps,
RefuseANY: true, RefuseANY: true,
}) })
rlMw, err := ratelimit.NewMiddleware(rl, nil) rlMw, err := ratelimit.NewMiddleware(rl, nil)

View File

@ -37,15 +37,22 @@ type BackOffConfig struct {
// as several responses. // as several responses.
ResponseSizeEstimate int ResponseSizeEstimate int
// RPS is the maximum number of requests per second allowed from a single // IPv4RPS is the maximum number of requests per second allowed from a
// subnet. Any requests above this rate are counted as the client's // single subnet for IPv4 addresses. Any requests above this rate are
// back-off count. RPS must be greater than zero. // counted as the client's back-off count. RPS must be greater than
RPS int // zero.
IPv4RPS int
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate // IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero. // rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
IPv4SubnetKeyLen int IPv4SubnetKeyLen int
// IPv6RPS is the maximum number of requests per second allowed from a
// single subnet for IPv6 addresses. Any requests above this rate are
// counted as the client's back-off count. RPS must be greater than
// zero.
IPv6RPS int
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate // IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero. // rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
IPv6SubnetKeyLen int IPv6SubnetKeyLen int
@ -62,14 +69,18 @@ type BackOffConfig struct {
// current implementation might be too abstract. Middlewares by themselves // current implementation might be too abstract. Middlewares by themselves
// already provide an interface that can be re-implemented by the users. // already provide an interface that can be re-implemented by the users.
// Perhaps, another layer of abstraction is unnecessary. // Perhaps, another layer of abstraction is unnecessary.
//
// TODO(ameshkov): Consider splitting rps and other properties by protocol
// family.
type BackOff struct { type BackOff struct {
rpsCounters *cache.Cache rpsCounters *cache.Cache
hitCounters *cache.Cache hitCounters *cache.Cache
allowlist Allowlist allowlist Allowlist
count int count int
rps int
respSzEst int respSzEst int
ipv4rps int
ipv4SubnetKeyLen int ipv4SubnetKeyLen int
ipv6rps int
ipv6SubnetKeyLen int ipv6SubnetKeyLen int
refuseANY bool refuseANY bool
} }
@ -84,9 +95,10 @@ func NewBackOff(c *BackOffConfig) (l *BackOff) {
hitCounters: cache.New(c.Duration, c.Duration), hitCounters: cache.New(c.Duration, c.Duration),
allowlist: c.Allowlist, allowlist: c.Allowlist,
count: c.Count, count: c.Count,
rps: c.RPS,
respSzEst: c.ResponseSizeEstimate, respSzEst: c.ResponseSizeEstimate,
ipv4rps: c.IPv4RPS,
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen, ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
ipv6rps: c.IPv6RPS,
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen, ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
refuseANY: c.RefuseANY, refuseANY: c.RefuseANY,
} }
@ -124,7 +136,12 @@ func (l *BackOff) IsRateLimited(
return true, false, nil return true, false, nil
} }
return l.hasHitRateLimit(key), false, nil rps := l.ipv4rps
if ip.Is6() {
rps = l.ipv6rps
}
return l.hasHitRateLimit(key, rps), false, nil
} }
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address. // validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
@ -184,14 +201,15 @@ func (l *BackOff) incBackOff(key string) {
l.hitCounters.SetDefault(key, counter) l.hitCounters.SetDefault(key, counter)
} }
// hasHitRateLimit checks value for a subnet. // hasHitRateLimit checks value for a subnet with rps as a maximum number
func (l *BackOff) hasHitRateLimit(subnetIPStr string) (ok bool) { // requests per second.
var r *rps func (l *BackOff) hasHitRateLimit(subnetIPStr string, rps int) (ok bool) {
var r *rpsCounter
rVal, ok := l.rpsCounters.Get(subnetIPStr) rVal, ok := l.rpsCounters.Get(subnetIPStr)
if ok { if ok {
r = rVal.(*rps) r = rVal.(*rpsCounter)
} else { } else {
r = newRPS(l.rps) r = newRPSCounter(rps)
l.rpsCounters.SetDefault(subnetIPStr, r) l.rpsCounters.SetDefault(subnetIPStr, r)
} }

View File

@ -98,8 +98,9 @@ func TestRatelimitMiddleware(t *testing.T) {
Duration: time.Minute, Duration: time.Minute,
Count: rps, Count: rps,
ResponseSizeEstimate: 128, ResponseSizeEstimate: 128,
RPS: rps, IPv4RPS: rps,
IPv4SubnetKeyLen: 24, IPv4SubnetKeyLen: 24,
IPv6RPS: rps,
IPv6SubnetKeyLen: 48, IPv6SubnetKeyLen: 48,
RefuseANY: true, RefuseANY: true,
}) })

View File

@ -7,17 +7,17 @@ import (
// Requests Per Second Counter // Requests Per Second Counter
// rps is a single request per seconds counter. // rpsCounter is a single request per seconds counter.
type rps struct { type rpsCounter struct {
// mu protects all fields. // mu protects all fields.
mu *sync.Mutex mu *sync.Mutex
ring []int64 ring []int64
idx int idx int
} }
// newRPS returns a new requests per second counter. n must be above zero. // newRPSCounter returns a new requests per second counter. n must be above zero.
func newRPS(n int) (r *rps) { func newRPSCounter(n int) (r *rpsCounter) {
return &rps{ return &rpsCounter{
mu: &sync.Mutex{}, mu: &sync.Mutex{},
// Add one, because we need to always keep track of the previous // Add one, because we need to always keep track of the previous
// request. For example, consider n == 1. // request. For example, consider n == 1.
@ -28,7 +28,7 @@ func newRPS(n int) (r *rps) {
// add adds another request to the counter. above is true if the request goes // add adds another request to the counter. above is true if the request goes
// above the counter value. It is safe for concurrent use. // above the counter value. It is safe for concurrent use.
func (r *rps) add(t time.Time) (above bool) { func (r *rpsCounter) add(t time.Time) (above bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -31,16 +32,19 @@ type ConfigBase struct {
// Handler is a handler that processes incoming DNS messages. // Handler is a handler that processes incoming DNS messages.
// If not set, we'll use the default handler that returns error response // If not set, we'll use the default handler that returns error response
// to any query. // to any query.
Handler Handler Handler Handler
// Metrics is the object we use for collecting performance metrics. // Metrics is the object we use for collecting performance metrics.
// This field is optional. // This field is optional.
Metrics MetricsListener Metrics MetricsListener
// BaseContext is a function that should return the base context. If not // BaseContext is a function that should return the base context. If not
// set, we'll be using context.Background(). // set, we'll be using context.Background().
BaseContext func() (ctx context.Context) BaseContext func() (ctx context.Context)
// ListenConfig, when set, is used to set options of connections used by the
// DNS server. If nil, an appropriate default ListenConfig is used.
ListenConfig netext.ListenConfig
} }
// ServerBase implements base methods that every Server implementation uses. // ServerBase implements base methods that every Server implementation uses.
@ -67,14 +71,19 @@ type ServerBase struct {
// metrics is the object we use for collecting performance metrics. // metrics is the object we use for collecting performance metrics.
metrics MetricsListener metrics MetricsListener
// listenConfig is used to set tcpListener and udpListener.
listenConfig netext.ListenConfig
// Server operation // Server operation
// -- // --
// will be nil for servers that don't use TCP. // tcpListener is used to accept new TCP connections. It is nil for servers
// that don't use TCP.
tcpListener net.Listener tcpListener net.Listener
// will be nil for servers that don't use UDP. // udpListener is used to accept new UDP messages. It is nil for servers
udpListener *net.UDPConn // that don't use UDP.
udpListener net.PacketConn
// Shutdown handling // Shutdown handling
// -- // --
@ -94,13 +103,14 @@ var _ Server = (*ServerBase)(nil)
// some of its internal properties. // some of its internal properties.
func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) { func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
s = &ServerBase{ s = &ServerBase{
name: conf.Name, name: conf.Name,
addr: conf.Addr, addr: conf.Addr,
proto: proto, proto: proto,
network: conf.Network, network: conf.Network,
handler: conf.Handler, handler: conf.Handler,
metrics: conf.Metrics, metrics: conf.Metrics,
baseContext: conf.BaseContext, listenConfig: conf.ListenConfig,
baseContext: conf.BaseContext,
} }
if s.baseContext == nil { if s.baseContext == nil {
@ -347,18 +357,17 @@ func (s *ServerBase) handlePanicAndRecover(ctx context.Context) {
} }
} }
// listenUDP creates a UDP listener for the ServerBase.addr. This function will // listenUDP initializes and starts s.udpListener using s.addr. If the TCP
// initialize and start ServerBase.udpListener or return an error. If the TCP // listener is already running, its address is used instead to properly handle
// listener is already running, its address is used instead. The point of this // the case when port 0 is used as both listeners should use the same port, and
// is to properly handle the case when port 0 is used as both listeners should // we only learn it after the first one was started.
// use the same port, and we only learn it after the first one was started.
func (s *ServerBase) listenUDP(ctx context.Context) (err error) { func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
addr := s.addr addr := s.addr
if s.tcpListener != nil { if s.tcpListener != nil {
addr = s.tcpListener.Addr().String() addr = s.tcpListener.Addr().String()
} }
conn, err := listenUDP(ctx, addr, true) conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr)
if err != nil { if err != nil {
return err return err
} }
@ -368,19 +377,17 @@ func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
return nil return nil
} }
// listenTCP creates a TCP listener for the ServerBase.addr. This function will // listenTCP initializes and starts s.tcpListener using s.addr. If the UDP
// initialize and start ServerBase.tcpListener or return an error. If the UDP // listener is already running, its address is used instead to properly handle
// listener is already running, its address is used instead. The point of this // the case when port 0 is used as both listeners should use the same port, and
// is to properly handle the case when port 0 is used as both listeners should // we only learn it after the first one was started.
// use the same port, and we only learn it after the first one was started.
func (s *ServerBase) listenTCP(ctx context.Context) (err error) { func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
addr := s.addr addr := s.addr
if s.udpListener != nil { if s.udpListener != nil {
addr = s.udpListener.LocalAddr().String() addr = s.udpListener.LocalAddr().String()
} }
var l net.Listener l, err := s.listenConfig.Listen(ctx, "tcp", addr)
l, err = listenTCP(ctx, addr)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,6 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -110,6 +111,10 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) {
conf.TCPSize = dns.MinMsgSize conf.TCPSize = dns.MinMsgSize
} }
if conf.ListenConfig == nil {
conf.ListenConfig = netext.DefaultListenConfigWithOOB()
}
s = &ServerDNS{ s = &ServerDNS{
ServerBase: newServerBase(proto, conf.ConfigBase), ServerBase: newServerBase(proto, conf.ConfigBase),
conf: conf, conf: conf,
@ -262,10 +267,21 @@ func makePacketBuffer(size int) (f func() any) {
} }
} }
// writeDeadlineSetter is an interface for connections that can set write
// deadlines.
type writeDeadlineSetter interface {
SetWriteDeadline(t time.Time) (err error)
}
// withWriteDeadline is a helper that takes the deadline of the context and the // withWriteDeadline is a helper that takes the deadline of the context and the
// write timeout into account. It sets the write deadline on conn before // write timeout into account. It sets the write deadline on conn before
// calling f and resets it once f is done. // calling f and resets it once f is done.
func withWriteDeadline(ctx context.Context, writeTimeout time.Duration, conn net.Conn, f func()) { func withWriteDeadline(
ctx context.Context,
writeTimeout time.Duration,
conn writeDeadlineSetter,
f func(),
) {
dl, hasDeadline := ctx.Deadline() dl, hasDeadline := ctx.Deadline()
if !hasDeadline { if !hasDeadline {
dl = time.Now().Add(writeTimeout) dl = time.Now().Add(writeTimeout)

View File

@ -284,8 +284,8 @@ func TestServerDNS_integration_query(t *testing.T) {
tc.expectedTruncated, tc.expectedTruncated,
) )
reqKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req) reqKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
respKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp) respKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil { if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil {
require.NotNil(t, respKeepAliveOpt) require.NotNil(t, respKeepAliveOpt)
expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100) expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100)

View File

@ -2,7 +2,9 @@ package dnsserver
import ( import (
"context" "context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2" "github.com/ameshkov/dnscrypt/v2"
@ -38,6 +40,10 @@ var _ Server = (*ServerDNSCrypt)(nil)
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt. // NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) { func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
if conf.ListenConfig == nil {
conf.ListenConfig = netext.DefaultListenConfig()
}
return &ServerDNSCrypt{ return &ServerDNSCrypt{
ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase), ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase),
conf: conf, conf: conf,
@ -140,7 +146,10 @@ func (s *ServerDNSCrypt) startServeUDP(ctx context.Context) {
// TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in // TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in
// dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext // dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext
// methods for now. // methods for now.
err := s.dnsCryptServer.ServeUDP(s.udpListener) //
// TODO(ameshkov): Redo the dnscrypt module to make it not depend on
// *net.UDPConn and use net.PacketConn instead.
err := s.dnsCryptServer.ServeUDP(s.udpListener.(*net.UDPConn))
if err != nil { if err != nil {
log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err) log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err)
} }

View File

@ -4,23 +4,21 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"runtime"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
// serveUDP runs the UDP serving loop. // serveUDP runs the UDP serving loop.
func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) { func (s *ServerDNS) serveUDP(ctx context.Context, conn net.PacketConn) (err error) {
defer log.OnCloserError(conn, log.DEBUG) defer log.OnCloserError(conn, log.DEBUG)
for s.isStarted() { for s.isStarted() {
var m []byte var m []byte
var sess *dns.SessionUDP var sess netext.PacketSession
m, sess, err = s.readUDPMsg(ctx, conn) m, sess, err = s.readUDPMsg(ctx, conn)
if err != nil { if err != nil {
// TODO(ameshkov): Consider the situation where the server is shut // TODO(ameshkov): Consider the situation where the server is shut
@ -59,15 +57,15 @@ func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error)
func (s *ServerDNS) serveUDPPacket( func (s *ServerDNS) serveUDPPacket(
ctx context.Context, ctx context.Context,
m []byte, m []byte,
conn *net.UDPConn, conn net.PacketConn,
udpSession *dns.SessionUDP, sess netext.PacketSession,
) { ) {
defer s.wg.Done() defer s.wg.Done()
defer s.handlePanicAndRecover(ctx) defer s.handlePanicAndRecover(ctx)
rw := &udpResponseWriter{ rw := &udpResponseWriter{
udpSession: sess,
conn: conn, conn: conn,
udpSession: udpSession,
writeTimeout: s.conf.WriteTimeout, writeTimeout: s.conf.WriteTimeout,
} }
s.serveDNS(ctx, m, rw) s.serveDNS(ctx, m, rw)
@ -75,15 +73,18 @@ func (s *ServerDNS) serveUDPPacket(
} }
// readUDPMsg reads the next incoming DNS message. // readUDPMsg reads the next incoming DNS message.
func (s *ServerDNS) readUDPMsg(ctx context.Context, conn *net.UDPConn) (msg []byte, sess *dns.SessionUDP, err error) { func (s *ServerDNS) readUDPMsg(
ctx context.Context,
conn net.PacketConn,
) (msg []byte, sess netext.PacketSession, err error) {
err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
m := s.getUDPBuffer() m := s.getUDPBuffer()
var n int
n, sess, err = dns.ReadFromSessionUDP(conn, m) n, sess, err := netext.ReadFromSession(conn, m)
if err != nil { if err != nil {
s.putUDPBuffer(m) s.putUDPBuffer(m)
@ -120,30 +121,10 @@ func (s *ServerDNS) putUDPBuffer(m []byte) {
s.udpPool.Put(&m) s.udpPool.Put(&m)
} }
// setUDPSocketOptions is a function that is necessary to be able to use
// dns.ReadFromSessionUDP and dns.WriteToSessionUDP.
// TODO(ameshkov): https://github.com/AdguardTeam/AdGuardHome/issues/2807
func setUDPSocketOptions(conn *net.UDPConn) (err error) {
if runtime.GOOS == "windows" {
return nil
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err4 != nil && err6 != nil {
return errors.List("error while setting NetworkUDP socket options", err4, err6)
}
return nil
}
// udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP. // udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP.
type udpResponseWriter struct { type udpResponseWriter struct {
udpSession *dns.SessionUDP udpSession netext.PacketSession
conn *net.UDPConn conn net.PacketConn
writeTimeout time.Duration writeTimeout time.Duration
} }
@ -152,13 +133,15 @@ var _ ResponseWriter = (*udpResponseWriter)(nil)
// LocalAddr implements the ResponseWriter interface for *udpResponseWriter. // LocalAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) LocalAddr() (addr net.Addr) { func (r *udpResponseWriter) LocalAddr() (addr net.Addr) {
return r.conn.LocalAddr() // Don't use r.conn.LocalAddr(), since udpSession may actually contain the
// decoded OOB data, including the real local (dst) address.
return r.udpSession.LocalAddr()
} }
// RemoteAddr implements the ResponseWriter interface for *udpResponseWriter. // RemoteAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) { func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) {
// Don't use r.conn.RemoteAddr(), since udpSession actually contains the // Don't use r.conn.RemoteAddr(), since udpSession may actually contain the
// decoded OOB data, including the remote address. // decoded OOB data, including the real remote (src) address.
return r.udpSession.RemoteAddr() return r.udpSession.RemoteAddr()
} }
@ -173,7 +156,7 @@ func (r *udpResponseWriter) WriteMsg(ctx context.Context, req, resp *dns.Msg) (e
} }
withWriteDeadline(ctx, r.writeTimeout, r.conn, func() { withWriteDeadline(ctx, r.writeTimeout, r.conn, func() {
_, err = dns.WriteToSessionUDP(r.conn, data, r.udpSession) _, err = netext.WriteToSession(r.conn, data, r.udpSession)
}) })
if err != nil { if err != nil {

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -94,6 +95,12 @@ var _ Server = (*ServerHTTPS)(nil)
// NewServerHTTPS creates a new ServerHTTPS instance. // NewServerHTTPS creates a new ServerHTTPS instance.
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) { func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
if conf.ListenConfig == nil {
// Do not enable OOB here, because ListenPacket is only used by HTTP/3,
// and quic-go sets the necessary flags.
conf.ListenConfig = netext.DefaultListenConfig()
}
s = &ServerHTTPS{ s = &ServerHTTPS{
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase), ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
conf: conf, conf: conf,
@ -500,8 +507,7 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
tlsConf.NextProtos = nextProtoDoH3 tlsConf.NextProtos = nextProtoDoH3
} }
// Do not enable OOB here as quic-go will do that on its own. conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
conn, err := listenUDP(ctx, s.addr, false)
if err != nil { if err != nil {
return err return err
} }
@ -518,24 +524,28 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
return nil return nil
} }
// httpContextWithClientInfo adds client info to the context. // httpContextWithClientInfo adds client info to the context. ctx is never nil,
// even when there is an error.
func httpContextWithClientInfo( func httpContextWithClientInfo(
parent context.Context, parent context.Context,
r *http.Request, r *http.Request,
) (ctx context.Context, err error) { ) (ctx context.Context, err error) {
ctx = parent
ci := ClientInfo{ ci := ClientInfo{
URL: netutil.CloneURL(r.URL), URL: netutil.CloneURL(r.URL),
} }
// Due to the quic-go bug we should use Host instead of r.TLS: // Due to the quic-go bug we should use Host instead of r.TLS. See
// https://github.com/lucas-clemente/quic-go/issues/3596 // https://github.com/quic-go/quic-go/issues/2879 and
// https://github.com/lucas-clemente/quic-go/issues/3596.
// //
// TODO(ameshkov): remove this when the bug is fixed in quic-go. // TODO(ameshkov): Remove when quic-go is fixed, likely in v0.32.0.
if r.ProtoAtLeast(3, 0) { if r.ProtoAtLeast(3, 0) {
var host string var host string
host, err = netutil.SplitHost(r.Host) host, err = netutil.SplitHost(r.Host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse Host: %w", err) return ctx, fmt.Errorf("failed to parse Host: %w", err)
} }
ci.TLSServerName = host ci.TLSServerName = host
@ -543,7 +553,7 @@ func httpContextWithClientInfo(
ci.TLSServerName = strings.ToLower(r.TLS.ServerName) ci.TLSServerName = strings.ToLower(r.TLS.ServerName)
} }
return ContextWithClientInfo(parent, ci), nil return ContextWithClientInfo(ctx, ci), nil
} }
// httpRequestToMsg reads the DNS message from http.Request. // httpRequestToMsg reads the DNS message from http.Request.

View File

@ -133,7 +133,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
require.True(t, resp.Response) require.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt. // EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt) require.Nil(t, paddingOpt)
}) })
} }
@ -338,7 +338,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req) resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req)
require.True(t, resp.Response) require.True(t, resp.Response)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt) require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding) require.NotEmpty(t, paddingOpt.Padding)
} }

View File

@ -12,6 +12,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache" "github.com/bluele/gcache"
@ -98,6 +99,11 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) {
tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...) tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...)
} }
if conf.ListenConfig == nil {
// Do not enable OOB here as quic-go will do that on its own.
conf.ListenConfig = netext.DefaultListenConfig()
}
s = &ServerQUIC{ s = &ServerQUIC{
ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase), ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase),
conf: conf, conf: conf,
@ -476,8 +482,7 @@ func (s *ServerQUIC) readQUICMsg(
// listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts // listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts
// the QUIC listener. // the QUIC listener.
func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) { func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) {
// Do not enable OOB here as quic-go will do that on its own. conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
conn, err := listenUDP(ctx, s.addr, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -59,7 +59,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
assert.True(t, resp.Response) assert.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt. // EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt) require.Nil(t, paddingOpt)
}() }()
} }
@ -97,7 +97,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response) require.True(t, resp.Response)
require.False(t, resp.Truncated) require.False(t, resp.Truncated)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt) require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding) require.NotEmpty(t, paddingOpt.Padding)
} }

View File

@ -3,7 +3,6 @@ package dnsserver
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -102,10 +101,9 @@ func (s *ServerTLS) startServeTCP(ctx context.Context) {
} }
} }
// listenTLS creates the TLS listener for the ServerTLS.addr. // listenTLS creates the TLS listener for s.addr.
func (s *ServerTLS) listenTLS(ctx context.Context) (err error) { func (s *ServerTLS) listenTLS(ctx context.Context) (err error) {
var l net.Listener l, err := s.listenConfig.Listen(ctx, "tcp", s.addr)
l, err = listenTCP(ctx, s.addr)
if err != nil { if err != nil {
return err return err
} }

View File

@ -40,7 +40,7 @@ func TestServerTLS_integration_queryTLS(t *testing.T) {
require.False(t, resp.Truncated) require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt. // EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt) require.Nil(t, paddingOpt)
} }
@ -142,7 +142,7 @@ func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
require.False(t, resp.Truncated) require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt. // EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt) require.Nil(t, paddingOpt)
} }
@ -231,7 +231,7 @@ func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response) require.True(t, resp.Response)
require.False(t, resp.Truncated) require.False(t, resp.Truncated)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt) require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding) require.NotEmpty(t, paddingOpt.Padding)
} }

View File

@ -0,0 +1,172 @@
package dnssvc
import (
"context"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTXTExtra is a helper function that converts strs into DNS TXT resource
// records with Name and Txt fields set to first and second values of each
// tuple.
func newTXTExtra(strs [][2]string) (extra []dns.RR) {
for _, v := range strs {
extra = append(extra, &dns.TXT{
Hdr: dns.RR_Header{
Name: v[0],
Rrtype: dns.TypeTXT,
Class: dns.ClassCHAOS,
Ttl: 1,
},
Txt: []string{v[1]},
})
}
return extra
}
func TestService_writeDebugResponse(t *testing.T) {
svc := &Service{messages: &dnsmsg.Constructor{FilteredResponseTTL: time.Second}}
const (
fltListID1 agd.FilterListID = "fl1"
fltListID2 agd.FilterListID = "fl2"
blockRule = "||example.com^"
)
testCases := []struct {
name string
ri *agd.RequestInfo
reqRes filter.Result
respRes filter.Result
wantExtra []dns.RR
}{{
name: "normal",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "request_result_blocked",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "blocked"},
{"req.rule.adguard-dns.com.", "||example.com^"},
{"req.rule-list-id.adguard-dns.com.", "fl1"},
}),
}, {
name: "response_result_blocked",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule},
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "blocked"},
{"resp.rule.adguard-dns.com.", "||example.com^"},
{"resp.rule-list-id.adguard-dns.com.", "fl2"},
}),
}, {
name: "request_result_allowed",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultAllowed{},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "allowed"},
{"req.rule.adguard-dns.com.", ""},
{"req.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "response_result_allowed",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: &filter.ResultAllowed{},
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "allowed"},
{"resp.rule.adguard-dns.com.", ""},
{"resp.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "request_result_modified",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultModified{
Rule: "||example.com^$dnsrewrite=REFUSED",
},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "modified"},
{"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"},
{"req.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "device",
ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"device-id.adguard-dns.com.", "dev1234"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "profile",
ri: &agd.RequestInfo{
Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")},
},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"profile-id.adguard-dns.com.", "some-profile-id"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "location",
ri: &agd.RequestInfo{Location: &agd.Location{Country: agd.CountryAD}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"country.adguard-dns.com.", "AD"},
{"asn.adguard-dns.com.", "0"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)
req := dnsservertest.NewReq("example.com", dns.TypeA, dns.ClassINET)
resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
err := svc.writeDebugResponse(ctx, rw, req, resp, tc.reqRes, tc.respRes)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantExtra, msg.Extra)
})
}
}

View File

@ -195,7 +195,15 @@ func deviceIDFromEDNS(req *dns.Msg) (id agd.DeviceID, err error) {
continue continue
} }
return agd.NewDeviceID(string(o.Data)) id, err = agd.NewDeviceID(string(o.Data))
if err != nil {
return "", &deviceIDError{
err: err,
typ: "edns option",
}
}
return id, nil
} }
return "", nil return "", nil

View File

@ -233,7 +233,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte{}, Data: []byte{},
}, },
wantDeviceID: "", wantDeviceID: "",
wantErrMsg: `bad device id "": too short: got 0 bytes, min 1`, wantErrMsg: `edns option device id check: bad device id "": ` +
`too short: got 0 bytes, min 1`,
}, { }, {
name: "bad_device_id", name: "bad_device_id",
opt: &dns.EDNS0_LOCAL{ opt: &dns.EDNS0_LOCAL{
@ -241,7 +242,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte("toolongdeviceid"), Data: []byte("toolongdeviceid"),
}, },
wantDeviceID: "", wantDeviceID: "",
wantErrMsg: `bad device id "toolongdeviceid": too long: got 15 bytes, max 8`, wantErrMsg: `edns option device id check: bad device id "toolongdeviceid": ` +
`too long: got 15 bytes, max 8`,
}, { }, {
name: "device_id", name: "device_id",
opt: &dns.EDNS0_LOCAL{ opt: &dns.EDNS0_LOCAL{

View File

@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat" "github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@ -79,10 +78,6 @@ type Config struct {
// Upstream defines the upstream server and the group of fallback servers. // Upstream defines the upstream server and the group of fallback servers.
Upstream *agd.Upstream Upstream *agd.Upstream
// RootRedirectURL is the URL to which non-DNS and non-Debug HTTP requests
// are redirected.
RootRedirectURL *url.URL
// NewListener, when set, is used instead of the package-level function // NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener. // NewListener when creating a DNS listener.
// //
@ -119,6 +114,11 @@ type Config struct {
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be // UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used. // used.
UseECSCache bool UseECSCache bool
// ResearchMetrics controls whether research metrics are enabled or not.
// This is a set of metrics that we may need temporary, so its collection is
// controlled by a separate setting.
ResearchMetrics bool
} }
// New returns a new DNS service. // New returns a new DNS service.
@ -150,7 +150,6 @@ func New(c *Config) (svc *Service, err error) {
groups := make([]*serverGroup, len(c.ServerGroups)) groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{ svc = &Service{
messages: c.Messages, messages: c.Messages,
rootRedirectURL: c.RootRedirectURL,
billStat: c.BillStat, billStat: c.BillStat,
errColl: c.ErrColl, errColl: c.ErrColl,
fltStrg: c.FilterStorage, fltStrg: c.FilterStorage,
@ -158,6 +157,7 @@ func New(c *Config) (svc *Service, err error) {
queryLog: c.QueryLog, queryLog: c.QueryLog,
ruleStat: c.RuleStat, ruleStat: c.RuleStat,
groups: groups, groups: groups,
researchMetrics: c.ResearchMetrics,
} }
for i, srvGrp := range c.ServerGroups { for i, srvGrp := range c.ServerGroups {
@ -212,7 +212,6 @@ var _ agd.Service = (*Service)(nil)
// Service is the main DNS service of AdGuard DNS. // Service is the main DNS service of AdGuard DNS.
type Service struct { type Service struct {
messages *dnsmsg.Constructor messages *dnsmsg.Constructor
rootRedirectURL *url.URL
billStat billstat.Recorder billStat billstat.Recorder
errColl agd.ErrorCollector errColl agd.ErrorCollector
fltStrg filter.Storage fltStrg filter.Storage
@ -220,6 +219,7 @@ type Service struct {
queryLog querylog.Interface queryLog querylog.Interface
ruleStat rulestat.Interface ruleStat rulestat.Interface
groups []*serverGroup groups []*serverGroup
researchMetrics bool
} }
// mustStartListener starts l and panics on any error. // mustStartListener starts l and panics on any error.

View File

@ -260,14 +260,6 @@ func (mh *initMwHandler) ServeDNS(
// Copy middleware to the local variable to make the code simpler. // Copy middleware to the local variable to make the code simpler.
mw := mh.mw mw := mh.mw
if specHdlr, name := mw.noReqInfoSpecialHandler(fqdn, qt, cl); specHdlr != nil {
optlog.Debug1("init mw: got no-req-info special handler %s", name)
// Don't wrap the error, because it's informative enough as is, and
// because if handled is true, the main flow terminates here.
return specHdlr(ctx, rw, req)
}
// Get the request's information, such as GeoIP data and user profiles. // Get the request's information, such as GeoIP data and user profiles.
ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl) ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
if err != nil { if err != nil {

View File

@ -277,7 +277,7 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
} }
} }
func TestInitMw_ServeDNS_privateRelay(t *testing.T) { func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
host string host string
@ -287,7 +287,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked bool profBlocked bool
wantRCode dnsmsg.RCode wantRCode dnsmsg.RCode
}{{ }{{
name: "blocked_by_fltgrp", name: "private_relay_blocked_by_fltgrp",
host: applePrivateRelayMaskHost, host: applePrivateRelayMaskHost,
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: true, fltGrpBlocked: true,
@ -295,7 +295,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false, profBlocked: false,
wantRCode: dns.RcodeNameError, wantRCode: dns.RcodeNameError,
}, { }, {
name: "no_private_relay_domain", name: "no_special_domain",
host: "www.example.com", host: "www.example.com",
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: true, fltGrpBlocked: true,
@ -311,7 +311,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false, profBlocked: false,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
name: "blocked_by_prof", name: "private_relay_blocked_by_prof",
host: applePrivateRelayMaskHost, host: applePrivateRelayMaskHost,
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: false, fltGrpBlocked: false,
@ -319,7 +319,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: true, profBlocked: true,
wantRCode: dns.RcodeNameError, wantRCode: dns.RcodeNameError,
}, { }, {
name: "allowed_by_prof", name: "private_relay_allowed_by_prof",
host: applePrivateRelayMaskHost, host: applePrivateRelayMaskHost,
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: true, fltGrpBlocked: true,
@ -327,7 +327,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false, profBlocked: false,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
name: "allowed_by_both", name: "private_relay_allowed_by_both",
host: applePrivateRelayMaskHost, host: applePrivateRelayMaskHost,
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: false, fltGrpBlocked: false,
@ -335,13 +335,45 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false, profBlocked: false,
wantRCode: dns.RcodeSuccess, wantRCode: dns.RcodeSuccess,
}, { }, {
name: "blocked_by_both", name: "private_relay_blocked_by_both",
host: applePrivateRelayMaskHost, host: applePrivateRelayMaskHost,
qtype: dns.TypeA, qtype: dns.TypeA,
fltGrpBlocked: true, fltGrpBlocked: true,
hasProf: true, hasProf: true,
profBlocked: true, profBlocked: true,
wantRCode: dns.RcodeNameError, wantRCode: dns.RcodeNameError,
}, {
name: "firefox_canary_allowed_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_allowed_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_blocked_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeRefused,
}, {
name: "firefox_canary_blocked_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeRefused,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -368,9 +400,12 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
return nil, nil, agd.DeviceNotFoundError{} return nil, nil, agd.DeviceNotFoundError{}
} }
return &agd.Profile{ prof := &agd.Profile{
BlockPrivateRelay: tc.profBlocked, BlockPrivateRelay: tc.profBlocked,
}, &agd.Device{}, nil BlockFirefoxCanary: tc.profBlocked,
}
return prof, &agd.Device{}, nil
} }
db := &agdtest.ProfileDB{ db := &agdtest.ProfileDB{
OnProfileByDeviceID: func( OnProfileByDeviceID: func(
@ -406,7 +441,8 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
FilteredResponseTTL: 10 * time.Second, FilteredResponseTTL: 10 * time.Second,
}, },
fltGrp: &agd.FilteringGroup{ fltGrp: &agd.FilteringGroup{
BlockPrivateRelay: tc.fltGrpBlocked, BlockPrivateRelay: tc.fltGrpBlocked,
BlockFirefoxCanary: tc.fltGrpBlocked,
}, },
srvGrp: &agd.ServerGroup{}, srvGrp: &agd.ServerGroup{},
srv: &agd.Server{ srv: &agd.Server{
@ -436,7 +472,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
resp := rw.Msg() resp := rw.Msg()
require.NotNil(t, resp) require.NotNil(t, resp)
assert.Equal(t, dnsmsg.RCode(resp.Rcode), tc.wantRCode) assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
}) })
} }
} }

View File

@ -101,7 +101,7 @@ func (svc *Service) filterQuery(
) (reqRes, respRes filter.Result) { ) (reqRes, respRes filter.Result) {
start := time.Now() start := time.Now()
defer func() { defer func() {
reportMetrics(ri, reqRes, respRes, time.Since(start)) svc.reportMetrics(ri, reqRes, respRes, time.Since(start))
}() }()
f := svc.fltStrg.FilterFromContext(ctx, ri) f := svc.fltStrg.FilterFromContext(ctx, ri)
@ -120,7 +120,7 @@ func (svc *Service) filterQuery(
// reportMetrics extracts filtering metrics data from the context and reports it // reportMetrics extracts filtering metrics data from the context and reports it
// to Prometheus. // to Prometheus.
func reportMetrics( func (svc *Service) reportMetrics(
ri *agd.RequestInfo, ri *agd.RequestInfo,
reqRes filter.Result, reqRes filter.Result,
respRes filter.Result, respRes filter.Result,
@ -139,7 +139,7 @@ func reportMetrics(
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc() metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc() metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
id, _, _ := filteringData(reqRes, respRes) id, _, blocked := filteringData(reqRes, respRes)
metrics.DNSSvcRequestByFilterTotal.WithLabelValues( metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
string(id), string(id),
metrics.BoolString(ri.Profile == nil), metrics.BoolString(ri.Profile == nil),
@ -147,6 +147,22 @@ func reportMetrics(
metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds()) metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds())
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP) metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
if svc.researchMetrics {
anonymous := ri.Profile == nil
filteringEnabled := ri.FilteringGroup != nil &&
ri.FilteringGroup.RuleListsEnabled &&
len(ri.FilteringGroup.RuleListIDs) > 0
metrics.ReportResearchMetrics(
anonymous,
filteringEnabled,
asn,
ctry,
string(id),
blocked,
)
}
} }
// reportf is a helper method for reporting non-critical errors. // reportf is a helper method for reporting non-critical errors.

View File

@ -0,0 +1,135 @@
package dnssvc
import (
"context"
"crypto/sha256"
"encoding/hex"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
const safeBrowsingHost = "scam.example.net."
var (
ip = net.IP{127, 0, 0, 1}
name = "example.com"
)
sum := sha256.Sum256([]byte(safeBrowsingHost))
hashStr := hex.EncodeToString(sum[:])
hashes, herr := hashstorage.New(safeBrowsingHost)
require.NoError(t, herr)
srv := filter.NewSafeBrowsingServer(hashes, nil)
host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
const ttl = 60
testCases := []struct {
name string
req *dns.Msg
dnscheckResp *dns.Msg
ri *agd.RequestInfo
wantAns []dns.RR
}{{
name: "normal",
req: dnsservertest.CreateMessage(name, dns.TypeA),
dnscheckResp: nil,
ri: &agd.RequestInfo{},
wantAns: []dns.RR{
dnsservertest.NewA(name, 100, ip),
},
}, {
name: "dnscheck",
req: dnsservertest.CreateMessage(name, dns.TypeA),
dnscheckResp: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET),
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)},
Sec: dnsservertest.SectionAnswer,
},
),
ri: &agd.RequestInfo{
Host: name,
QType: dns.TypeA,
QClass: dns.ClassINET,
},
wantAns: []dns.RR{
dnsservertest.NewA(name, ttl, ip),
},
}, {
name: "with_hashes",
req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT},
wantAns: []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{
Name: safeBrowsingHost,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: ttl,
},
Txt: []string{hashStr},
}},
}, {
name: "not_matched",
req: dnsservertest.CreateMessage(name, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT},
wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
tctx := agd.ContextWithRequestInfo(ctx, tc.ri)
dnsCk := &agdtest.DNSCheck{
OnCheck: func(
ctx context.Context,
msg *dns.Msg,
ri *agd.RequestInfo,
) (resp *dns.Msg, err error) {
return tc.dnscheckResp, nil
},
}
mw := &preServiceMw{
messages: &dnsmsg.Constructor{
FilteredResponseTTL: ttl * time.Second,
},
filter: srv,
checker: dnsCk,
}
handler := dnsservertest.DefaultHandler()
h := mw.Wrap(handler)
err := h.ServeDNS(tctx, rw, tc.req)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantAns, msg.Answer)
})
}
}

View File

@ -0,0 +1,264 @@
package dnssvc
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
reqHostname = "example.com."
defaultTTL = 3600
)
func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) {
remoteIP := netip.MustParseAddr("1.2.3.4")
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
respIP := remoteIP.AsSlice()
resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)},
Sec: dnsservertest.SectionAnswer,
})
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
Host: aReq.Question[0].Name,
})
const N = 5
testCases := []struct {
name string
cacheSize int
wantNumReq int
}{{
name: "no_cache",
cacheSize: 0,
wantNumReq: N,
}, {
name: "with_cache",
cacheSize: 100,
wantNumReq: 1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
numReq := 0
handler := dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) error {
numReq++
return rw.WriteMsg(ctx, req, resp)
})
mw := &preUpstreamMw{
db: dnsdb.Empty{},
cacheSize: tc.cacheSize,
}
h := mw.Wrap(handler)
for i := 0; i < N; i++ {
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
err := h.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
}
assert.Equal(t, tc.wantNumReq, numReq)
})
}
}
func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
remoteIP := netip.MustParseAddr("1.2.3.4")
subnet := netip.MustParsePrefix("1.2.3.4/24")
const ctry = agd.CountryAD
resp := dnsservertest.NewResp(
dns.RcodeSuccess,
aReq,
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(
reqHostname,
defaultTTL,
net.IP{1, 2, 3, 4},
)},
Sec: dnsservertest.SectionAnswer,
},
)
numReq := 0
handler := dnsserver.HandlerFunc(
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
numReq++
return rw.WriteMsg(ctx, req, resp)
},
)
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
ctry agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
return netip.MustParsePrefix("1.2.0.0/16"), nil
},
OnData: func(_ string, _ netip.Addr) (_ *agd.Location, _ error) {
panic("not implemented")
},
}
mw := &preUpstreamMw{
db: dnsdb.Empty{},
geoIP: geoIP,
cacheSize: 100,
ecsCacheSize: 100,
useECSCache: true,
}
h := mw.Wrap(handler)
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
Location: &agd.Location{
Country: ctry,
},
ECS: &agd.ECS{
Location: &agd.Location{
Country: ctry,
},
Subnet: subnet,
Scope: 0,
},
Host: aReq.Question[0].Name,
RemoteIP: remoteIP,
})
const N = 5
var nrw *dnsserver.NonWriterResponseWriter
for i := 0; i < N; i++ {
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
nrw = dnsserver.NewNonWriterResponseWriter(addr, addr)
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
err := h.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
}
assert.Equal(t, 1, numReq)
}
func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
mw := &preUpstreamMw{db: dnsdb.Empty{}}
req := dnsservertest.CreateMessage("example.com", dns.TypeA)
resp := new(dns.Msg).SetReply(req)
ctx := context.Background()
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{})
const ttl = 100
testCases := []struct {
name string
req *dns.Msg
resp *dns.Msg
wantName string
wantAns []dns.RR
}{{
name: "no_changes",
req: dnsservertest.CreateMessage("example.com.", dns.TypeA),
resp: resp,
wantName: "example.com.",
wantAns: nil,
}, {
name: "android-tls-metric",
req: dnsservertest.CreateMessage(
"12345678-dnsotls-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: resp,
wantName: "00000000-dnsotls-ds.metric.gstatic.com.",
wantAns: nil,
}, {
name: "android-https-metric",
req: dnsservertest.CreateMessage(
"123456-dnsohttps-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: resp,
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
wantAns: nil,
}, {
name: "multiple_answers_metric",
req: dnsservertest.CreateMessage(
"123456-dnsohttps-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: dnsservertest.NewResp(
dns.RcodeSuccess,
req,
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(
"123456-dnsohttps-ds.metric.gstatic.com.",
ttl,
net.IP{1, 2, 3, 4},
), dnsservertest.NewA(
"654321-dnsohttps-ds.metric.gstatic.com.",
ttl,
net.IP{1, 2, 3, 5},
)},
Sec: dnsservertest.SectionAnswer,
},
),
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
wantAns: []dns.RR{
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}),
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}),
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) error {
assert.Equal(t, tc.wantName, req.Question[0].Name)
return rw.WriteMsg(ctx, req, tc.resp)
})
h := mw.Wrap(handler)
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
err := h.ServeDNS(ctx, rw, tc.req)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantAns, msg.Answer)
})
}
}

View File

@ -32,57 +32,23 @@ const (
// Resolvers for querying the resolver with unknown or absent name. // Resolvers for querying the resolver with unknown or absent name.
ddrDomain = ddrLabel + "." + resolverArpaDomain ddrDomain = ddrLabel + "." + resolverArpaDomain
// firefoxCanaryFQDN is the fully-qualified canary domain that Firefox uses // firefoxCanaryHost is the hostname that Firefox uses to check if it
// to check if it should use its own DNS-over-HTTPS settings. // should use its own DNS-over-HTTPS settings.
// //
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https. // See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
firefoxCanaryFQDN = "use-application-dns.net." firefoxCanaryHost = "use-application-dns.net"
// applePrivateRelayMaskHost and applePrivateRelayMaskH2Host are the
// hostnames that Apple devices use to check if Apple Private Relay can be
// enabled. Returning NXDOMAIN to queries for these domain names blocks
// Apple Private Relay.
//
// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
applePrivateRelayMaskHost = "mask.icloud.com"
applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
) )
// noReqInfoSpecialHandler returns a handler that can handle a special-domain // Hostnames that Apple devices use to check if Apple Private Relay can be
// query based only on its question type, class, and target, as well as the // enabled. Returning NXDOMAIN to queries for these domain names blocks Apple
// handler's name for debugging. // Private Relay.
func (mw *initMw) noReqInfoSpecialHandler( //
fqdn string, // See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
qt dnsmsg.RRType, const (
cl dnsmsg.Class, applePrivateRelayMaskHost = "mask.icloud.com"
) (f dnsserver.HandlerFunc, name string) { applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
if cl != dns.ClassINET { applePrivateRelayMaskCanaryHost = "mask-canary.icloud.com"
return nil, "" )
}
if (qt == dns.TypeA || qt == dns.TypeAAAA) && fqdn == firefoxCanaryFQDN {
return mw.handleFirefoxCanary, "firefox"
}
return nil, ""
}
// Firefox Canary
// handleFirefoxCanary checks if the request is for the fully-qualified domain
// name that Firefox uses to check DoH settings and writes a response if needed.
func (mw *initMw) handleFirefoxCanary(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := mw.messages.NewMsgREFUSED(req)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w")
}
// reqInfoSpecialHandler returns a handler that can handle a special-domain // reqInfoSpecialHandler returns a handler that can handle a special-domain
// query based on the request info, as well as the handler's name for debugging. // query based on the request info, as well as the handler's name for debugging.
@ -99,11 +65,9 @@ func (mw *initMw) reqInfoSpecialHandler(
} else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) { } else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) {
// A badly formed resolver.arpa subdomain query. // A badly formed resolver.arpa subdomain query.
return mw.handleBadResolverARPA, "bad_resolver_arpa" return mw.handleBadResolverARPA, "bad_resolver_arpa"
} else if shouldBlockPrivateRelay(ri) {
return mw.handlePrivateRelay, "apple_private_relay"
} }
return nil, "" return mw.specialDomainHandler(ri)
} }
// reqInfoHandlerFunc is an alias for handler functions that additionally accept // reqInfoHandlerFunc is an alias for handler functions that additionally accept
@ -222,24 +186,46 @@ func (mw *initMw) handleBadResolverARPA(
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host) return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
} }
// specialDomainHandler returns a handler that can handle a special-domain
// query for Apple Private Relay or Firefox canary domain based on the request
// or profile information, as well as the handler's name for debugging.
func (mw *initMw) specialDomainHandler(
ri *agd.RequestInfo,
) (f reqInfoHandlerFunc, name string) {
qt := ri.QType
if qt != dns.TypeA && qt != dns.TypeAAAA {
return nil, ""
}
host := ri.Host
prof := ri.Profile
switch host {
case
applePrivateRelayMaskHost,
applePrivateRelayMaskH2Host,
applePrivateRelayMaskCanaryHost:
if shouldBlockPrivateRelay(ri, prof) {
return mw.handlePrivateRelay, "apple_private_relay"
}
case firefoxCanaryHost:
if shouldBlockFirefoxCanary(ri, prof) {
return mw.handleFirefoxCanary, "firefox"
}
default:
// Go on.
}
return nil, ""
}
// Apple Private Relay // Apple Private Relay
// shouldBlockPrivateRelay returns true if the query is for an Apple Private // shouldBlockPrivateRelay returns true if the query is for an Apple Private
// Relay check domain and the request information indicates that Apple Private // Relay check domain and the request information or profile indicates that
// Relay should be blocked. // Apple Private Relay should be blocked.
func shouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) { func shouldBlockPrivateRelay(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
qt := ri.QType if prof != nil {
host := ri.Host
return (qt == dns.TypeA || qt == dns.TypeAAAA) &&
(host == applePrivateRelayMaskHost || host == applePrivateRelayMaskH2Host) &&
reqInfoShouldBlockPrivateRelay(ri)
}
// reqInfoShouldBlockPrivateRelay returns true if Apple Private Relay queries
// should be blocked based on the request information.
func reqInfoShouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
if prof := ri.Profile; prof != nil {
return prof.BlockPrivateRelay return prof.BlockPrivateRelay
} }
@ -260,3 +246,32 @@ func (mw *initMw) handlePrivateRelay(
return errors.Annotate(err, "writing private relay resp: %w") return errors.Annotate(err, "writing private relay resp: %w")
} }
// Firefox canary domain
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
// domain and the request information or profile indicates that Firefox canary
// domain should be blocked.
func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
if prof != nil {
return prof.BlockFirefoxCanary
}
return ri.FilteringGroup.BlockFirefoxCanary
}
// handleFirefoxCanary checks if the request is for the fully-qualified domain
// name that Firefox uses to check DoH settings and writes a response if needed.
func (mw *initMw) handleFirefoxCanary(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
ri *agd.RequestInfo,
) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := mw.messages.NewMsgREFUSED(req)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w")
}

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -93,22 +94,16 @@ func (mw *Middleware) toCacheKey(cr *cacheRequest, hostHasECS bool) (key uint64)
var buf [6]byte var buf [6]byte
binary.LittleEndian.PutUint16(buf[:2], cr.qType) binary.LittleEndian.PutUint16(buf[:2], cr.qType)
binary.LittleEndian.PutUint16(buf[2:4], cr.qClass) binary.LittleEndian.PutUint16(buf[2:4], cr.qClass)
if cr.reqDO {
buf[4] = 1
} else {
buf[4] = 0
}
if cr.subnet.Addr().Is4() { buf[4] = mathutil.BoolToNumber[byte](cr.reqDO)
buf[5] = 0
} else { addr := cr.subnet.Addr()
buf[5] = 1 buf[5] = mathutil.BoolToNumber[byte](addr.Is6())
}
_, _ = h.Write(buf[:]) _, _ = h.Write(buf[:])
if hostHasECS { if hostHasECS {
_, _ = h.Write(cr.subnet.Addr().AsSlice()) _, _ = h.Write(addr.AsSlice())
_ = h.WriteByte(byte(cr.subnet.Bits())) _ = h.WriteByte(byte(cr.subnet.Bits()))
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -269,9 +270,5 @@ func getTTLIfLower(r dns.RR, ttl uint32) (res uint32) {
// Go on. // Go on.
} }
if httl := r.Header().Ttl; httl < ttl { return mathutil.Min(r.Header().Ttl, ttl)
return httl
}
return ttl
} }

View File

@ -67,13 +67,13 @@ type SentryReportableError interface {
// TODO(a.garipov): Make sure that we use this approach everywhere. // TODO(a.garipov): Make sure that we use this approach everywhere.
func isReportable(err error) (ok bool) { func isReportable(err error) (ok bool) {
var ( var (
ravErr SentryReportableError sentryRepErr SentryReportableError
fwdErr *forward.Error fwdErr *forward.Error
dnsWErr *dnsserver.WriteError dnsWErr *dnsserver.WriteError
) )
if errors.As(err, &ravErr) { if errors.As(err, &sentryRepErr) {
return ravErr.IsSentryReportable() return sentryRepErr.IsSentryReportable()
} else if errors.As(err, &fwdErr) { } else if errors.As(err, &fwdErr) {
return isReportableNetwork(fwdErr.Err) return isReportableNetwork(fwdErr.Err)
} else if errors.As(err, &dnsWErr) { } else if errors.As(err, &dnsWErr) {

View File

@ -21,8 +21,8 @@ var _ Interface = (*compFilter)(nil)
// compFilter is a composite filter based on several types of safe search // compFilter is a composite filter based on several types of safe search
// filters and rule lists. // filters and rule lists.
type compFilter struct { type compFilter struct {
safeBrowsing *hashPrefixFilter safeBrowsing *HashPrefix
adultBlocking *hashPrefixFilter adultBlocking *HashPrefix
genSafeSearch *safeSearch genSafeSearch *safeSearch
ytSafeSearch *safeSearch ytSafeSearch *safeSearch

View File

@ -18,10 +18,11 @@ import (
// maxFilterSize is the maximum size of downloaded filters. // maxFilterSize is the maximum size of downloaded filters.
const maxFilterSize = 196 * int64(datasize.MB) const maxFilterSize = 196 * int64(datasize.MB)
// defaultTimeout is the default timeout to use when fetching filter data. // defaultFilterRefreshTimeout is the default timeout to use when fetching
// filter lists data.
// //
// TODO(a.garipov): Consider making timeouts where they are used configurable. // TODO(a.garipov): Consider making timeouts where they are used configurable.
const defaultTimeout = 30 * time.Second const defaultFilterRefreshTimeout = 180 * time.Second
// defaultResolveTimeout is the default timeout for resolving hosts for safe // defaultResolveTimeout is the default timeout for resolving hosts for safe
// search and safe browsing filters. // search and safe browsing filters.

View File

@ -181,24 +181,18 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) {
FilterIndexURL: fltsURL, FilterIndexURL: fltsURL,
GeneralSafeSearchRulesURL: ssURL, GeneralSafeSearchRulesURL: ssURL,
YoutubeSafeSearchRulesURL: ssURL, YoutubeSafeSearchRulesURL: ssURL,
SafeBrowsing: &filter.HashPrefixConfig{ SafeBrowsing: &filter.HashPrefix{},
CacheTTL: 1 * time.Hour, AdultBlocking: &filter.HashPrefix{},
CacheSize: 100, Now: time.Now,
}, ErrColl: nil,
AdultBlocking: &filter.HashPrefixConfig{ Resolver: nil,
CacheTTL: 1 * time.Hour, CacheDir: cacheDir,
CacheSize: 100, CustomFilterCacheSize: 100,
}, SafeSearchCacheSize: 100,
Now: time.Now, SafeSearchCacheTTL: 1 * time.Hour,
ErrColl: nil, RuleListCacheSize: 100,
Resolver: nil, RefreshIvl: testRefreshIvl,
CacheDir: cacheDir, UseRuleListCache: false,
CustomFilterCacheSize: 100,
SafeSearchCacheSize: 100,
SafeSearchCacheTTL: 1 * time.Hour,
RuleListCacheSize: 100,
RefreshIvl: testRefreshIvl,
UseRuleListCache: false,
} }
} }

View File

@ -3,13 +3,17 @@ package filter
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -21,13 +25,32 @@ import (
// HashPrefixConfig is the hash-prefix filter configuration structure. // HashPrefixConfig is the hash-prefix filter configuration structure.
type HashPrefixConfig struct { type HashPrefixConfig struct {
// Hashes are the hostname hashes for this filter. // Hashes are the hostname hashes for this filter.
Hashes *HashStorage Hashes *hashstorage.Storage
// URL is the URL used to update the filter.
URL *url.URL
// ErrColl is used to collect non-critical and rare errors.
ErrColl agd.ErrorCollector
// Resolver is used to resolve hosts for the hash-prefix filter.
Resolver agdnet.Resolver
// ID is the ID of this hash storage for logging and error reporting.
ID agd.FilterListID
// CachePath is the path to the file containing the cached filtered
// hostnames, one per line.
CachePath string
// ReplacementHost is the replacement host for this filter. Queries // ReplacementHost is the replacement host for this filter. Queries
// matched by the filter receive a response with the IP addresses of // matched by the filter receive a response with the IP addresses of
// this host. // this host.
ReplacementHost string ReplacementHost string
// Staleness is the time after which a file is considered stale.
Staleness time.Duration
// CacheTTL is the time-to-live value used to cache the results of the // CacheTTL is the time-to-live value used to cache the results of the
// filter. // filter.
// //
@ -38,57 +61,58 @@ type HashPrefixConfig struct {
CacheSize int CacheSize int
} }
// hashPrefixFilter is a filter that matches hosts by their hashes based on // HashPrefix is a filter that matches hosts by their hashes based on a
// a hash-prefix table. // hash-prefix table.
type hashPrefixFilter struct { type HashPrefix struct {
hashes *HashStorage hashes *hashstorage.Storage
refr *refreshableFilter
resCache *resultcache.Cache[*ResultModified] resCache *resultcache.Cache[*ResultModified]
resolver agdnet.Resolver resolver agdnet.Resolver
errColl agd.ErrorCollector errColl agd.ErrorCollector
repHost string repHost string
id agd.FilterListID
} }
// newHashPrefixFilter returns a new hash-prefix filter. c must not be nil. // NewHashPrefix returns a new hash-prefix filter. c must not be nil.
func newHashPrefixFilter( func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) {
c *HashPrefixConfig, f = &HashPrefix{
resolver agdnet.Resolver, hashes: c.Hashes,
errColl agd.ErrorCollector, refr: &refreshableFilter{
id agd.FilterListID, http: agdhttp.NewClient(&agdhttp.ClientConfig{
) (f *hashPrefixFilter) { Timeout: defaultFilterRefreshTimeout,
f = &hashPrefixFilter{ }),
hashes: c.Hashes, url: c.URL,
id: c.ID,
cachePath: c.CachePath,
typ: "hash storage",
staleness: c.Staleness,
},
resCache: resultcache.New[*ResultModified](c.CacheSize), resCache: resultcache.New[*ResultModified](c.CacheSize),
resolver: resolver, resolver: c.Resolver,
errColl: errColl, errColl: c.ErrColl,
repHost: c.ReplacementHost, repHost: c.ReplacementHost,
id: id,
} }
// Patch the refresh function of the hash storage, if there is one, to make f.refr.resetRules = f.resetRules
// sure that we clear the result cache during every refresh.
//
// TODO(a.garipov): Create a better way to do that than this spaghetti-style
// patching. Perhaps, lift the entire logic of hash-storage refresh into
// hashPrefixFilter.
if f.hashes != nil && f.hashes.refr != nil {
resetRules := f.hashes.refr.resetRules
f.hashes.refr.resetRules = func(text string) (err error) {
f.resCache.Clear()
return resetRules(text) err = f.refresh(context.Background(), true)
} if err != nil {
return nil, err
} }
return f return f, nil
}
// id returns the ID of the hash storage.
func (f *HashPrefix) id() (fltID agd.FilterListID) {
return f.refr.id
} }
// type check // type check
var _ qtHostFilter = (*hashPrefixFilter)(nil) var _ qtHostFilter = (*HashPrefix)(nil)
// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It // filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
// modifies the response if host matches f. // modifies the response if host matches f.
func (f *hashPrefixFilter) filterReq( func (f *HashPrefix) filterReq(
ctx context.Context, ctx context.Context,
ri *agd.RequestInfo, ri *agd.RequestInfo,
req *dns.Msg, req *dns.Msg,
@ -115,7 +139,7 @@ func (f *hashPrefixFilter) filterReq(
var matched string var matched string
sub := hashableSubdomains(host) sub := hashableSubdomains(host)
for _, s := range sub { for _, s := range sub {
if f.hashes.hashMatches(s) { if f.hashes.Matches(s) {
matched = s matched = s
break break
@ -134,19 +158,19 @@ func (f *hashPrefixFilter) filterReq(
var result *dns.Msg var result *dns.Msg
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost) ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
if err != nil { if err != nil {
agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err) agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err)
result = ri.Messages.NewMsgSERVFAIL(req) result = ri.Messages.NewMsgSERVFAIL(req)
} else { } else {
result, err = ri.Messages.NewIPRespMsg(req, ips...) result, err = ri.Messages.NewIPRespMsg(req, ips...)
if err != nil { if err != nil {
return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err) return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err)
} }
} }
rm = &ResultModified{ rm = &ResultModified{
Msg: result, Msg: result,
List: f.id, List: f.id(),
Rule: agd.FilterRuleText(matched), Rule: agd.FilterRuleText(matched),
} }
@ -161,21 +185,21 @@ func (f *hashPrefixFilter) filterReq(
} }
// updateCacheSizeMetrics updates cache size metrics. // updateCacheSizeMetrics updates cache size metrics.
func (f *hashPrefixFilter) updateCacheSizeMetrics(size int) { func (f *HashPrefix) updateCacheSizeMetrics(size int) {
switch f.id { switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing: case agd.FilterListIDSafeBrowsing:
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size)) metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
case agd.FilterListIDAdultBlocking: case agd.FilterListIDAdultBlocking:
metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size)) metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size))
default: default:
panic(fmt.Errorf("unsupported FilterListID %s", f.id)) panic(fmt.Errorf("unsupported FilterListID %s", id))
} }
} }
// updateCacheLookupsMetrics updates cache lookups metrics. // updateCacheLookupsMetrics updates cache lookups metrics.
func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) { func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
var hitsMetric, missesMetric prometheus.Counter var hitsMetric, missesMetric prometheus.Counter
switch f.id { switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing: case agd.FilterListIDSafeBrowsing:
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
@ -183,7 +207,7 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
default: default:
panic(fmt.Errorf("unsupported FilterListID %s", f.id)) panic(fmt.Errorf("unsupported FilterListID %s", id))
} }
if hit { if hit {
@ -194,12 +218,50 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
} }
// name implements the qtHostFilter interface for *hashPrefixFilter. // name implements the qtHostFilter interface for *hashPrefixFilter.
func (f *hashPrefixFilter) name() (n string) { func (f *HashPrefix) name() (n string) {
if f == nil { if f == nil {
return "" return ""
} }
return string(f.id) return string(f.id())
}
// type check
var _ agd.Refresher = (*HashPrefix)(nil)
// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter.
func (f *HashPrefix) Refresh(ctx context.Context) (err error) {
return f.refresh(ctx, false)
}
// refresh reloads the hash filter data. If acceptStale is true, do not try to
// load the list from its URL when there is already a file in the cache
// directory, regardless of its staleness.
func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) {
return f.refr.refresh(ctx, acceptStale)
}
// resetRules resets the hosts in the index.
func (f *HashPrefix) resetRules(text string) (err error) {
n, err := f.hashes.Reset(text)
// Report the filter update to prometheus.
promLabels := prometheus.Labels{
"filter": string(f.id()),
}
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
if err != nil {
return err
}
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
metrics.FilterRulesTotal.With(promLabels).Set(float64(n))
log.Info("filter %s: reset %d hosts", f.id(), n)
return nil
} }
// subDomainNum defines how many labels should be hashed to match against a hash // subDomainNum defines how many labels should be hashed to match against a hash

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -21,8 +22,10 @@ import (
func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
cacheDir := t.TempDir() cacheDir := t.TempDir()
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing)) cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
hosts := "scam.example.net\n" err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
err := os.WriteFile(cachePath, []byte(hosts), 0o644) require.NoError(t, err)
hashes, err := hashstorage.New("")
require.NoError(t, err) require.NoError(t, err)
errColl := &agdtest.ErrorCollector{ errColl := &agdtest.ErrorCollector{
@ -31,46 +34,33 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
}, },
} }
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ resolver := &agdtest.Resolver{
URL: nil, OnLookupIP: func(
ErrColl: errColl, _ context.Context,
ID: agd.FilterListIDSafeBrowsing, _ netutil.AddrFamily,
CachePath: cachePath, _ string,
RefreshIvl: testRefreshIvl, ) (ips []net.IP, err error) {
}) return []net.IP{safeBrowsingSafeIP4}, nil
require.NoError(t, err) },
ctx := context.Background()
err = hashes.Start()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return hashes.Shutdown(ctx)
})
// Fake Data
onLookupIP := func(
_ context.Context,
_ netutil.AddrFamily,
_ string,
) (ips []net.IP, err error) {
return []net.IP{safeBrowsingSafeIP4}, nil
} }
c := prepareConf(t) c := prepareConf(t)
c.SafeBrowsing = &filter.HashPrefixConfig{ c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes, Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDSafeBrowsing,
CachePath: cachePath,
ReplacementHost: safeBrowsingSafeHost, ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second, CacheTTL: 10 * time.Second,
CacheSize: 100, CacheSize: 100,
} })
require.NoError(t, err)
c.ErrColl = errColl c.ErrColl = errColl
c.Resolver = resolver
c.Resolver = &agdtest.Resolver{
OnLookupIP: onLookupIP,
}
s, err := filter.NewDefaultStorage(c) s, err := filter.NewDefaultStorage(c)
require.NoError(t, err) require.NoError(t, err)
@ -93,7 +83,7 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
} }
ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA) ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
ctx = agd.ContextWithRequestInfo(ctx, ri) ctx := agd.ContextWithRequestInfo(context.Background(), ri)
f := s.FilterFromContext(ctx, ri) f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f) require.NotNil(t, f)

View File

@ -1,352 +0,0 @@
package filter
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/prometheus/client_golang/prometheus"
)
// Hash Storage
// Hash and hash part length constants.
const (
// hashPrefixLen is the length of the prefix of the hash of the filtered
// hostname.
hashPrefixLen = 2
// HashPrefixEncLen is the encoded length of the hash prefix. Two text
// bytes per one binary byte.
HashPrefixEncLen = hashPrefixLen * 2
// hashLen is the length of the whole hash of the checked hostname.
hashLen = sha256.Size
// hashSuffixLen is the length of the suffix of the hash of the filtered
// hostname.
hashSuffixLen = hashLen - hashPrefixLen
// hashEncLen is the encoded length of the hash. Two text bytes per one
// binary byte.
hashEncLen = hashLen * 2
// legacyHashPrefixEncLen is the encoded length of a legacy hash.
legacyHashPrefixEncLen = 8
)
// hashPrefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of
// a host being checked.
type hashPrefix [hashPrefixLen]byte
// hashSuffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of
// a host being checked.
type hashSuffix [hashSuffixLen]byte
// HashStorage is a storage for hashes of the filtered hostnames.
type HashStorage struct {
// mu protects hashSuffixes.
mu *sync.RWMutex
hashSuffixes map[hashPrefix][]hashSuffix
// refr contains data for refreshing the filter.
refr *refreshableFilter
refrWorker *agd.RefreshWorker
}
// HashStorageConfig is the configuration structure for a *HashStorage.
type HashStorageConfig struct {
// URL is the URL used to update the filter.
URL *url.URL
// ErrColl is used to collect non-critical and rare errors.
ErrColl agd.ErrorCollector
// ID is the ID of this hash storage for logging and error reporting.
ID agd.FilterListID
// CachePath is the path to the file containing the cached filtered
// hostnames, one per line.
CachePath string
// RefreshIvl is the refresh interval.
RefreshIvl time.Duration
}
// NewHashStorage returns a new hash storage containing hashes of all hostnames.
func NewHashStorage(c *HashStorageConfig) (hs *HashStorage, err error) {
hs = &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout,
}),
url: c.URL,
id: c.ID,
cachePath: c.CachePath,
typ: "hash storage",
refreshIvl: c.RefreshIvl,
},
}
// Do not set this in the literal above, since hs is nil there.
hs.refr.resetRules = hs.resetHosts
refrWorker := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: func() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), defaultTimeout)
},
Refresher: hs,
ErrColl: c.ErrColl,
Name: string(c.ID),
Interval: c.RefreshIvl,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
hs.refrWorker = refrWorker
err = hs.refresh(context.Background(), true)
if err != nil {
return nil, fmt.Errorf("initializing %s: %w", c.ID, err)
}
return hs, nil
}
// hashes returns all hashes starting with the given prefixes, if any. The
// resulting slice shares storage for all underlying strings.
//
// TODO(a.garipov): This currently doesn't take duplicates into account.
func (hs *HashStorage) hashes(hps []hashPrefix) (hashes []string) {
if len(hps) == 0 {
return nil
}
hs.mu.RLock()
defer hs.mu.RUnlock()
// First, calculate the number of hashes to allocate the buffer.
l := 0
for _, hp := range hps {
hashSufs := hs.hashSuffixes[hp]
l += len(hashSufs)
}
// Then, allocate the buffer of the appropriate size and write all hashes
// into one big buffer and slice it into separate strings to make the
// garbage collector's work easier. This assumes that all references to
// this buffer will become unreachable at the same time.
//
// The fact that we iterate over the map twice shouldn't matter, since we
// assume that len(hps) will be below 5 most of the time.
b := &strings.Builder{}
b.Grow(l * hashEncLen)
// Use a buffer and write the resulting buffer into b directly instead of
// using hex.NewEncoder, because that seems to incur a significant
// performance hit.
var buf [hashEncLen]byte
for _, hp := range hps {
hashSufs := hs.hashSuffixes[hp]
for _, suf := range hashSufs {
// Slicing is safe here, since the contents of hp and suf are being
// encoded.
// nolint:looppointer
hex.Encode(buf[:], hp[:])
// nolint:looppointer
hex.Encode(buf[HashPrefixEncLen:], suf[:])
_, _ = b.Write(buf[:])
}
}
s := b.String()
hashes = make([]string, 0, l)
for i := 0; i < l; i++ {
hashes = append(hashes, s[i*hashEncLen:(i+1)*hashEncLen])
}
return hashes
}
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
// concurrent use.
func (hs *HashStorage) loadHashSuffixes(hp hashPrefix) (sufs []hashSuffix, ok bool) {
hs.mu.RLock()
defer hs.mu.RUnlock()
sufs, ok = hs.hashSuffixes[hp]
return sufs, ok
}
// hashMatches returns true if the host matches one of the hashes.
func (hs *HashStorage) hashMatches(host string) (ok bool) {
sum := sha256.Sum256([]byte(host))
hp := hashPrefix{sum[0], sum[1]}
var buf [hashLen]byte
hashSufs, ok := hs.loadHashSuffixes(hp)
if !ok {
return false
}
copy(buf[:], hp[:])
for _, suf := range hashSufs {
// Slicing is safe here, because we make a copy.
// nolint:looppointer
copy(buf[hashPrefixLen:], suf[:])
if buf == sum {
return true
}
}
return false
}
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashPrefix, err error) {
if prefixesStr == "" {
return nil, nil
}
prefixSet := stringutil.NewSet()
prefixStrs := strings.Split(prefixesStr, ".")
for _, s := range prefixStrs {
if len(s) != HashPrefixEncLen {
// Some legacy clients send eight-character hashes instead of
// four-character ones. For now, remove the final four characters.
//
// TODO(a.garipov): Either remove this crutch or support such
// prefixes better.
if len(s) == legacyHashPrefixEncLen {
s = s[:HashPrefixEncLen]
} else {
return nil, fmt.Errorf("bad hash len for %q", s)
}
}
prefixSet.Add(s)
}
hashPrefixes = make([]hashPrefix, prefixSet.Len())
prefixStrs = prefixSet.Values()
for i, s := range prefixStrs {
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
if err != nil {
return nil, fmt.Errorf("bad hash encoding for %q", s)
}
}
return hashPrefixes, nil
}
// type check
var _ agd.Refresher = (*HashStorage)(nil)
// Refresh implements the agd.Refresher interface for *HashStorage. If the file
// at the storage's path exists and its mtime shows that it's still fresh, it
// loads the data from the file. Otherwise, it uses the URL of the storage.
func (hs *HashStorage) Refresh(ctx context.Context) (err error) {
err = hs.refresh(ctx, false)
// Report the filter update to prometheus.
promLabels := prometheus.Labels{
"filter": string(hs.id()),
}
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
if err == nil {
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
// Count the total number of hashes loaded.
count := 0
for _, v := range hs.hashSuffixes {
count += len(v)
}
metrics.FilterRulesTotal.With(promLabels).Set(float64(count))
}
return err
}
// id returns the ID of the hash storage.
func (hs *HashStorage) id() (fltID agd.FilterListID) {
return hs.refr.id
}
// refresh reloads the hash filter data. If acceptStale is true, do not try to
// load the list from its URL when there is already a file in the cache
// directory, regardless of its staleness.
func (hs *HashStorage) refresh(ctx context.Context, acceptStale bool) (err error) {
return hs.refr.refresh(ctx, acceptStale)
}
// resetHosts resets the hosts in the index.
func (hs *HashStorage) resetHosts(hostsStr string) (err error) {
hs.mu.Lock()
defer hs.mu.Unlock()
// Delete all elements without allocating a new map to safe space and
// performance.
//
// This is optimized, see https://github.com/golang/go/issues/20138.
for hp := range hs.hashSuffixes {
delete(hs.hashSuffixes, hp)
}
var n int
s := bufio.NewScanner(strings.NewReader(hostsStr))
for s.Scan() {
host := s.Text()
if len(host) == 0 || host[0] == '#' {
continue
}
sum := sha256.Sum256([]byte(host))
hp := hashPrefix{sum[0], sum[1]}
// TODO(a.garipov): Convert to array directly when proposal
// golang/go#46505 is implemented in Go 1.20.
suf := *(*hashSuffix)(sum[hashPrefixLen:])
hs.hashSuffixes[hp] = append(hs.hashSuffixes[hp], suf)
n++
}
err = s.Err()
if err != nil {
return fmt.Errorf("scanning hosts: %w", err)
}
log.Info("filter %s: reset %d hosts", hs.id(), n)
return nil
}
// Start implements the agd.Service interface for *HashStorage.
func (hs *HashStorage) Start() (err error) {
return hs.refrWorker.Start()
}
// Shutdown implements the agd.Service interface for *HashStorage.
func (hs *HashStorage) Shutdown(ctx context.Context) (err error) {
return hs.refrWorker.Shutdown(ctx)
}

View File

@ -0,0 +1,203 @@
// Package hashstorage defines a storage of hashes of domain names used for
// filtering.
package hashstorage
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
)
// Hash and hash part length constants.
const (
// PrefixLen is the length of the prefix of the hash of the filtered
// hostname.
PrefixLen = 2
// PrefixEncLen is the encoded length of the hash prefix. Two text
// bytes per one binary byte.
PrefixEncLen = PrefixLen * 2
// hashLen is the length of the whole hash of the checked hostname.
hashLen = sha256.Size
// suffixLen is the length of the suffix of the hash of the filtered
// hostname.
suffixLen = hashLen - PrefixLen
// hashEncLen is the encoded length of the hash. Two text bytes per one
// binary byte.
hashEncLen = hashLen * 2
)
// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a
// host being checked.
type Prefix [PrefixLen]byte
// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a
// host being checked.
type suffix [suffixLen]byte
// Storage stores hashes of the filtered hostnames. All methods are safe for
// concurrent use.
type Storage struct {
// mu protects hashSuffixes.
mu *sync.RWMutex
hashSuffixes map[Prefix][]suffix
}
// New returns a new hash storage containing hashes of the domain names listed
// in hostnames, one domain name per line.
func New(hostnames string) (s *Storage, err error) {
s = &Storage{
mu: &sync.RWMutex{},
hashSuffixes: map[Prefix][]suffix{},
}
if hostnames != "" {
_, err = s.Reset(hostnames)
if err != nil {
return nil, err
}
}
return s, nil
}
// Hashes returns all hashes starting with the given prefixes, if any. The
// resulting slice shares storage for all underlying strings.
//
// TODO(a.garipov): This currently doesn't take duplicates into account.
func (s *Storage) Hashes(hps []Prefix) (hashes []string) {
if len(hps) == 0 {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
// First, calculate the number of hashes to allocate the buffer.
l := 0
for _, hp := range hps {
hashSufs := s.hashSuffixes[hp]
l += len(hashSufs)
}
// Then, allocate the buffer of the appropriate size and write all hashes
// into one big buffer and slice it into separate strings to make the
// garbage collector's work easier. This assumes that all references to
// this buffer will become unreachable at the same time.
//
// The fact that we iterate over the [s.hashSuffixes] map twice shouldn't
// matter, since we assume that len(hps) will be below 5 most of the time.
b := &strings.Builder{}
b.Grow(l * hashEncLen)
// Use a buffer and write the resulting buffer into b directly instead of
// using hex.NewEncoder, because that seems to incur a significant
// performance hit.
var buf [hashEncLen]byte
for _, hp := range hps {
hashSufs := s.hashSuffixes[hp]
for _, suf := range hashSufs {
// Slicing is safe here, since the contents of hp and suf are being
// encoded.
// nolint:looppointer
hex.Encode(buf[:], hp[:])
// nolint:looppointer
hex.Encode(buf[PrefixEncLen:], suf[:])
_, _ = b.Write(buf[:])
}
}
str := b.String()
hashes = make([]string, 0, l)
for i := 0; i < l; i++ {
hashes = append(hashes, str[i*hashEncLen:(i+1)*hashEncLen])
}
return hashes
}
// Matches returns true if the host matches one of the hashes.
func (s *Storage) Matches(host string) (ok bool) {
sum := sha256.Sum256([]byte(host))
hp := *(*Prefix)(sum[:PrefixLen])
var buf [hashLen]byte
hashSufs, ok := s.loadHashSuffixes(hp)
if !ok {
return false
}
copy(buf[:], hp[:])
for _, suf := range hashSufs {
// Slicing is safe here, because we make a copy.
// nolint:looppointer
copy(buf[PrefixLen:], suf[:])
if buf == sum {
return true
}
}
return false
}
// Reset resets the hosts in the index using the domain names listed in
// hostnames, one domain name per line, and returns the total number of
// processed rules.
func (s *Storage) Reset(hostnames string) (n int, err error) {
s.mu.Lock()
defer s.mu.Unlock()
// Delete all elements without allocating a new map to save space and
// improve performance.
//
// This is optimized, see https://github.com/golang/go/issues/20138.
//
// TODO(a.garipov): Use clear once golang/go#56351 is implemented.
for hp := range s.hashSuffixes {
delete(s.hashSuffixes, hp)
}
sc := bufio.NewScanner(strings.NewReader(hostnames))
for sc.Scan() {
host := sc.Text()
if len(host) == 0 || host[0] == '#' {
continue
}
sum := sha256.Sum256([]byte(host))
hp := *(*Prefix)(sum[:PrefixLen])
// TODO(a.garipov): Here and everywhere, convert to array directly when
// proposal golang/go#46505 is implemented in Go 1.20.
suf := *(*suffix)(sum[PrefixLen:])
s.hashSuffixes[hp] = append(s.hashSuffixes[hp], suf)
n++
}
err = sc.Err()
if err != nil {
return 0, fmt.Errorf("scanning hosts: %w", err)
}
return n, nil
}
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
// concurrent use. sufs must not be modified.
func (s *Storage) loadHashSuffixes(hp Prefix) (sufs []suffix, ok bool) {
s.mu.RLock()
defer s.mu.RUnlock()
sufs, ok = s.hashSuffixes[hp]
return sufs, ok
}

View File

@ -0,0 +1,142 @@
package hashstorage_test
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strconv"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Common hostnames for tests.
const (
testHost = "porn.example"
otherHost = "otherporn.example"
)
func TestStorage_Hashes(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
h := sha256.Sum256([]byte(testHost))
want := []string{hex.EncodeToString(h[:])}
p := hashstorage.Prefix{h[0], h[1]}
got := s.Hashes([]hashstorage.Prefix{p})
assert.Equal(t, want, got)
wrong := s.Hashes([]hashstorage.Prefix{{}})
assert.Empty(t, wrong)
}
func TestStorage_Matches(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
got := s.Matches(testHost)
assert.True(t, got)
got = s.Matches(otherHost)
assert.False(t, got)
}
func TestStorage_Reset(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
n, err := s.Reset(otherHost)
require.NoError(t, err)
assert.Equal(t, 1, n)
h := sha256.Sum256([]byte(otherHost))
want := []string{hex.EncodeToString(h[:])}
p := hashstorage.Prefix{h[0], h[1]}
got := s.Hashes([]hashstorage.Prefix{p})
assert.Equal(t, want, got)
prevHash := sha256.Sum256([]byte(testHost))
prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}})
assert.Empty(t, prev)
}
// Sinks for benchmarks.
var (
errSink error
strsSink []string
)
func BenchmarkStorage_Hashes(b *testing.B) {
const N = 10_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
}
s, err := hashstorage.New(strings.Join(hosts, "\n"))
require.NoError(b, err)
var hashPrefixes []hashstorage.Prefix
for i := 0; i < 4; i++ {
hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]})
}
for n := 1; n <= 4; n++ {
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
hps := hashPrefixes[:n]
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
strsSink = s.Hashes(hps)
}
})
}
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/3-16 13492526 92.22 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/4-16 9542425 109.2 ns/op 0 B/op 0 allocs/op
}
func BenchmarkStorage_ResetHosts(b *testing.B) {
const N = 1_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
}
hostnames := strings.Join(hosts, "\n")
s, err := hashstorage.New(hostnames)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, errSink = s.Reset(hostnames)
}
require.NoError(b, errSink)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op
}

View File

@ -1,91 +0,0 @@
package filter
import (
"fmt"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
var strsSink []string
func BenchmarkHashStorage_Hashes(b *testing.B) {
const N = 10_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
}
// Don't use a constructor, since we don't need the whole contents of the
// storage.
//
// TODO(a.garipov): Think of a better way to do this.
hs := &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{id: "test_filter"},
}
err := hs.resetHosts(strings.Join(hosts, "\n"))
require.NoError(b, err)
var hashPrefixes []hashPrefix
for i := 0; i < 4; i++ {
hashPrefixes = append(hashPrefixes, hashPrefix{hosts[i][0], hosts[i][1]})
}
for n := 1; n <= 4; n++ {
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
hps := hashPrefixes[:n]
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
strsSink = hs.hashes(hps)
}
})
}
}
func BenchmarkHashStorage_resetHosts(b *testing.B) {
const N = 1_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
}
// Don't use a constructor, since we don't need the whole contents of the
// storage.
//
// TODO(a.garipov): Think of a better way to do this.
hs := &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{id: "test_filter"},
}
// Reset them once to fill the initial map.
err := hs.resetHosts(strings.Join(hosts, "\n"))
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = hs.resetHosts(strings.Join(hosts, "\n"))
}
require.NoError(b, err)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkHashStorage_resetHosts-16 1404 829505 ns/op 58289 B/op 1011 allocs/op
}

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
) )
@ -93,13 +94,9 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
// Save on allocations by reusing a buffer. // Save on allocations by reusing a buffer.
var buf [3]byte var buf [3]byte
binary.LittleEndian.PutUint16(buf[:2], qt) binary.LittleEndian.PutUint16(buf[:2], qt)
if isAns { buf[2] = mathutil.BoolToNumber[byte](isAns)
buf[2] = 1
} else {
buf[2] = 0
}
_, _ = h.Write(buf[:3]) _, _ = h.Write(buf[:])
return Key(h.Sum64()) return Key(h.Sum64())
} }

View File

@ -45,9 +45,8 @@ type refreshableFilter struct {
// typ is the type of this filter used for logging and error reporting. // typ is the type of this filter used for logging and error reporting.
typ string typ string
// refreshIvl is the refresh interval for this filter. It is also used to // staleness is the time after which a file is considered stale.
// check if the cached file is fresh enough. staleness time.Duration
refreshIvl time.Duration
} }
// refresh reloads the filter data. If acceptStale is true, refresh doesn't try // refresh reloads the filter data. If acceptStale is true, refresh doesn't try
@ -118,7 +117,7 @@ func (f *refreshableFilter) refreshFromFile(
return "", fmt.Errorf("reading filter file stat: %w", err) return "", fmt.Errorf("reading filter file stat: %w", err)
} }
if mtime := fi.ModTime(); !mtime.Add(f.refreshIvl).After(time.Now()) { if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) {
return "", nil return "", nil
} }
} }

View File

@ -30,43 +30,43 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
name string name string
cachePath string cachePath string
wantText string wantText string
refreshIvl time.Duration staleness time.Duration
acceptStale bool acceptStale bool
}{{ }{{
name: "no_file", name: "no_file",
cachePath: "does_not_exist", cachePath: "does_not_exist",
wantText: "", wantText: "",
refreshIvl: 0, staleness: 0,
acceptStale: true, acceptStale: true,
}, { }, {
name: "file", name: "file",
cachePath: cachePath, cachePath: cachePath,
wantText: defaultText, wantText: defaultText,
refreshIvl: 0, staleness: 0,
acceptStale: true, acceptStale: true,
}, { }, {
name: "file_stale", name: "file_stale",
cachePath: cachePath, cachePath: cachePath,
wantText: "", wantText: "",
refreshIvl: -1 * time.Second, staleness: -1 * time.Second,
acceptStale: false, acceptStale: false,
}, { }, {
name: "file_stale_accept", name: "file_stale_accept",
cachePath: cachePath, cachePath: cachePath,
wantText: defaultText, wantText: defaultText,
refreshIvl: -1 * time.Second, staleness: -1 * time.Second,
acceptStale: true, acceptStale: true,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
f := &refreshableFilter{ f := &refreshableFilter{
http: nil, http: nil,
url: nil, url: nil,
id: "test_filter", id: "test_filter",
cachePath: tc.cachePath, cachePath: tc.cachePath,
typ: "test filter", typ: "test filter",
refreshIvl: tc.refreshIvl, staleness: tc.staleness,
} }
var text string var text string
@ -161,12 +161,12 @@ func TestRefreshableFilter_RefreshFromURL(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
f := &refreshableFilter{ f := &refreshableFilter{
http: httpCli, http: httpCli,
url: u, url: u,
id: "test_filter", id: "test_filter",
cachePath: tc.cachePath, cachePath: tc.cachePath,
typ: "test filter", typ: "test filter",
refreshIvl: testTimeout, staleness: testTimeout,
} }
if tc.expectReq { if tc.expectReq {

View File

@ -72,13 +72,13 @@ func newRuleListFilter(
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
refr: &refreshableFilter{ refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{ http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout, Timeout: defaultFilterRefreshTimeout,
}), }),
url: l.URL, url: l.URL,
id: l.ID, id: l.ID,
cachePath: filepath.Join(fileCacheDir, string(l.ID)), cachePath: filepath.Join(fileCacheDir, string(l.ID)),
typ: "rule list", typ: "rule list",
refreshIvl: l.RefreshIvl, staleness: l.RefreshIvl,
}, },
urlFilterID: newURLFilterID(), urlFilterID: newURLFilterID(),
} }

View File

@ -2,9 +2,13 @@ package filter
import ( import (
"context" "context"
"encoding/hex"
"fmt"
"strings" "strings"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
) )
// Safe Browsing TXT Record Server // Safe Browsing TXT Record Server
@ -14,12 +18,12 @@ import (
// //
// TODO(a.garipov): Consider making an interface to simplify testing. // TODO(a.garipov): Consider making an interface to simplify testing.
type SafeBrowsingServer struct { type SafeBrowsingServer struct {
generalHashes *HashStorage generalHashes *hashstorage.Storage
adultBlockingHashes *HashStorage adultBlockingHashes *hashstorage.Storage
} }
// NewSafeBrowsingServer returns a new safe browsing DNS server. // NewSafeBrowsingServer returns a new safe browsing DNS server.
func NewSafeBrowsingServer(general, adultBlocking *HashStorage) (f *SafeBrowsingServer) { func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) {
return &SafeBrowsingServer{ return &SafeBrowsingServer{
generalHashes: general, generalHashes: general,
adultBlockingHashes: adultBlocking, adultBlockingHashes: adultBlocking,
@ -48,8 +52,7 @@ func (srv *SafeBrowsingServer) Hashes(
} }
var prefixesStr string var prefixesStr string
var strg *HashStorage var strg *hashstorage.Storage
if strings.HasSuffix(host, GeneralTXTSuffix) { if strings.HasSuffix(host, GeneralTXTSuffix) {
prefixesStr = host[:len(host)-len(GeneralTXTSuffix)] prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
strg = srv.generalHashes strg = srv.generalHashes
@ -67,5 +70,45 @@ func (srv *SafeBrowsingServer) Hashes(
return nil, false, err return nil, false, err
} }
return strg.hashes(hashPrefixes), true, nil return strg.Hashes(hashPrefixes), true, nil
}
// legacyPrefixEncLen is the encoded length of a legacy hash.
const legacyPrefixEncLen = 8
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) {
if prefixesStr == "" {
return nil, nil
}
prefixSet := stringutil.NewSet()
prefixStrs := strings.Split(prefixesStr, ".")
for _, s := range prefixStrs {
if len(s) != hashstorage.PrefixEncLen {
// Some legacy clients send eight-character hashes instead of
// four-character ones. For now, remove the final four characters.
//
// TODO(a.garipov): Either remove this crutch or support such
// prefixes better.
if len(s) == legacyPrefixEncLen {
s = s[:hashstorage.PrefixEncLen]
} else {
return nil, fmt.Errorf("bad hash len for %q", s)
}
}
prefixSet.Add(s)
}
hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len())
prefixStrs = prefixSet.Values()
for i, s := range prefixStrs {
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
if err != nil {
return nil, fmt.Errorf("bad hash encoding for %q", s)
}
}
return hashPrefixes, nil
} }

View File

@ -4,17 +4,11 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"net/url"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -43,41 +37,10 @@ func TestSafeBrowsingServer(t *testing.T) {
hashStrs[i] = hex.EncodeToString(sum[:]) hashStrs[i] = hex.EncodeToString(sum[:])
} }
// Hash Storage hashes, err := hashstorage.New(strings.Join(hosts, "\n"))
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
panic("not implemented")
},
}
cacheDir := t.TempDir()
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
err := os.WriteFile(cachePath, []byte(strings.Join(hosts, "\n")), 0o644)
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
URL: &url.URL{},
ErrColl: errColl,
ID: agd.FilterListIDSafeBrowsing,
CachePath: cachePath,
RefreshIvl: testRefreshIvl,
})
require.NoError(t, err) require.NoError(t, err)
ctx := context.Background() ctx := context.Background()
err = hashes.Start()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return hashes.Shutdown(ctx)
})
// Give the storage some time to process the hashes.
//
// TODO(a.garipov): Think of a less stupid way of doing this.
time.Sleep(100 * time.Millisecond)
testCases := []struct { testCases := []struct {
name string name string
host string host string
@ -90,14 +53,14 @@ func TestSafeBrowsingServer(t *testing.T) {
wantMatched: false, wantMatched: false,
}, { }, {
name: "realistic", name: "realistic",
host: hashStrs[realisticHostIdx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix, host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{ wantHashStrs: []string{
hashStrs[realisticHostIdx], hashStrs[realisticHostIdx],
}, },
wantMatched: true, wantMatched: true,
}, { }, {
name: "same_prefix", name: "same_prefix",
host: hashStrs[samePrefixHost1Idx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix, host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{ wantHashStrs: []string{
hashStrs[samePrefixHost1Idx], hashStrs[samePrefixHost1Idx],
hashStrs[samePrefixHost2Idx], hashStrs[samePrefixHost2Idx],

View File

@ -3,9 +3,7 @@ package filter_test
import ( import (
"context" "context"
"net" "net"
"os"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
@ -20,45 +18,29 @@ import (
func TestStorage_FilterFromContext_safeSearch(t *testing.T) { func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
numLookupIP := 0 numLookupIP := 0
onLookupIP := func( resolver := &agdtest.Resolver{
_ context.Context, OnLookupIP: func(
fam netutil.AddrFamily, _ context.Context,
_ string, fam netutil.AddrFamily,
) (ips []net.IP, err error) { _ string,
numLookupIP++ ) (ips []net.IP, err error) {
numLookupIP++
if fam == netutil.AddrFamilyIPv4 { if fam == netutil.AddrFamilyIPv4 {
return []net.IP{safeSearchIPRespIP4}, nil return []net.IP{safeSearchIPRespIP4}, nil
} }
return []net.IP{safeSearchIPRespIP6}, nil return []net.IP{safeSearchIPRespIP6}, nil
},
} }
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_, err = tmpFile.Write([]byte("bad.example.com\n"))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t) c := prepareConf(t)
c.SafeBrowsing.Hashes = hashes
c.AdultBlocking.Hashes = hashes
c.ErrColl = &agdtest.ErrorCollector{ c.ErrColl = &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") }, OnCollect: func(_ context.Context, err error) { panic("not implemented") },
} }
c.Resolver = &agdtest.Resolver{ c.Resolver = resolver
OnLookupIP: onLookupIP,
}
s, err := filter.NewDefaultStorage(c) s, err := filter.NewDefaultStorage(c)
require.NoError(t, err) require.NoError(t, err)
@ -80,13 +62,13 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
host: safeSearchIPHost, host: safeSearchIPHost,
wantIP: safeSearchIPRespIP4, wantIP: safeSearchIPRespIP4,
rrtype: dns.TypeA, rrtype: dns.TypeA,
wantLookups: 0, wantLookups: 1,
}, { }, {
name: "ip6", name: "ip6",
host: safeSearchIPHost, host: safeSearchIPHost,
wantIP: nil, wantIP: safeSearchIPRespIP6,
rrtype: dns.TypeAAAA, rrtype: dns.TypeAAAA,
wantLookups: 0, wantLookups: 1,
}, { }, {
name: "host_ip4", name: "host_ip4",
host: safeSearchHost, host: safeSearchHost,

View File

@ -48,7 +48,7 @@ func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *servic
return &serviceBlocker{ return &serviceBlocker{
url: indexURL, url: indexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{ http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout, Timeout: defaultFilterRefreshTimeout,
}), }),
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
errColl: errColl, errColl: errColl,

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -55,10 +54,10 @@ type DefaultStorage struct {
services *serviceBlocker services *serviceBlocker
// safeBrowsing is the general safe browsing filter. // safeBrowsing is the general safe browsing filter.
safeBrowsing *hashPrefixFilter safeBrowsing *HashPrefix
// adultBlocking is the adult content blocking safe browsing filter. // adultBlocking is the adult content blocking safe browsing filter.
adultBlocking *hashPrefixFilter adultBlocking *HashPrefix
// genSafeSearch is the general safe search filter. // genSafeSearch is the general safe search filter.
genSafeSearch *safeSearch genSafeSearch *safeSearch
@ -111,11 +110,11 @@ type DefaultStorageConfig struct {
// SafeBrowsing is the configuration for the default safe browsing filter. // SafeBrowsing is the configuration for the default safe browsing filter.
// It must not be nil. // It must not be nil.
SafeBrowsing *HashPrefixConfig SafeBrowsing *HashPrefix
// AdultBlocking is the configuration for the adult content blocking safe // AdultBlocking is the configuration for the adult content blocking safe
// browsing filter. It must not be nil. // browsing filter. It must not be nil.
AdultBlocking *HashPrefixConfig AdultBlocking *HashPrefix
// Now is a function that returns current time. // Now is a function that returns current time.
Now func() (now time.Time) Now func() (now time.Time)
@ -156,25 +155,8 @@ type DefaultStorageConfig struct {
// NewDefaultStorage returns a new filter storage. c must not be nil. // NewDefaultStorage returns a new filter storage. c must not be nil.
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) { func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
// TODO(ameshkov): Consider making configurable.
resolver := agdnet.NewCachingResolver(c.Resolver, 1*timeutil.Day)
safeBrowsing := newHashPrefixFilter(
c.SafeBrowsing,
resolver,
c.ErrColl,
agd.FilterListIDSafeBrowsing,
)
adultBlocking := newHashPrefixFilter(
c.AdultBlocking,
resolver,
c.ErrColl,
agd.FilterListIDAdultBlocking,
)
genSafeSearch := newSafeSearch(&safeSearchConfig{ genSafeSearch := newSafeSearch(&safeSearchConfig{
resolver: resolver, resolver: c.Resolver,
errColl: c.ErrColl, errColl: c.ErrColl,
list: &agd.FilterList{ list: &agd.FilterList{
URL: c.GeneralSafeSearchRulesURL, URL: c.GeneralSafeSearchRulesURL,
@ -187,7 +169,7 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
}) })
ytSafeSearch := newSafeSearch(&safeSearchConfig{ ytSafeSearch := newSafeSearch(&safeSearchConfig{
resolver: resolver, resolver: c.Resolver,
errColl: c.ErrColl, errColl: c.ErrColl,
list: &agd.FilterList{ list: &agd.FilterList{
URL: c.YoutubeSafeSearchRulesURL, URL: c.YoutubeSafeSearchRulesURL,
@ -203,11 +185,11 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
url: c.FilterIndexURL, url: c.FilterIndexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{ http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout, Timeout: defaultFilterRefreshTimeout,
}), }),
services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl), services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
safeBrowsing: safeBrowsing, safeBrowsing: c.SafeBrowsing,
adultBlocking: adultBlocking, adultBlocking: c.AdultBlocking,
genSafeSearch: genSafeSearch, genSafeSearch: genSafeSearch,
ytSafeSearch: ytSafeSearch, ytSafeSearch: ytSafeSearch,
now: c.Now, now: c.Now,
@ -322,7 +304,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b
func (s *DefaultStorage) safeBrowsingForProfile( func (s *DefaultStorage) safeBrowsingForProfile(
p *agd.Profile, p *agd.Profile,
parentalEnabled bool, parentalEnabled bool,
) (safeBrowsing, adultBlocking *hashPrefixFilter) { ) (safeBrowsing, adultBlocking *HashPrefix) {
if p.SafeBrowsingEnabled { if p.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing safeBrowsing = s.safeBrowsing
} }
@ -359,7 +341,7 @@ func (s *DefaultStorage) safeSearchForProfile(
// in the filtering group. g must not be nil. // in the filtering group. g must not be nil.
func (s *DefaultStorage) safeBrowsingForGroup( func (s *DefaultStorage) safeBrowsingForGroup(
g *agd.FilteringGroup, g *agd.FilteringGroup,
) (safeBrowsing, adultBlocking *hashPrefixFilter) { ) (safeBrowsing, adultBlocking *HashPrefix) {
if g.SafeBrowsingEnabled { if g.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing safeBrowsing = s.safeBrowsing
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -124,34 +125,11 @@ func TestStorage_FilterFromContext(t *testing.T) {
} }
func TestStorage_FilterFromContext_customAllow(t *testing.T) { func TestStorage_FilterFromContext_customAllow(t *testing.T) {
// Initialize the hashes file and use it with the storage. errColl := &agdtest.ErrorCollector{
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t)
c.SafeBrowsing = &filter.HashPrefixConfig{
Hashes: hashes,
ReplacementHost: safeBrowsingSafeHost,
CacheTTL: 10 * time.Second,
CacheSize: 100,
}
c.ErrColl = &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") }, OnCollect: func(_ context.Context, err error) { panic("not implemented") },
} }
c.Resolver = &agdtest.Resolver{ resolver := &agdtest.Resolver{
OnLookupIP: func( OnLookupIP: func(
_ context.Context, _ context.Context,
_ netutil.AddrFamily, _ netutil.AddrFamily,
@ -161,6 +139,35 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
}, },
} }
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := hashstorage.New(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDSafeBrowsing,
CachePath: tmpFile.Name(),
ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
})
require.NoError(t, err)
c.ErrColl = errColl
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c) s, err := filter.NewDefaultStorage(c)
require.NoError(t, err) require.NoError(t, err)
@ -211,39 +218,11 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
// parental protection from 11:00:00 until 12:59:59. // parental protection from 11:00:00 until 12:59:59.
nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC) nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC)
// Initialize the hashes file and use it with the storage. errColl := &agdtest.ErrorCollector{
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t)
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
c.AdultBlocking = &filter.HashPrefixConfig{
Hashes: hashes,
ReplacementHost: safeBrowsingSafeHost,
CacheTTL: 10 * time.Second,
CacheSize: 100,
}
c.Now = func() (t time.Time) {
return nowTime
}
c.ErrColl = &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") }, OnCollect: func(_ context.Context, err error) { panic("not implemented") },
} }
c.Resolver = &agdtest.Resolver{ resolver := &agdtest.Resolver{
OnLookupIP: func( OnLookupIP: func(
_ context.Context, _ context.Context,
_ netutil.AddrFamily, _ netutil.AddrFamily,
@ -253,6 +232,40 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
}, },
} }
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := hashstorage.New(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDAdultBlocking,
CachePath: tmpFile.Name(),
ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
})
require.NoError(t, err)
c.Now = func() (t time.Time) {
return nowTime
}
c.ErrColl = errColl
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c) s, err := filter.NewDefaultStorage(c)
require.NoError(t, err) require.NoError(t, err)

View File

@ -20,9 +20,6 @@ import (
// FileConfig is the file-based GeoIP configuration structure. // FileConfig is the file-based GeoIP configuration structure.
type FileConfig struct { type FileConfig struct {
// ErrColl is the error collector that is used to report errors.
ErrColl agd.ErrorCollector
// ASNPath is the path to the GeoIP database of ASNs. // ASNPath is the path to the GeoIP database of ASNs.
ASNPath string ASNPath string
@ -39,8 +36,6 @@ type FileConfig struct {
// File is a file implementation of [geoip.Interface]. // File is a file implementation of [geoip.Interface].
type File struct { type File struct {
errColl agd.ErrorCollector
// mu protects asn, country, country subnet maps, and caches against // mu protects asn, country, country subnet maps, and caches against
// simultaneous access during a refresh. // simultaneous access during a refresh.
mu *sync.RWMutex mu *sync.RWMutex
@ -77,8 +72,6 @@ type asnSubnets map[agd.ASN]netip.Prefix
// NewFile returns a new GeoIP database that reads information from a file. // NewFile returns a new GeoIP database that reads information from a file.
func NewFile(c *FileConfig) (f *File, err error) { func NewFile(c *FileConfig) (f *File, err error) {
f = &File{ f = &File{
errColl: c.ErrColl,
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
asnPath: c.ASNPath, asnPath: c.ASNPath,

View File

@ -1,12 +1,10 @@
package geoip_test package geoip_test
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -14,12 +12,7 @@ import (
) )
func TestFile_Data(t *testing.T) { func TestFile_Data(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{ conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath, ASNPath: asnPath,
CountryPath: countryPath, CountryPath: countryPath,
HostCacheSize: 0, HostCacheSize: 0,
@ -42,12 +35,7 @@ func TestFile_Data(t *testing.T) {
} }
func TestFile_Data_hostCache(t *testing.T) { func TestFile_Data_hostCache(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{ conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath, ASNPath: asnPath,
CountryPath: countryPath, CountryPath: countryPath,
HostCacheSize: 1, HostCacheSize: 1,
@ -74,12 +62,7 @@ func TestFile_Data_hostCache(t *testing.T) {
} }
func TestFile_SubnetByLocation(t *testing.T) { func TestFile_SubnetByLocation(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{ conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath, ASNPath: asnPath,
CountryPath: countryPath, CountryPath: countryPath,
HostCacheSize: 0, HostCacheSize: 0,
@ -106,12 +89,7 @@ var locSink *agd.Location
var errSink error var errSink error
func BenchmarkFile_Data(b *testing.B) { func BenchmarkFile_Data(b *testing.B) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{ conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath, ASNPath: asnPath,
CountryPath: countryPath, CountryPath: countryPath,
HostCacheSize: 0, HostCacheSize: 0,
@ -162,12 +140,7 @@ func BenchmarkFile_Data(b *testing.B) {
var fileSink *geoip.File var fileSink *geoip.File
func BenchmarkNewFile(b *testing.B) { func BenchmarkNewFile(b *testing.B) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{ conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath, ASNPath: asnPath,
CountryPath: countryPath, CountryPath: countryPath,
HostCacheSize: 0, HostCacheSize: 0,

View File

@ -11,7 +11,7 @@ var DNSSvcRequestByCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_country_total", Name: "request_per_country_total",
Namespace: namespace, Namespace: namespace,
Subsystem: subsystemDNSSvc, Subsystem: subsystemDNSSvc,
Help: "The number of filtered DNS requests labeled by country and continent.", Help: "The number of processed DNS requests labeled by country and continent.",
}, []string{"continent", "country"}) }, []string{"continent", "country"})
// DNSSvcRequestByASNTotal is a counter with the total number of queries // DNSSvcRequestByASNTotal is a counter with the total number of queries
@ -20,13 +20,14 @@ var DNSSvcRequestByASNTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_asn_total", Name: "request_per_asn_total",
Namespace: namespace, Namespace: namespace,
Subsystem: subsystemDNSSvc, Subsystem: subsystemDNSSvc,
Help: "The number of filtered DNS requests labeled by country and ASN.", Help: "The number of processed DNS requests labeled by country and ASN.",
}, []string{"country", "asn"}) }, []string{"country", "asn"})
// DNSSvcRequestByFilterTotal is a counter with the total number of queries // DNSSvcRequestByFilterTotal is a counter with the total number of queries
// processed labeled by filter. "filter" contains the ID of the filter list // processed labeled by filter. Processed could mean that the request was
// applied. "anonymous" is "0" if the request is from a AdGuard DNS customer, // blocked or unblocked by a rule from that filter list. "filter" contains
// otherwise it is "1". // the ID of the filter list applied. "anonymous" is "0" if the request is
// from a AdGuard DNS customer, otherwise it is "1".
var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{ var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_filter_total", Name: "request_per_filter_total",
Namespace: namespace, Namespace: namespace,

View File

@ -26,6 +26,8 @@ const (
subsystemQueryLog = "querylog" subsystemQueryLog = "querylog"
subsystemRuleStat = "rulestat" subsystemRuleStat = "rulestat"
subsystemTLS = "tls" subsystemTLS = "tls"
subsystemResearch = "research"
subsystemWebSvc = "websvc"
) )
// SetUpGauge signals that the server has been started. // SetUpGauge signals that the server has been started.

View File

@ -0,0 +1,61 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// ResearchRequestsPerCountryTotal counts the total number of queries per
// country from anonymous users.
var ResearchRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "requests_per_country_total",
Namespace: namespace,
Subsystem: subsystemResearch,
Help: "The total number of DNS queries per country from anonymous users.",
}, []string{"country"})
// ResearchBlockedRequestsPerCountryTotal counts the number of blocked queries
// per country from anonymous users.
var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "blocked_per_country_total",
Namespace: namespace,
Subsystem: subsystemResearch,
Help: "The number of blocked DNS queries per country from anonymous users.",
}, []string{"filter", "country"})
// ReportResearchMetrics reports metrics to prometheus that we may need to
// conduct researches.
//
// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved.
func ReportResearchMetrics(
anonymous bool,
filteringEnabled bool,
asn string,
ctry string,
filterID string,
blocked bool,
) {
// The current research metrics only count queries that come to public
// DNS servers where filtering is enabled.
if !filteringEnabled || !anonymous {
return
}
// Ignore AdGuard ASN specifically in order to avoid counting queries that
// come from the monitoring. This part is ugly, but since these metrics
// are a one-time deal, this is acceptable.
//
// TODO(ameshkov): think of a better way later if we need to do that again.
if asn == "212772" {
return
}
if blocked {
ResearchBlockedRequestsPerCountryTotal.WithLabelValues(
filterID,
ctry,
).Inc()
}
ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
}

View File

@ -86,7 +86,7 @@ func (c *userCounter) record(now time.Time, ip netip.Addr, syncUpdate bool) {
prevMinuteCounter := c.currentMinuteCounter prevMinuteCounter := c.currentMinuteCounter
c.currentMinute = minuteOfTheDay c.currentMinute = minuteOfTheDay
c.currentMinuteCounter = hyperloglog.New() c.currentMinuteCounter = newHyperLogLog()
// If this is the first iteration and prevMinute is -1, don't update the // If this is the first iteration and prevMinute is -1, don't update the
// counters, since there are none. // counters, since there are none.
@ -123,7 +123,7 @@ func (c *userCounter) updateCounters(prevMinute int, prevCounter *hyperloglog.Sk
// estimate uses HyperLogLog counters to estimate the hourly and daily users // estimate uses HyperLogLog counters to estimate the hourly and daily users
// count, starting with the minute of the day m. // count, starting with the minute of the day m.
func (c *userCounter) estimate(m int) (hourly, daily uint64) { func (c *userCounter) estimate(m int) (hourly, daily uint64) {
hourlyCounter, dailyCounter := hyperloglog.New(), hyperloglog.New() hourlyCounter, dailyCounter := newHyperLogLog(), newHyperLogLog()
// Go through all minutes in a day while decreasing the current minute m. // Go through all minutes in a day while decreasing the current minute m.
// Decreasing m, as opposed to increasing it or using i as the minute, is // Decreasing m, as opposed to increasing it or using i as the minute, is
@ -168,6 +168,11 @@ func decrMod(n, m int) (res int) {
return n - 1 return n - 1
} }
// newHyperLogLog creates a new instance of hyperloglog.Sketch.
func newHyperLogLog() (sk *hyperloglog.Sketch) {
return hyperloglog.New16()
}
// defaultUserCounter is the main user statistics counter. // defaultUserCounter is the main user statistics counter.
var defaultUserCounter = newUserCounter() var defaultUserCounter = newUserCounter()

View File

@ -0,0 +1,69 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
webSvcRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "websvc_requests_total",
Namespace: namespace,
Subsystem: subsystemWebSvc,
Help: "The number of DNS requests for websvc.",
}, []string{"kind"})
// WebSvcError404RequestsTotal is a counter with total number of
// requests with error 404.
WebSvcError404RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "error404",
})
// WebSvcError500RequestsTotal is a counter with total number of
// requests with error 500.
WebSvcError500RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "error500",
})
// WebSvcStaticContentRequestsTotal is a counter with total number of
// requests for static content.
WebSvcStaticContentRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "static_content",
})
// WebSvcDNSCheckTestRequestsTotal is a counter with total number of
// requests for dnscheck_test.
WebSvcDNSCheckTestRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "dnscheck_test",
})
// WebSvcRobotsTxtRequestsTotal is a counter with total number of
// requests for robots_txt.
WebSvcRobotsTxtRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "robots_txt",
})
// WebSvcRootRedirectRequestsTotal is a counter with total number of
// root redirected requests.
WebSvcRootRedirectRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "root_redirect",
})
// WebSvcLinkedIPProxyRequestsTotal is a counter with total number of
// requests with linked ip.
WebSvcLinkedIPProxyRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "linkip",
})
// WebSvcAdultBlockingPageRequestsTotal is a counter with total number
// of requests for adult blocking page.
WebSvcAdultBlockingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "adult_blocking_page",
})
// WebSvcSafeBrowsingPageRequestsTotal is a counter with total number
// of requests for safe browsing page.
WebSvcSafeBrowsingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "safe_browsing_page",
})
)

View File

@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
) )
// FileSystemConfig is the configuration of the file system query log. // FileSystemConfig is the configuration of the file system query log.
@ -71,11 +72,6 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
metrics.QueryLogItemsCount.Inc() metrics.QueryLogItemsCount.Inc()
}() }()
var dnssec uint8 = 0
if e.DNSSEC {
dnssec = 1
}
entBuf := l.bufferPool.Get().(*entryBuffer) entBuf := l.bufferPool.Get().(*entryBuffer)
defer l.bufferPool.Put(entBuf) defer l.bufferPool.Put(entBuf)
entBuf.buf.Reset() entBuf.buf.Reset()
@ -94,7 +90,7 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
ClientASN: e.ClientASN, ClientASN: e.ClientASN,
Elapsed: e.Elapsed, Elapsed: e.Elapsed,
RequestType: e.RequestType, RequestType: e.RequestType,
DNSSEC: dnssec, DNSSEC: mathutil.BoolToNumber[uint8](e.DNSSEC),
Protocol: e.Protocol, Protocol: e.Protocol,
ResultCode: c, ResultCode: c,
ResponseCode: e.ResponseCode, ResponseCode: e.ResponseCode,

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -58,12 +59,16 @@ func (svc *Service) processRec(
body = svc.error404 body = svc.error404
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
} }
metrics.WebSvcError404RequestsTotal.Inc()
case http.StatusInternalServerError: case http.StatusInternalServerError:
action = "writing 500" action = "writing 500"
if len(svc.error500) != 0 { if len(svc.error500) != 0 {
body = svc.error500 body = svc.error500
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
} }
metrics.WebSvcError500RequestsTotal.Inc()
default: default:
action = "writing response" action = "writing response"
for k, v := range rec.Header() { for k, v := range rec.Header() {
@ -87,6 +92,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/dnscheck/test": case "/dnscheck/test":
svc.dnsCheck.ServeHTTP(w, r) svc.dnsCheck.ServeHTTP(w, r)
metrics.WebSvcDNSCheckTestRequestsTotal.Inc()
case "/robots.txt": case "/robots.txt":
serveRobotsDisallow(w.Header(), w, "handler") serveRobotsDisallow(w.Header(), w, "handler")
case "/": case "/":
@ -94,6 +101,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r) http.NotFound(w, r)
} else { } else {
http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound) http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound)
metrics.WebSvcRootRedirectRequestsTotal.Inc()
} }
default: default:
http.NotFound(w, r) http.NotFound(w, r)
@ -101,7 +110,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
} }
// safeBrowsingHandler returns an HTTP handler serving the block page from the // safeBrowsingHandler returns an HTTP handler serving the block page from the
// blockPagePath. name is used for logging. // blockPagePath. name is used for logging and metrics.
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) { func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
f := func(w http.ResponseWriter, r *http.Request) { f := func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header() hdr := w.Header()
@ -122,6 +131,15 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
if err != nil { if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err) logErrorByType(err, "websvc: %s: writing response: %s", name, err)
} }
switch name {
case adultBlockingName:
metrics.WebSvcAdultBlockingPageRequestsTotal.Inc()
case safeBrowsingName:
metrics.WebSvcSafeBrowsingPageRequestsTotal.Inc()
default:
// Go on.
}
} }
} }
@ -140,10 +158,11 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
return false return false
} }
if f.AllowOrigin != "" { h := w.Header()
w.Header().Set(agdhttp.HdrNameAccessControlAllowOrigin, f.AllowOrigin) for k, v := range f.Headers {
h[k] = v
} }
w.Header().Set(agdhttp.HdrNameContentType, f.ContentType)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, err := w.Write(f.Content) _, err := w.Write(f.Content)
@ -151,16 +170,15 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
logErrorByType(err, "websvc: static content: writing %s: %s", p, err) logErrorByType(err, "websvc: static content: writing %s: %s", p, err)
} }
metrics.WebSvcStaticContentRequestsTotal.Inc()
return true return true
} }
// StaticFile is a single file in a StaticFS. // StaticFile is a single file in a StaticFS.
type StaticFile struct { type StaticFile struct {
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header. // Headers contains headers of the HTTP response.
AllowOrigin string Headers http.Header
// ContentType is the value for the HTTP Content-Type header.
ContentType string
// Content is the file content. // Content is the file content.
Content []byte Content []byte
@ -174,6 +192,8 @@ func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) {
if err != nil { if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err) logErrorByType(err, "websvc: %s: writing response: %s", name, err)
} }
metrics.WebSvcRobotsTxtRequestsTotal.Inc()
} }
// logErrorByType writes err to the error log, unless err is a network error or // logErrorByType writes err to the error log, unless err is a network error or

View File

@ -31,8 +31,10 @@ func TestService_ServeHTTP(t *testing.T) {
staticContent := map[string]*websvc.StaticFile{ staticContent := map[string]*websvc.StaticFile{
"/favicon.ico": { "/favicon.ico": {
ContentType: "image/x-icon", Content: []byte{},
Content: []byte{}, Headers: http.Header{
agdhttp.HdrNameContentType: []string{"image/x-icon"},
},
}, },
} }
@ -56,22 +58,33 @@ func TestService_ServeHTTP(t *testing.T) {
}) })
// DNSCheck path. // DNSCheck path.
assertPathResponse(t, svc, "/dnscheck/test", http.StatusOK) assertResponse(t, svc, "/dnscheck/test", http.StatusOK)
// Static content path. // Static content path with headers.
assertPathResponse(t, svc, "/favicon.ico", http.StatusOK) h := http.Header{
agdhttp.HdrNameContentType: []string{"image/x-icon"},
agdhttp.HdrNameServer: []string{"AdGuardDNS/"},
}
assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h)
// Robots path. // Robots path.
assertPathResponse(t, svc, "/robots.txt", http.StatusOK) assertResponse(t, svc, "/robots.txt", http.StatusOK)
// Root redirect path. // Root redirect path.
assertPathResponse(t, svc, "/", http.StatusFound) assertResponse(t, svc, "/", http.StatusFound)
// Other path. // Other path.
assertPathResponse(t, svc, "/other", http.StatusNotFound) assertResponse(t, svc, "/other", http.StatusNotFound)
} }
func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCode int) { // assertResponse is a helper function that checks status code of HTTP
// response.
func assertResponse(
t *testing.T,
svc *websvc.Service,
path string,
statusCode int,
) (rw *httptest.ResponseRecorder) {
t.Helper() t.Helper()
r := httptest.NewRequest(http.MethodGet, (&url.URL{ r := httptest.NewRequest(http.MethodGet, (&url.URL{
@ -79,9 +92,27 @@ func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCo
Host: "127.0.0.1", Host: "127.0.0.1",
Path: path, Path: path,
}).String(), strings.NewReader("")) }).String(), strings.NewReader(""))
rw := httptest.NewRecorder() rw = httptest.NewRecorder()
svc.ServeHTTP(rw, r) svc.ServeHTTP(rw, r)
assert.Equal(t, statusCode, rw.Code) assert.Equal(t, statusCode, rw.Code)
assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer)) assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
return rw
}
// assertResponseWithHeaders is a helper function that checks status code and
// headers of HTTP response.
func assertResponseWithHeaders(
t *testing.T,
svc *websvc.Service,
path string,
statusCode int,
header http.Header,
) {
t.Helper()
rw := assertResponse(t, svc, path, statusCode)
assert.Equal(t, header, rw.Header())
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -148,6 +149,8 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID) log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
prx.httpProxy.ServeHTTP(w, r) prx.httpProxy.ServeHTTP(w, r)
metrics.WebSvcLinkedIPProxyRequestsTotal.Inc()
} else if r.URL.Path == "/robots.txt" { } else if r.URL.Path == "/robots.txt" {
serveRobotsDisallow(respHdr, w, prx.logPrefix) serveRobotsDisallow(respHdr, w, prx.logPrefix)
} else { } else {

View File

@ -118,8 +118,8 @@ func New(c *Config) (svc *Service) {
error404: c.Error404, error404: c.Error404,
error500: c.Error500, error500: c.Error500,
adultBlocking: blockPageServers(c.AdultBlocking, "adult blocking", c.Timeout), adultBlocking: blockPageServers(c.AdultBlocking, adultBlockingName, c.Timeout),
safeBrowsing: blockPageServers(c.SafeBrowsing, "safe browsing", c.Timeout), safeBrowsing: blockPageServers(c.SafeBrowsing, safeBrowsingName, c.Timeout),
} }
if c.RootRedirectURL != nil { if c.RootRedirectURL != nil {
@ -164,6 +164,12 @@ func New(c *Config) (svc *Service) {
return svc return svc
} }
// Names for safeBrowsingHandler for logging and metrics.
const (
safeBrowsingName = "safe browsing"
adultBlockingName = "adult blocking"
)
// blockPageServers is a helper function that converts a *BlockPageServer into // blockPageServers is a helper function that converts a *BlockPageServer into
// HTTP servers. // HTTP servers.
func blockPageServers( func blockPageServers(

View File

@ -117,7 +117,9 @@ underscores() {
git ls-files '*_*.go'\ git ls-files '*_*.go'\
| grep -F\ | grep -F\
-e '_generate.go'\ -e '_generate.go'\
-e '_linux.go'\
-e '_noreuseport.go'\ -e '_noreuseport.go'\
-e '_others.go'\
-e '_reuseport.go'\ -e '_reuseport.go'\
-e '_test.go'\ -e '_test.go'\
-e '_unix.go'\ -e '_unix.go'\