Sync v2.1.5

This commit is contained in:
Andrey Meshkov 2023-02-03 15:27:58 +03:00
parent 7dec041e0f
commit f20c533cc3
97 changed files with 2865 additions and 1340 deletions

View File

@ -11,6 +11,99 @@ The format is **not** based on [Keep a Changelog][kec], since the project
## AGDNS-916 / Build 456
* `ratelimit` now defines rate of requests per second for IPv4 and IPv6
addresses separately. So replace this:
```yaml
ratelimit:
rps: 30
ipv4_subnet_key_len: 24
ipv6_subnet_key_len: 48
```
with this:
```yaml
ratelimit:
ipv4:
rps: 30
subnet_key_len: 24
ipv6:
rps: 300
subnet_key_len: 48
```
## AGDNS-907 / Build 449
* The objects within the `filtering_groups` have a new property,
`block_firefox_canary`. So replace this:
```yaml
filtering_groups:
-
id: default
# …
```
with this:
```yaml
filtering_groups:
-
id: default
# …
block_firefox_canary: true
```
The recommended default value is `true`.
## AGDNS-1308 / Build 447
* There is now a new env variable `RESEARCH_METRICS` that controls whether
collecting research metrics is enabled or not. Also, the first research
metric is added: `dns_research_blocked_per_country_total`, it counts the
number of blocked requests per country. Its default value is `0`, i.e.
research metrics collection is disabled by default.
## AGDNS-1051 / Build 443
* There are two changes in the keys of the `static_content` map. Firstly,
properties `allow_origin` and `content_type` are removed. Secondly, a new
property, called `headers`, is added. So replace this:
```yaml
static_content:
'/favicon.ico':
# …
allow_origin: '*'
content_type: 'image/x-icon'
```
with this:
```yaml
static_content:
'/favicon.ico':
# …
headers:
'Access-Control-Allow-Origin':
- '*'
'Content-Type':
- 'image/x-icon'
```
Adjust or add the values, if necessary.
## AGDNS-1278 / Build 423
* The object `filters` has two new properties, `rule_list_cache_size` and

View File

@ -8,8 +8,21 @@ ratelimit:
refuseany: true
# If response is larger than this, it is counted as several responses.
response_size_estimate: 1KB
# Rate of requests per second for one subnet.
rps: 30
# Rate limit options for IPv4 addresses.
ipv4:
# Rate of requests per second for one subnet for IPv4 addresses.
rps: 30
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv4 addresses.
subnet_key_len: 24
# Rate limit options for IPv6 addresses.
ipv6:
# Rate of requests per second for one subnet for IPv6 addresses.
rps: 300
# The lengths of the subnet prefixes used to calculate rate limiter
# bucket keys for IPv6 addresses.
subnet_key_len: 48
# The time during which to count the number of times a client has hit the
# rate limit for a back off.
#
@ -21,10 +34,6 @@ ratelimit:
# How much a client that has hit the rate limit too often stays in the back
# off.
back_off_duration: 30m
# The lengths of the subnet prefixes used to calculate rate limiter bucket
# keys for IPv4 and IPv6 addresses correspondingly.
ipv4_subnet_key_len: 24
ipv6_subnet_key_len: 48
# Configuration for the allowlist.
allowlist:
@ -156,9 +165,12 @@ web:
# servers. Paths must not cross the ones used by the DNS-over-HTTPS server.
static_content:
'/favicon.ico':
allow_origin: '*'
content_type: 'image/x-icon'
content: ''
headers:
Access-Control-Allow-Origin:
- '*'
Content-Type:
- 'image/x-icon'
# If not defined, AdGuard DNS will respond with a 404 page to all such
# requests.
root_redirect_url: 'https://adguard-dns.com'
@ -221,6 +233,7 @@ filtering_groups:
safe_browsing:
enabled: true
block_private_relay: false
block_firefox_canary: true
- id: 'family'
parental:
enabled: true
@ -234,6 +247,7 @@ filtering_groups:
safe_browsing:
enabled: true
block_private_relay: false
block_firefox_canary: true
- id: 'non_filtering'
rule_lists:
enabled: false
@ -242,6 +256,7 @@ filtering_groups:
safe_browsing:
enabled: false
block_private_relay: false
block_firefox_canary: true
# Server groups and servers.
server_groups:

View File

@ -85,12 +85,24 @@ The `ratelimit` object has the following properties:
**Example:** `30m`.
* <a href="#ratelimit-rps" id="ratelimit-rps" name="ratelimit-rps">`rps`</a>:
* <a href="#ratelimit-ipv4" id="ratelimit-ipv4" name="ratelimit-ipv4">`ipv4`</a>:
The ipv4 configuration object. It has the following fields:
* <a href="#ratelimit-ipv4-rps" id="ratelimit-ipv4-rps" name="ratelimit-ipv4-rps">`rps`</a>:
The rate of requests per second for one subnet. Requests above this are
counted in the backoff count.
**Example:** `30`.
* <a href="#ratelimit-ipv4-subnet_key_len" id="ratelimit-ipv4-subnet_key_len" name="ratelimit-ipv4-subnet_key_len">`ipv4-subnet_key_len`</a>:
The length of the subnet prefix used to calculate rate limiter bucket keys.
**Example:** `24`.
* <a href="#ratelimit-ipv6" id="ratelimit-ipv6" name="ratelimit-ipv6">`ipv6`</a>:
The `ipv6` configuration object has the same properties as the `ipv4` one
above.
* <a href="#ratelimit-back_off_count" id="ratelimit-back_off_count" name="ratelimit-back_off_count">`back_off_count`</a>:
Maximum number of requests a client can make above the RPS within
a `back_off_period`. When a client exceeds this limit, requests aren't
@ -120,22 +132,11 @@ The `ratelimit` object has the following properties:
**Example:** `30s`.
* <a href="#ratelimit-ipv4_subnet_key_len" id="ratelimit-ipv4_subnet_key_len" name="ratelimit-ipv4_subnet_key_len">`ipv4_subnet_key_len`</a>:
The length of the subnet prefix used to calculate rate limiter bucket keys
for IPv4 addresses.
**Example:** `24`.
* <a href="#ratelimit-ipv6_subnet_key_len" id="ratelimit-ipv6_subnet_key_len" name="ratelimit-ipv6_subnet_key_len">`ipv6_subnet_key_len`</a>:
Same as `ipv4_subnet_key_len` above but for IPv6 addresses.
**Example:** `48`.
For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and `rps`
is `5`, a client (meaning all IP addresses within the subnet defined by
`ipv4_subnet_key_len` and `ipv6_subnet_key_len`) that made 15 requests in one
second or 6 requests (one above `rps`) every second for 10 seconds within one
minute, the client is blocked for `back_off_duration`.
For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and
`ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined
by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests
(one above `rps`) every second for 10 seconds within one minute, the client is
blocked for `back_off_duration`.
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
@ -454,13 +455,17 @@ The optional `web` object has the following properties:
`safe_browsing` and `adult_blocking` servers. Paths must not duplicate the
ones used by the DNS-over-HTTPS server.
Inside of the `headers` map, the header `Content-Type` is required.
**Property example:**
```yaml
'static_content':
static_content:
'/favicon.ico':
'content_type': 'image/x-icon'
'content': 'base64content'
content: 'base64content'
headers:
'Content-Type':
- 'image/x-icon'
```
* <a href="#web-root_redirect_url" id="web-root_redirect_url" name="web-root_redirect_url">`root_redirect_url`</a>:
@ -647,6 +652,12 @@ The items of the `filtering_groups` array have the following properties:
**Example:** `false`.
* <a href="#fg-*-block_firefox_canary" id="fg-*-block_firefox_canary" name="fg-*-block_firefox_canary">`block_firefox_canary`</a>:
If true, Firefox canary domain queries are blocked for requests using this
filtering group.
**Example:** `true`.
## <a href="#server_groups" id="server_groups" name="server_groups">Server groups</a>

View File

@ -22,6 +22,7 @@ sensitive configuration. All other configuration is stored in the
* [`LISTEN_PORT`](#LISTEN_PORT)
* [`LOG_TIMESTAMP`](#LOG_TIMESTAMP)
* [`QUERYLOG_PATH`](#QUERYLOG_PATH)
* [`RESEARCH_METRICS`](#RESEARCH_METRICS)
* [`RULESTAT_URL`](#RULESTAT_URL)
* [`SENTRY_DSN`](#SENTRY_DSN)
* [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE)
@ -198,6 +199,15 @@ The path to the file into which the query log is going to be written.
## <a href="#RESEARCH_METRICS" id="RESEARCH_METRICS" name="RESEARCH_METRICS">`RESEARCH_METRICS`</a>
If `1`, enable collection of a set of special prometheus metrics (prefix is
`dns_research`). If `0`, disable collection of those metrics.
**Default:** `0`.
## <a href="#RULESTAT_URL" id="RULESTAT_URL" name="RULESTAT_URL">`RULESTAT_URL`</a>
The URL to send filtering rule list statistics to. If empty or unset, the

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.19
require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
github.com/AdguardTeam/golibs v0.11.3
github.com/AdguardTeam/golibs v0.11.4
github.com/AdguardTeam/urlfilter v0.16.1
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a

4
go.sum
View File

@ -33,8 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=

View File

@ -9,7 +9,11 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVl
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
@ -74,7 +78,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
@ -198,41 +202,36 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
@ -246,6 +245,7 @@ gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=

View File

@ -1,11 +1,6 @@
package agd
import (
"context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
)
import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
// Common DNS Message Constants, Types, And Utilities
@ -27,10 +22,3 @@ const (
ProtoDoT = dnsserver.ProtoDoT
ProtoDNSCrypt = dnsserver.ProtoDNSCrypt
)
// Resolver is the DNS resolver interface.
//
// See go doc net.Resolver.
type Resolver interface {
LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error)
}

View File

@ -146,6 +146,10 @@ type FilteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay is blocked for requests
// using this filtering group.
BlockPrivateRelay bool
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests using this filtering group.
BlockFirefoxCanary bool
}
// FilteringGroupID is the ID of a filter group. It is an opaque string.

View File

@ -73,6 +73,10 @@ type Profile struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests from all devices in this profile.
BlockPrivateRelay bool
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests from all devices in this profile.
BlockFirefoxCanary bool
}
// ProfileID is the ID of a profile. It is an opaque string.

View File

@ -17,6 +17,8 @@ import (
// Data Storage
// ProfileDB is the local database of profiles and other data.
//
// TODO(a.garipov): move this logic to the backend package.
type ProfileDB interface {
ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)

View File

@ -7,6 +7,8 @@ package agdio
import (
"fmt"
"io"
"github.com/AdguardTeam/golibs/mathutil"
)
// LimitError is returned when the Limit is reached.
@ -35,9 +37,8 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
}
}
if int64(len(p)) > lr.n {
p = p[0:lr.n]
}
l := mathutil.Min(int64(len(p)), lr.n)
p = p[:l]
n, err = lr.r.Read(p)
lr.n -= int64(n)

View File

@ -165,6 +165,31 @@ type v1SettingsRespSettings struct {
FilteringEnabled bool `json:"filtering_enabled"`
Deleted bool `json:"deleted"`
BlockPrivateRelay bool `json:"block_private_relay"`
BlockFirefoxCanary bool `json:"block_firefox_canary"`
}
// type check
var _ json.Unmarshaler = (*v1SettingsRespSettings)(nil)
// UnmarshalJSON implements the [json.Unmarshaler] interface for
// *v1SettingsRespSettings. It puts default value into BlockFirefoxCanary
// field while it is not implemented on the backend side.
//
// TODO(a.garipov): Remove once the backend starts to always send it.
func (rs *v1SettingsRespSettings) UnmarshalJSON(b []byte) (err error) {
type defaultDec v1SettingsRespSettings
s := defaultDec{
BlockFirefoxCanary: true,
}
if err = json.Unmarshal(b, &s); err != nil {
return err
}
*rs = v1SettingsRespSettings(s)
return nil
}
// v1SettingsRespRuleLists is the structure for decoding filtering rule lists
@ -414,10 +439,11 @@ const maxFltRespTTL = 1 * time.Hour
func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) {
ttl = time.Duration(respTTL) * time.Second
if ttl > maxFltRespTTL {
return ttl, fmt.Errorf("too high: got %d, max %d", respTTL, maxFltRespTTL)
ttl = maxFltRespTTL
err = fmt.Errorf("too high: got %s, max %s", ttl, maxFltRespTTL)
}
return ttl, nil
return ttl, err
}
// toInternal converts r to an [agd.DSProfilesResponse] instance.
@ -461,6 +487,9 @@ func (r *v1SettingsResp) toInternal(
reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err)
// Go on and use the fixed value.
//
// TODO(ameshkov, a.garipov): Consider continuing, like with all
// other validation errors.
}
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
@ -479,6 +508,7 @@ func (r *v1SettingsResp) toInternal(
QueryLogEnabled: s.QueryLogEnabled,
Deleted: s.Deleted,
BlockPrivateRelay: s.BlockPrivateRelay,
BlockFirefoxCanary: s.BlockFirefoxCanary,
})
}

View File

@ -131,6 +131,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true,
Deleted: false,
BlockPrivateRelay: true,
BlockFirefoxCanary: true,
}, {
Parental: wantParental,
ID: "83f3ea8f",
@ -162,6 +163,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true,
Deleted: true,
BlockPrivateRelay: false,
BlockFirefoxCanary: false,
}},
}

View File

@ -11,6 +11,7 @@
},
"deleted": false,
"block_private_relay": true,
"block_firefox_canary": true,
"devices": [
{
"id": "118ffe93",
@ -43,6 +44,7 @@
},
"deleted": true,
"block_private_relay": false,
"block_firefox_canary": false,
"devices": [
{
"id": "0d7724fa",

View File

@ -28,6 +28,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
// Main is the entry point of application.
@ -89,52 +90,75 @@ func Main() {
go func() {
defer geoIPMu.Unlock()
geoIP, geoIPErr = envs.geoIP(c.GeoIP, errColl)
geoIP, geoIPErr = envs.geoIP(c.GeoIP)
}()
// Safe Browsing Hosts
// Safe-browsing and adult-blocking filters
// TODO(ameshkov): Consider making configurable.
filteringResolver := agdnet.NewCachingResolver(
agdnet.DefaultResolver{},
1*timeutil.Day,
)
err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm)
check(err)
safeBrowsingConf := c.SafeBrowsing.toInternal(
safeBrowsingConf, err := c.SafeBrowsing.toInternal(
errColl,
filteringResolver,
agd.FilterListIDSafeBrowsing,
envs.FilterCachePath,
errColl,
)
safeBrowsingHashes, err := filter.NewHashStorage(safeBrowsingConf)
check(err)
err = safeBrowsingHashes.Start()
safeBrowsingFilter, err := filter.NewHashPrefix(safeBrowsingConf)
check(err)
adultBlockingConf := c.AdultBlocking.toInternal(
safeBrowsingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: ctxWithDefaultTimeout,
Refresher: safeBrowsingFilter,
ErrColl: errColl,
Name: string(agd.FilterListIDSafeBrowsing),
Interval: safeBrowsingConf.Staleness,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
err = safeBrowsingUpd.Start()
check(err)
adultBlockingConf, err := c.AdultBlocking.toInternal(
errColl,
filteringResolver,
agd.FilterListIDAdultBlocking,
envs.FilterCachePath,
errColl,
)
adultBlockingHashes, err := filter.NewHashStorage(adultBlockingConf)
check(err)
err = adultBlockingHashes.Start()
adultBlockingFilter, err := filter.NewHashPrefix(adultBlockingConf)
check(err)
// Filters And Filtering Groups
adultBlockingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: ctxWithDefaultTimeout,
Refresher: adultBlockingFilter,
ErrColl: errColl,
Name: string(agd.FilterListIDAdultBlocking),
Interval: adultBlockingConf.Staleness,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
err = adultBlockingUpd.Start()
check(err)
fltStrgConf := c.Filters.toInternal(errColl, envs)
fltStrgConf.SafeBrowsing = &filter.HashPrefixConfig{
Hashes: safeBrowsingHashes,
ReplacementHost: c.SafeBrowsing.BlockHost,
CacheTTL: c.SafeBrowsing.CacheTTL.Duration,
CacheSize: c.SafeBrowsing.CacheSize,
}
// Filter storage and filtering groups
fltStrgConf.AdultBlocking = &filter.HashPrefixConfig{
Hashes: adultBlockingHashes,
ReplacementHost: c.AdultBlocking.BlockHost,
CacheTTL: c.AdultBlocking.CacheTTL.Duration,
CacheSize: c.AdultBlocking.CacheSize,
}
fltStrgConf := c.Filters.toInternal(
errColl,
filteringResolver,
envs,
safeBrowsingFilter,
adultBlockingFilter,
)
fltStrg, err := filter.NewDefaultStorage(fltStrgConf)
check(err)
@ -153,8 +177,6 @@ func Main() {
err = fltStrgUpd.Start()
check(err)
safeBrowsing := filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes)
// Server Groups
fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
@ -330,7 +352,10 @@ func Main() {
dnsConf := &dnssvc.Config{
Messages: messages,
SafeBrowsing: safeBrowsing,
SafeBrowsing: filter.NewSafeBrowsingServer(
safeBrowsingConf.Hashes,
adultBlockingConf.Hashes,
),
BillStat: billStatRec,
ProfileDB: profDB,
DNSCheck: dnsCk,
@ -349,6 +374,7 @@ func Main() {
CacheSize: c.Cache.Size,
ECSCacheSize: c.Cache.ECSSize,
UseECSCache: c.Cache.Type == cacheTypeECS,
ResearchMetrics: bool(envs.ResearchMetrics),
}
dnsSvc, err := dnssvc.New(dnsConf)
@ -383,11 +409,11 @@ func Main() {
)
h := newSignalHandler(
adultBlockingHashes,
safeBrowsingHashes,
debugSvc,
webSvc,
dnsSvc,
safeBrowsingUpd,
adultBlockingUpd,
profDBUpd,
dnsDBUpd,
geoIPUpd,

View File

@ -3,14 +3,9 @@ package cmd
import (
"fmt"
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v2"
)
@ -213,74 +208,6 @@ func (c *geoIPConfig) validate() (err error) {
}
}
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []string `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
type safeBrowsingConfig struct {
// URL is the URL used to update the filter.
URL *agdhttp.URL `yaml:"url"`
// BlockHost is the hostname with which to respond to any requests that
// match the filter.
//
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
// addresses.
BlockHost string `yaml:"block_host"`
// CacheSize is the size of the response cache, in entries.
CacheSize int `yaml:"cache_size"`
// CacheTTL is the TTL of the response cache.
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// toInternal converts c to the safe browsing filter configuration for the
// filter storage of the DNS server. c is assumed to be valid.
func (c *safeBrowsingConfig) toInternal(
id agd.FilterListID,
cacheDir string,
errColl agd.ErrorCollector,
) (conf *filter.HashStorageConfig) {
return &filter.HashStorageConfig{
URL: netutil.CloneURL(&c.URL.URL),
ErrColl: errColl,
ID: id,
CachePath: filepath.Join(cacheDir, string(id)),
RefreshIvl: c.RefreshIvl.Duration,
}
}
// validate returns an error if the safe browsing filter configuration is
// invalid.
func (c *safeBrowsingConfig) validate() (err error) {
switch {
case c == nil:
return errNilConfig
case c.URL == nil:
return errors.Error("no url")
case c.BlockHost == "":
return errors.Error("no block_host")
case c.CacheSize <= 0:
return newMustBePositiveError("cache_size", c.CacheSize)
case c.CacheTTL.Duration <= 0:
return newMustBePositiveError("cache_ttl", c.CacheTTL)
case c.RefreshIvl.Duration <= 0:
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
default:
return nil
}
}
// readConfig reads the configuration.
func readConfig(confPath string) (c *configuration, err error) {
// #nosec G304 -- Trust the path to the configuration file that is given

View File

@ -50,6 +50,7 @@ type environments struct {
LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
ResearchMetrics strictBool `env:"RESEARCH_METRICS" envDefault:"0"`
}
// readEnvs reads the configuration.
@ -127,12 +128,10 @@ func (envs *environments) buildDNSDB(
// geoIP returns an GeoIP database implementation from environment.
func (envs *environments) geoIP(
c *geoIPConfig,
errColl agd.ErrorCollector,
) (g *geoip.File, err error) {
log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath)
g, err = geoip.NewFile(&geoip.FileConfig{
ErrColl: errColl,
ASNPath: envs.GeoIPASNPath,
CountryPath: envs.GeoIPCountryPath,
HostCacheSize: c.HostCacheSize,

View File

@ -47,16 +47,21 @@ type filtersConfig struct {
// cacheDir must exist. c is assumed to be valid.
func (c *filtersConfig) toInternal(
errColl agd.ErrorCollector,
resolver agdnet.Resolver,
envs *environments,
safeBrowsing *filter.HashPrefix,
adultBlocking *filter.HashPrefix,
) (conf *filter.DefaultStorageConfig) {
return &filter.DefaultStorageConfig{
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL),
GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL),
YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL),
SafeBrowsing: safeBrowsing,
AdultBlocking: adultBlocking,
Now: time.Now,
ErrColl: errColl,
Resolver: agdnet.DefaultResolver{},
Resolver: resolver,
CacheDir: envs.FilterCachePath,
CustomFilterCacheSize: c.CustomFilterCacheSize,
SafeSearchCacheSize: c.SafeSearchCacheSize,

View File

@ -29,6 +29,10 @@ type filteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests using this filtering group.
BlockPrivateRelay bool `yaml:"block_private_relay"`
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests using this filtering group.
BlockFirefoxCanary bool `yaml:"block_firefox_canary"`
}
// fltGrpRuleLists contains filter rule lists configuration for a filtering
@ -133,6 +137,7 @@ func (groups filteringGroups) toInternal(
GeneralSafeSearch: g.Parental.GeneralSafeSearch,
YoutubeSafeSearch: g.Parental.YoutubeSafeSearch,
BlockPrivateRelay: g.BlockPrivateRelay,
BlockFirefoxCanary: g.BlockFirefoxCanary,
}
}

View File

@ -15,14 +15,17 @@ type rateLimitConfig struct {
// AllowList is the allowlist of clients.
Allowlist *allowListConfig `yaml:"allowlist"`
// Rate limit options for IPv4 addresses.
IPv4 *rateLimitOptions `yaml:"ipv4"`
// Rate limit options for IPv6 addresses.
IPv6 *rateLimitOptions `yaml:"ipv6"`
// ResponseSizeEstimate is the size of the estimate of the size of one DNS
// response for the purposes of rate limiting. Responses over this estimate
// are counted as several responses.
ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"`
// RPS is the maximum number of requests per second.
RPS int `yaml:"rps"`
// BackOffCount helps with repeated offenders. It defines, how many times
// a client hits the rate limit before being held in the back off.
BackOffCount int `yaml:"back_off_count"`
@ -35,18 +38,42 @@ type rateLimitConfig struct {
// a client has hit the rate limit for a back off.
BackOffPeriod timeutil.Duration `yaml:"back_off_period"`
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses.
IPv4SubnetKeyLen int `yaml:"ipv4_subnet_key_len"`
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses.
IPv6SubnetKeyLen int `yaml:"ipv6_subnet_key_len"`
// RefuseANY, if true, makes the server refuse DNS * queries.
RefuseANY bool `yaml:"refuse_any"`
}
// allowListConfig is the consul allow list configuration.
type allowListConfig struct {
// List contains IPs and CIDRs.
List []string `yaml:"list"`
// RefreshIvl time between two updates of allow list from the Consul URL.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
// addresses.
type rateLimitOptions struct {
// RPS is the maximum number of requests per second.
RPS int `yaml:"rps"`
// SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys.
SubnetKeyLen int `yaml:"subnet_key_len"`
}
// validate returns an error if rate limit options are invalid.
func (o *rateLimitOptions) validate() (err error) {
if o == nil {
return errNilConfig
}
return coalesceError(
validatePositive("rps", o.RPS),
validatePositive("subnet_key_len", o.SubnetKeyLen),
)
}
// toInternal converts c to the rate limiting configuration for the DNS server.
// c is assumed to be valid.
func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) {
@ -55,10 +82,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()),
Duration: c.BackOffDuration.Duration,
Period: c.BackOffPeriod.Duration,
RPS: c.RPS,
IPv4RPS: c.IPv4.RPS,
IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
IPv6RPS: c.IPv6.RPS,
IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
Count: c.BackOffCount,
IPv4SubnetKeyLen: c.IPv4SubnetKeyLen,
IPv6SubnetKeyLen: c.IPv6SubnetKeyLen,
RefuseANY: c.RefuseANY,
}
}
@ -71,14 +99,21 @@ func (c *rateLimitConfig) validate() (err error) {
return fmt.Errorf("allowlist: %w", errNilConfig)
}
err = c.IPv4.validate()
if err != nil {
return fmt.Errorf("ipv4: %w", err)
}
err = c.IPv6.validate()
if err != nil {
return fmt.Errorf("ipv6: %w", err)
}
return coalesceError(
validatePositive("rps", c.RPS),
validatePositive("back_off_count", c.BackOffCount),
validatePositive("back_off_duration", c.BackOffDuration),
validatePositive("back_off_period", c.BackOffPeriod),
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
validatePositive("ipv4_subnet_key_len", c.IPv4SubnetKeyLen),
validatePositive("ipv6_subnet_key_len", c.IPv6SubnetKeyLen),
)
}

View File

@ -0,0 +1,86 @@
package cmd
import (
"path/filepath"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// Safe-browsing and adult-blocking configuration
// safeBrowsingConfig is the configuration for one of the safe browsing filters.
type safeBrowsingConfig struct {
// URL is the URL used to update the filter.
URL *agdhttp.URL `yaml:"url"`
// BlockHost is the hostname with which to respond to any requests that
// match the filter.
//
// TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
// addresses.
BlockHost string `yaml:"block_host"`
// CacheSize is the size of the response cache, in entries.
CacheSize int `yaml:"cache_size"`
// CacheTTL is the TTL of the response cache.
CacheTTL timeutil.Duration `yaml:"cache_ttl"`
// RefreshIvl defines how often AdGuard DNS refreshes the filter.
RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
}
// toInternal converts c to the safe browsing filter configuration for the
// filter storage of the DNS server. c is assumed to be valid.
func (c *safeBrowsingConfig) toInternal(
errColl agd.ErrorCollector,
resolver agdnet.Resolver,
id agd.FilterListID,
cacheDir string,
) (fltConf *filter.HashPrefixConfig, err error) {
hashes, err := hashstorage.New("")
if err != nil {
return nil, err
}
return &filter.HashPrefixConfig{
Hashes: hashes,
URL: netutil.CloneURL(&c.URL.URL),
ErrColl: errColl,
Resolver: resolver,
ID: id,
CachePath: filepath.Join(cacheDir, string(id)),
ReplacementHost: c.BlockHost,
Staleness: c.RefreshIvl.Duration,
CacheTTL: c.CacheTTL.Duration,
CacheSize: c.CacheSize,
}, nil
}
// validate returns an error if the safe browsing filter configuration is
// invalid.
func (c *safeBrowsingConfig) validate() (err error) {
switch {
case c == nil:
return errNilConfig
case c.URL == nil:
return errors.Error("no url")
case c.BlockHost == "":
return errors.Error("no block_host")
case c.CacheSize <= 0:
return newMustBePositiveError("cache_size", c.CacheSize)
case c.CacheTTL.Duration <= 0:
return newMustBePositiveError("cache_ttl", c.CacheTTL)
case c.RefreshIvl.Duration <= 0:
return newMustBePositiveError("refresh_interval", c.RefreshIvl)
default:
return nil
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/netip"
"net/textproto"
"os"
"path"
@ -386,11 +387,8 @@ func (sc staticContent) validate() (err error) {
// staticFile is a single file in a static content mapping.
type staticFile struct {
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
AllowOrigin string `yaml:"allow_origin"`
// ContentType is the value for the HTTP Content-Type header.
ContentType string `yaml:"content_type"`
// Headers contains headers of the HTTP response.
Headers http.Header `yaml:"headers"`
// Content is the file content.
Content string `yaml:"content"`
@ -400,8 +398,12 @@ type staticFile struct {
// assumed to be valid.
func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
file = &websvc.StaticFile{
AllowOrigin: f.AllowOrigin,
ContentType: f.ContentType,
Headers: http.Header{},
}
for k, vs := range f.Headers {
ck := textproto.CanonicalMIMEHeaderKey(k)
file.Headers[ck] = vs
}
file.Content, err = base64.StdEncoding.DecodeString(f.Content)
@ -409,17 +411,20 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
return nil, fmt.Errorf("content: %w", err)
}
// Check Content-Type here as opposed to in validate, because we need
// all keys to be canonicalized first.
if file.Headers.Get(agdhttp.HdrNameContentType) == "" {
return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required")
}
return file, nil
}
// validate returns an error if the static content file is invalid.
func (f *staticFile) validate() (err error) {
switch {
case f == nil:
if f == nil {
return errors.Error("no file")
case f.ContentType == "":
return errors.Error("no content_type")
default:
}
return nil
}
}

View File

@ -222,9 +222,9 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
}
}
// FindENDS0Option searches for the specified EDNS0 option in the OPT resource
// FindEDNS0Option searches for the specified EDNS0 option in the OPT resource
// record of the msg and returns it or nil if it's not present.
func FindENDS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
func FindEDNS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
rr := msg.IsEdns0()
if rr == nil {
return o

View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.19
require (
github.com/AdguardTeam/golibs v0.11.3
github.com/AdguardTeam/golibs v0.11.4
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2

View File

@ -31,8 +31,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=

View File

@ -1,36 +0,0 @@
//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
package dnsserver
import (
"context"
"fmt"
"net"
"github.com/AdguardTeam/golibs/errors"
)
// listenUDP listens to the specified address on UDP.
func listenUDP(_ context.Context, addr string, _ bool) (conn *net.UDPConn, err error) {
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
c, err := net.ListenPacket("udp", addr)
if err != nil {
return nil, err
}
conn, ok := c.(*net.UDPConn)
if !ok {
// TODO(ameshkov): should not happen, consider panic here.
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
return nil, err
}
return conn, nil
}
// listenTCP listens to the specified address on TCP.
func listenTCP(_ context.Context, addr string) (conn net.Listener, err error) {
return net.Listen("tcp", addr)
}

View File

@ -1,63 +0,0 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
package dnsserver
import (
"context"
"fmt"
"net"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/unix"
)
func reuseportControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return opErr
}
// listenUDP listens to the specified address on UDP. If oob flag is set to
// true this method also enables OOB for the listen socket that enables using of
// ReadMsgUDP/WriteMsgUDP. Doing it this way is necessary to correctly discover
// the source address when it listens to 0.0.0.0.
func listenUDP(ctx context.Context, addr string, oob bool) (conn *net.UDPConn, err error) {
defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
var lc net.ListenConfig
lc.Control = reuseportControl
c, err := lc.ListenPacket(ctx, "udp", addr)
if err != nil {
return nil, err
}
conn, ok := c.(*net.UDPConn)
if !ok {
// TODO(ameshkov): should not happen, consider panic here.
err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
return nil, err
}
if oob {
if err = setUDPSocketOptions(conn); err != nil {
return nil, fmt.Errorf("failed to set socket options: %w", err)
}
}
return conn, err
}
// listenTCP listens to the specified address on TCP.
func listenTCP(ctx context.Context, addr string) (l net.Listener, err error) {
var lc net.ListenConfig
lc.Control = reuseportControl
return lc.Listen(ctx, "tcp", addr)
}

View File

@ -0,0 +1,73 @@
// Package netext contains extensions of package net in the Go standard library.
package netext
import (
"context"
"fmt"
"net"
)
// ListenConfig is the interface that allows controlling options of connections
// used by the DNS servers defined in this module. Default ListenConfigs are
// the ones returned by [DefaultListenConfigWithOOB] for plain DNS and
// [DefaultListenConfig] for others.
//
// This interface is modeled after [net.ListenConfig].
type ListenConfig interface {
Listen(ctx context.Context, network, address string) (l net.Listener, err error)
ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error)
}
// DefaultListenConfig returns the default [ListenConfig] used by the servers in
// this module except for the plain-DNS ones, which use
// [DefaultListenConfigWithOOB].
func DefaultListenConfig() (lc ListenConfig) {
return &net.ListenConfig{
Control: defaultListenControl,
}
}
// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the
// plain-DNS servers in this module. The resulting ListenConfig sets additional
// socket flags and processes the control-messages of connections created with
// ListenPacket.
func DefaultListenConfigWithOOB() (lc ListenConfig) {
return &listenConfigOOB{
ListenConfig: net.ListenConfig{
Control: defaultListenControl,
},
}
}
// type check
var _ ListenConfig = (*listenConfigOOB)(nil)
// listenConfigOOB is a wrapper around [net.ListenConfig] with modifications
// that set the control-message options on packet conns.
type listenConfigOOB struct {
net.ListenConfig
}
// ListenPacket implements the [ListenConfig] interface for *listenConfigOOB.
// It sets the control-message flags to receive additional out-of-band data to
// correctly discover the source address when it listens to 0.0.0.0 as well as
// in situations when SO_BINDTODEVICE is used.
//
// network must be "udp", "udp4", or "udp6".
func (lc *listenConfigOOB) ListenPacket(
ctx context.Context,
network string,
address string,
) (c net.PacketConn, err error) {
c, err = lc.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
}
err = setIPOpts(c)
if err != nil {
return nil, fmt.Errorf("setting socket options: %w", err)
}
return wrapPacketConn(c), nil
}

View File

@ -0,0 +1,44 @@
//go:build unix
package netext
import (
"net"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this
// package.
func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return errors.WithDeferred(opErr, err)
}
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
func setIPOpts(c net.PacketConn) (err error) {
// TODO(a.garipov): Returning an error only if both functions return one
// (which is what module dns does as well) seems rather fragile. Depending
// on the OS, the valid errors are ENOPROTOOPT, EINVAL, and maybe others.
// Investigate and make OS-specific versions to make sure we don't miss any
// real errors.
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err4 != nil && err6 != nil {
return errors.List("setting ipv4 and ipv6 options", err4, err6)
}
return nil
}

View File

@ -0,0 +1,67 @@
//go:build unix
package netext_test
import (
"context"
"syscall"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestDefaultListenConfigWithOOB(t *testing.T) {
lc := netext.DefaultListenConfigWithOOB()
require.NotNil(t, lc)
type syscallConner interface {
SyscallConn() (c syscall.RawConn, err error)
}
t.Run("ipv4", func(t *testing.T) {
c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
require.NoError(t, opErr)
// TODO(a.garipov): Rewrite this to use actual expected values for
// each OS.
assert.NotEqual(t, 0, val)
})
require.NoError(t, err)
})
t.Run("ipv6", func(t *testing.T) {
c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0")
if errors.Is(err, syscall.EADDRNOTAVAIL) {
// Some CI machines have IPv6 disabled.
t.Skipf("ipv6 seems to not be supported: %s", err)
}
require.NoError(t, err)
require.NotNil(t, c)
require.Implements(t, (*syscallConner)(nil), c)
sc, err := c.(syscallConner).SyscallConn()
require.NoError(t, err)
err = sc.Control(func(fd uintptr) {
val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
require.NoError(t, opErr)
assert.NotEqual(t, 0, val)
})
require.NoError(t, err)
})
}

View File

@ -0,0 +1,17 @@
//go:build windows
package netext
import (
"net"
"syscall"
)
// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
func setIPOpts(c net.PacketConn) (err error) {
return nil
}

View File

@ -0,0 +1,69 @@
package netext
import (
"net"
)
// PacketSession contains additional information about a packet read from or
// written to a [SessionPacketConn].
type PacketSession interface {
LocalAddr() (addr net.Addr)
RemoteAddr() (addr net.Addr)
}
// NewSimplePacketSession returns a new packet session using the given
// parameters.
func NewSimplePacketSession(laddr, raddr net.Addr) (s PacketSession) {
return &simplePacketSession{
laddr: laddr,
raddr: raddr,
}
}
// simplePacketSession is a simple implementation of the [PacketSession]
// interface.
type simplePacketSession struct {
laddr net.Addr
raddr net.Addr
}
// LocalAddr implements the [PacketSession] interface for *simplePacketSession.
func (s *simplePacketSession) LocalAddr() (addr net.Addr) { return s.laddr }
// RemoteAddr implements the [PacketSession] interface for *simplePacketSession.
func (s *simplePacketSession) RemoteAddr() (addr net.Addr) { return s.raddr }
// SessionPacketConn extends [net.PacketConn] with methods for working with
// packet sessions.
type SessionPacketConn interface {
net.PacketConn
ReadFromSession(b []byte) (n int, s PacketSession, err error)
WriteToSession(b []byte, s PacketSession) (n int, err error)
}
// ReadFromSession is a convenience wrapper for types that may or may not
// implement [SessionPacketConn]. If c implements it, ReadFromSession uses
// c.ReadFromSession. Otherwise, it uses c.ReadFrom and the session is created
// by using [NewSimplePacketSession] with c.LocalAddr.
func ReadFromSession(c net.PacketConn, b []byte) (n int, s PacketSession, err error) {
if spc, ok := c.(SessionPacketConn); ok {
return spc.ReadFromSession(b)
}
n, raddr, err := c.ReadFrom(b)
s = NewSimplePacketSession(c.LocalAddr(), raddr)
return n, s, err
}
// WriteToSession is a convenience wrapper for types that may or may not
// implement [SessionPacketConn]. If c implements it, WriteToSession uses
// c.WriteToSession. Otherwise, it uses c.WriteTo using s.RemoteAddr.
func WriteToSession(c net.PacketConn, b []byte, s PacketSession) (n int, err error) {
if spc, ok := c.(SessionPacketConn); ok {
return spc.WriteToSession(b, s)
}
return c.WriteTo(b, s.RemoteAddr())
}

View File

@ -0,0 +1,159 @@
//go:build linux
// TODO(a.garipov): Technically, we can expand this to other platforms, but that
// would require separate udpOOBSize constants and tests.
package netext
import (
"fmt"
"net"
"sync"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// type check
var _ PacketSession = (*packetSession)(nil)
// packetSession contains additional information about the packet read from a
// UDP connection. It is basically an extended version of [dns.SessionUDP] that
// contains the local address as well.
type packetSession struct {
laddr *net.UDPAddr
raddr *net.UDPAddr
respOOB []byte
}
// LocalAddr implements the [PacketSession] interface for *packetSession.
func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr }
// RemoteAddr implements the [PacketSession] interface for *packetSession.
func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr }
// type check
var _ SessionPacketConn = (*sessionPacketConn)(nil)
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
// that.
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
return &sessionPacketConn{
UDPConn: *c.(*net.UDPConn),
}
}
// sessionPacketConn wraps a UDP connection and implements [SessionPacketConn].
type sessionPacketConn struct {
net.UDPConn
}
// oobPool is the pool of byte slices for out-of-band data.
var oobPool = &sync.Pool{
New: func() (v any) {
b := make([]byte, IPDstOOBSize)
return &b
},
}
// IPDstOOBSize is the required size of the control-message buffer for
// [net.UDPConn.ReadMsgUDP] to read the original destination on Linux.
//
// See packetconn_linux_internal_test.go.
const IPDstOOBSize = 40
// ReadFromSession implements the [SessionPacketConn] interface for *packetConn.
func (c *sessionPacketConn) ReadFromSession(b []byte) (n int, s PacketSession, err error) {
oobPtr := oobPool.Get().(*[]byte)
defer oobPool.Put(oobPtr)
var oobn int
oob := *oobPtr
ps := &packetSession{}
n, oobn, _, ps.raddr, err = c.ReadMsgUDP(b, oob)
if err != nil {
return 0, nil, err
}
var origDstIP net.IP
sockLAddr := c.LocalAddr().(*net.UDPAddr)
origDstIP, err = origLAddr(oob[:oobn])
if err != nil {
return 0, nil, fmt.Errorf("getting original addr: %w", err)
}
if origDstIP == nil {
ps.laddr = sockLAddr
} else {
ps.respOOB = newRespOOB(origDstIP)
ps.laddr = &net.UDPAddr{
IP: origDstIP,
Port: sockLAddr.Port,
}
}
return n, ps, nil
}
// origLAddr returns the original local address from the encoded control-message
// data, if there is one. If not nil, origDst will have a protocol-appropriate
// length.
func origLAddr(oob []byte) (origDst net.IP, err error) {
ctrlMsg6 := &ipv6.ControlMessage{}
err = ctrlMsg6.Parse(oob)
if err != nil {
return nil, fmt.Errorf("parsing ipv6 control message: %w", err)
}
if dst := ctrlMsg6.Dst; dst != nil {
// Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
// IPv4 dst from an IPv6 control-message.
origDst = dst.To4()
if origDst == nil {
origDst = dst
}
return origDst, nil
}
ctrlMsg4 := &ipv4.ControlMessage{}
err = ctrlMsg4.Parse(oob)
if err != nil {
return nil, fmt.Errorf("parsing ipv4 control message: %w", err)
}
return ctrlMsg4.Dst.To4(), nil
}
// newRespOOB returns an encoded control-message for the response for this IP
// address. origDst is expected to have a protocol-appropriate length.
func newRespOOB(origDst net.IP) (b []byte) {
switch len(origDst) {
case net.IPv4len:
cm := &ipv4.ControlMessage{
Src: origDst,
}
return cm.Marshal()
case net.IPv6len:
cm := &ipv6.ControlMessage{
Src: origDst,
}
return cm.Marshal()
default:
return nil
}
}
// WriteToSession implements the [SessionPacketConn] interface for *packetConn.
func (c *sessionPacketConn) WriteToSession(b []byte, s PacketSession) (n int, err error) {
if ps, ok := s.(*packetSession); ok {
n, _, err = c.WriteMsgUDP(b, ps.respOOB, ps.raddr)
return n, err
}
return c.WriteTo(b, s.RemoteAddr())
}

View File

@ -0,0 +1,25 @@
//go:build linux
package netext
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func TestUDPOOBSize(t *testing.T) {
// See https://github.com/miekg/dns/blob/v1.1.50/udp.go.
len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface))
len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface))
max := len4
if len6 > max {
max = len6
}
assert.Equal(t, max, IPDstOOBSize)
}

View File

@ -0,0 +1,140 @@
//go:build linux
package netext_test
import (
"context"
"fmt"
"net"
"os"
"syscall"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Add IPv6 test.
func TestSessionPacketConn(t *testing.T) {
const numTries = 5
// Try the test multiple times to reduce flakiness due to UDP failures.
var success4, success6 bool
for i := 0; i < numTries; i++ {
var isTimeout4, isTimeout6 bool
success4 = t.Run(fmt.Sprintf("ipv4_%d", i), func(t *testing.T) {
isTimeout4 = testSessionPacketConn(t, "udp4", "0.0.0.0:0", net.IP{127, 0, 0, 1})
})
success6 = t.Run(fmt.Sprintf("ipv6_%d", i), func(t *testing.T) {
isTimeout6 = testSessionPacketConn(t, "udp6", "[::]:0", net.IPv6loopback)
})
if success4 && success6 {
break
} else if isTimeout4 || isTimeout6 {
continue
}
t.Fail()
}
if !success4 {
t.Errorf("ipv4 test failed after %d attempts", numTries)
} else if !success6 {
t.Errorf("ipv6 test failed after %d attempts", numTries)
}
}
func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) {
lc := netext.DefaultListenConfigWithOOB()
require.NotNil(t, lc)
c, err := lc.ListenPacket(context.Background(), proto, addr)
if isTimeoutOrFail(t, err) {
return true
}
require.NotNil(t, c)
deadline := time.Now().Add(1 * time.Second)
err = c.SetDeadline(deadline)
require.NoError(t, err)
laddr := testutil.RequireTypeAssert[*net.UDPAddr](t, c.LocalAddr())
require.NotNil(t, laddr)
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: laddr.Port,
}
remoteConn, err := net.DialUDP(proto, nil, dstAddr)
if proto == "udp6" && errors.Is(err, syscall.EADDRNOTAVAIL) {
// Some CI machines have IPv6 disabled.
t.Skipf("ipv6 seems to not be supported: %s", err)
} else if isTimeoutOrFail(t, err) {
return true
}
err = remoteConn.SetDeadline(deadline)
require.NoError(t, err)
msg := []byte("hello")
msgLen := len(msg)
_, err = remoteConn.Write(msg)
if isTimeoutOrFail(t, err) {
return true
}
require.Implements(t, (*netext.SessionPacketConn)(nil), c)
buf := make([]byte, msgLen)
n, sess, err := netext.ReadFromSession(c, buf)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, msgLen, n)
assert.Equal(t, net.Addr(dstAddr), sess.LocalAddr())
assert.Equal(t, remoteConn.LocalAddr(), sess.RemoteAddr())
assert.Equal(t, msg, buf)
respMsg := []byte("world")
respMsgLen := len(respMsg)
n, err = netext.WriteToSession(c, respMsg, sess)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, respMsgLen, n)
buf = make([]byte, respMsgLen)
n, err = remoteConn.Read(buf)
if isTimeoutOrFail(t, err) {
return true
}
assert.Equal(t, respMsgLen, n)
assert.Equal(t, respMsg, buf)
return false
}
// isTimeoutOrFail is a helper function that returns true if err is a timeout
// error and also calls require.NoError on err.
func isTimeoutOrFail(t *testing.T, err error) (ok bool) {
t.Helper()
if err == nil {
return false
}
defer require.NoError(t, err)
return errors.Is(err, os.ErrDeadlineExceeded)
}

View File

@ -0,0 +1,11 @@
//go:build !linux
package netext
import "net"
// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
// that.
func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
return c
}

View File

@ -27,7 +27,8 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
Duration: time.Minute,
Count: rps,
ResponseSizeEstimate: 1000,
RPS: rps,
IPv4RPS: rps,
IPv6RPS: rps,
RefuseANY: true,
})
rlMw, err := ratelimit.NewMiddleware(rl, nil)

View File

@ -37,15 +37,22 @@ type BackOffConfig struct {
// as several responses.
ResponseSizeEstimate int
// RPS is the maximum number of requests per second allowed from a single
// subnet. Any requests above this rate are counted as the client's
// back-off count. RPS must be greater than zero.
RPS int
// IPv4RPS is the maximum number of requests per second allowed from a
// single subnet for IPv4 addresses. Any requests above this rate are
// counted as the client's back-off count. RPS must be greater than
// zero.
IPv4RPS int
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
IPv4SubnetKeyLen int
// IPv6RPS is the maximum number of requests per second allowed from a
// single subnet for IPv6 addresses. Any requests above this rate are
// counted as the client's back-off count. RPS must be greater than
// zero.
IPv6RPS int
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
IPv6SubnetKeyLen int
@ -62,14 +69,18 @@ type BackOffConfig struct {
// current implementation might be too abstract. Middlewares by themselves
// already provide an interface that can be re-implemented by the users.
// Perhaps, another layer of abstraction is unnecessary.
//
// TODO(ameshkov): Consider splitting rps and other properties by protocol
// family.
type BackOff struct {
rpsCounters *cache.Cache
hitCounters *cache.Cache
allowlist Allowlist
count int
rps int
respSzEst int
ipv4rps int
ipv4SubnetKeyLen int
ipv6rps int
ipv6SubnetKeyLen int
refuseANY bool
}
@ -84,9 +95,10 @@ func NewBackOff(c *BackOffConfig) (l *BackOff) {
hitCounters: cache.New(c.Duration, c.Duration),
allowlist: c.Allowlist,
count: c.Count,
rps: c.RPS,
respSzEst: c.ResponseSizeEstimate,
ipv4rps: c.IPv4RPS,
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
ipv6rps: c.IPv6RPS,
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
refuseANY: c.RefuseANY,
}
@ -124,7 +136,12 @@ func (l *BackOff) IsRateLimited(
return true, false, nil
}
return l.hasHitRateLimit(key), false, nil
rps := l.ipv4rps
if ip.Is6() {
rps = l.ipv6rps
}
return l.hasHitRateLimit(key, rps), false, nil
}
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
@ -184,14 +201,15 @@ func (l *BackOff) incBackOff(key string) {
l.hitCounters.SetDefault(key, counter)
}
// hasHitRateLimit checks value for a subnet.
func (l *BackOff) hasHitRateLimit(subnetIPStr string) (ok bool) {
var r *rps
// hasHitRateLimit checks value for a subnet with rps as a maximum number
// requests per second.
func (l *BackOff) hasHitRateLimit(subnetIPStr string, rps int) (ok bool) {
var r *rpsCounter
rVal, ok := l.rpsCounters.Get(subnetIPStr)
if ok {
r = rVal.(*rps)
r = rVal.(*rpsCounter)
} else {
r = newRPS(l.rps)
r = newRPSCounter(rps)
l.rpsCounters.SetDefault(subnetIPStr, r)
}

View File

@ -98,8 +98,9 @@ func TestRatelimitMiddleware(t *testing.T) {
Duration: time.Minute,
Count: rps,
ResponseSizeEstimate: 128,
RPS: rps,
IPv4RPS: rps,
IPv4SubnetKeyLen: 24,
IPv6RPS: rps,
IPv6SubnetKeyLen: 48,
RefuseANY: true,
})

View File

@ -7,17 +7,17 @@ import (
// Requests Per Second Counter
// rps is a single request per seconds counter.
type rps struct {
// rpsCounter is a single request per seconds counter.
type rpsCounter struct {
// mu protects all fields.
mu *sync.Mutex
ring []int64
idx int
}
// newRPS returns a new requests per second counter. n must be above zero.
func newRPS(n int) (r *rps) {
return &rps{
// newRPSCounter returns a new requests per second counter. n must be above zero.
func newRPSCounter(n int) (r *rpsCounter) {
return &rpsCounter{
mu: &sync.Mutex{},
// Add one, because we need to always keep track of the previous
// request. For example, consider n == 1.
@ -28,7 +28,7 @@ func newRPS(n int) (r *rps) {
// add adds another request to the counter. above is true if the request goes
// above the counter value. It is safe for concurrent use.
func (r *rps) add(t time.Time) (above bool) {
func (r *rpsCounter) add(t time.Time) (above bool) {
r.mu.Lock()
defer r.mu.Unlock()

View File

@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@ -31,16 +32,19 @@ type ConfigBase struct {
// Handler is a handler that processes incoming DNS messages.
// If not set, we'll use the default handler that returns error response
// to any query.
Handler Handler
// Metrics is the object we use for collecting performance metrics.
// This field is optional.
Metrics MetricsListener
// BaseContext is a function that should return the base context. If not
// set, we'll be using context.Background().
BaseContext func() (ctx context.Context)
// ListenConfig, when set, is used to set options of connections used by the
// DNS server. If nil, an appropriate default ListenConfig is used.
ListenConfig netext.ListenConfig
}
// ServerBase implements base methods that every Server implementation uses.
@ -67,14 +71,19 @@ type ServerBase struct {
// metrics is the object we use for collecting performance metrics.
metrics MetricsListener
// listenConfig is used to set tcpListener and udpListener.
listenConfig netext.ListenConfig
// Server operation
// --
// will be nil for servers that don't use TCP.
// tcpListener is used to accept new TCP connections. It is nil for servers
// that don't use TCP.
tcpListener net.Listener
// will be nil for servers that don't use UDP.
udpListener *net.UDPConn
// udpListener is used to accept new UDP messages. It is nil for servers
// that don't use UDP.
udpListener net.PacketConn
// Shutdown handling
// --
@ -100,6 +109,7 @@ func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
network: conf.Network,
handler: conf.Handler,
metrics: conf.Metrics,
listenConfig: conf.ListenConfig,
baseContext: conf.BaseContext,
}
@ -347,18 +357,17 @@ func (s *ServerBase) handlePanicAndRecover(ctx context.Context) {
}
}
// listenUDP creates a UDP listener for the ServerBase.addr. This function will
// initialize and start ServerBase.udpListener or return an error. If the TCP
// listener is already running, its address is used instead. The point of this
// is to properly handle the case when port 0 is used as both listeners should
// use the same port, and we only learn it after the first one was started.
// listenUDP initializes and starts s.udpListener using s.addr. If the TCP
// listener is already running, its address is used instead to properly handle
// the case when port 0 is used as both listeners should use the same port, and
// we only learn it after the first one was started.
func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
addr := s.addr
if s.tcpListener != nil {
addr = s.tcpListener.Addr().String()
}
conn, err := listenUDP(ctx, addr, true)
conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr)
if err != nil {
return err
}
@ -368,19 +377,17 @@ func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
return nil
}
// listenTCP creates a TCP listener for the ServerBase.addr. This function will
// initialize and start ServerBase.tcpListener or return an error. If the UDP
// listener is already running, its address is used instead. The point of this
// is to properly handle the case when port 0 is used as both listeners should
// use the same port, and we only learn it after the first one was started.
// listenTCP initializes and starts s.tcpListener using s.addr. If the UDP
// listener is already running, its address is used instead to properly handle
// the case when port 0 is used as both listeners should use the same port, and
// we only learn it after the first one was started.
func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
addr := s.addr
if s.udpListener != nil {
addr = s.udpListener.LocalAddr().String()
}
var l net.Listener
l, err = listenTCP(ctx, addr)
l, err := s.listenConfig.Listen(ctx, "tcp", addr)
if err != nil {
return err
}

View File

@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
@ -110,6 +111,10 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) {
conf.TCPSize = dns.MinMsgSize
}
if conf.ListenConfig == nil {
conf.ListenConfig = netext.DefaultListenConfigWithOOB()
}
s = &ServerDNS{
ServerBase: newServerBase(proto, conf.ConfigBase),
conf: conf,
@ -262,10 +267,21 @@ func makePacketBuffer(size int) (f func() any) {
}
}
// writeDeadlineSetter is an interface for connections that can set write
// deadlines.
type writeDeadlineSetter interface {
SetWriteDeadline(t time.Time) (err error)
}
// withWriteDeadline is a helper that takes the deadline of the context and the
// write timeout into account. It sets the write deadline on conn before
// calling f and resets it once f is done.
func withWriteDeadline(ctx context.Context, writeTimeout time.Duration, conn net.Conn, f func()) {
func withWriteDeadline(
ctx context.Context,
writeTimeout time.Duration,
conn writeDeadlineSetter,
f func(),
) {
dl, hasDeadline := ctx.Deadline()
if !hasDeadline {
dl = time.Now().Add(writeTimeout)

View File

@ -284,8 +284,8 @@ func TestServerDNS_integration_query(t *testing.T) {
tc.expectedTruncated,
)
reqKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
respKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
reqKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
respKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil {
require.NotNil(t, respKeepAliveOpt)
expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100)

View File

@ -2,7 +2,9 @@ package dnsserver
import (
"context"
"net"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2"
@ -38,6 +40,10 @@ var _ Server = (*ServerDNSCrypt)(nil)
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
if conf.ListenConfig == nil {
conf.ListenConfig = netext.DefaultListenConfig()
}
return &ServerDNSCrypt{
ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase),
conf: conf,
@ -140,7 +146,10 @@ func (s *ServerDNSCrypt) startServeUDP(ctx context.Context) {
// TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in
// dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext
// methods for now.
err := s.dnsCryptServer.ServeUDP(s.udpListener)
//
// TODO(ameshkov): Redo the dnscrypt module to make it not depend on
// *net.UDPConn and use net.PacketConn instead.
err := s.dnsCryptServer.ServeUDP(s.udpListener.(*net.UDPConn))
if err != nil {
log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err)
}

View File

@ -4,23 +4,21 @@ import (
"context"
"fmt"
"net"
"runtime"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// serveUDP runs the UDP serving loop.
func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) {
func (s *ServerDNS) serveUDP(ctx context.Context, conn net.PacketConn) (err error) {
defer log.OnCloserError(conn, log.DEBUG)
for s.isStarted() {
var m []byte
var sess *dns.SessionUDP
var sess netext.PacketSession
m, sess, err = s.readUDPMsg(ctx, conn)
if err != nil {
// TODO(ameshkov): Consider the situation where the server is shut
@ -59,15 +57,15 @@ func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error)
func (s *ServerDNS) serveUDPPacket(
ctx context.Context,
m []byte,
conn *net.UDPConn,
udpSession *dns.SessionUDP,
conn net.PacketConn,
sess netext.PacketSession,
) {
defer s.wg.Done()
defer s.handlePanicAndRecover(ctx)
rw := &udpResponseWriter{
udpSession: sess,
conn: conn,
udpSession: udpSession,
writeTimeout: s.conf.WriteTimeout,
}
s.serveDNS(ctx, m, rw)
@ -75,15 +73,18 @@ func (s *ServerDNS) serveUDPPacket(
}
// readUDPMsg reads the next incoming DNS message.
func (s *ServerDNS) readUDPMsg(ctx context.Context, conn *net.UDPConn) (msg []byte, sess *dns.SessionUDP, err error) {
func (s *ServerDNS) readUDPMsg(
ctx context.Context,
conn net.PacketConn,
) (msg []byte, sess netext.PacketSession, err error) {
err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
if err != nil {
return nil, nil, err
}
m := s.getUDPBuffer()
var n int
n, sess, err = dns.ReadFromSessionUDP(conn, m)
n, sess, err := netext.ReadFromSession(conn, m)
if err != nil {
s.putUDPBuffer(m)
@ -120,30 +121,10 @@ func (s *ServerDNS) putUDPBuffer(m []byte) {
s.udpPool.Put(&m)
}
// setUDPSocketOptions is a function that is necessary to be able to use
// dns.ReadFromSessionUDP and dns.WriteToSessionUDP.
// TODO(ameshkov): https://github.com/AdguardTeam/AdGuardHome/issues/2807
func setUDPSocketOptions(conn *net.UDPConn) (err error) {
if runtime.GOOS == "windows" {
return nil
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err4 != nil && err6 != nil {
return errors.List("error while setting NetworkUDP socket options", err4, err6)
}
return nil
}
// udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP.
type udpResponseWriter struct {
udpSession *dns.SessionUDP
conn *net.UDPConn
udpSession netext.PacketSession
conn net.PacketConn
writeTimeout time.Duration
}
@ -152,13 +133,15 @@ var _ ResponseWriter = (*udpResponseWriter)(nil)
// LocalAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) LocalAddr() (addr net.Addr) {
return r.conn.LocalAddr()
// Don't use r.conn.LocalAddr(), since udpSession may actually contain the
// decoded OOB data, including the real local (dst) address.
return r.udpSession.LocalAddr()
}
// RemoteAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) {
// Don't use r.conn.RemoteAddr(), since udpSession actually contains the
// decoded OOB data, including the remote address.
// Don't use r.conn.RemoteAddr(), since udpSession may actually contain the
// decoded OOB data, including the real remote (src) address.
return r.udpSession.RemoteAddr()
}
@ -173,7 +156,7 @@ func (r *udpResponseWriter) WriteMsg(ctx context.Context, req, resp *dns.Msg) (e
}
withWriteDeadline(ctx, r.writeTimeout, r.conn, func() {
_, err = dns.WriteToSessionUDP(r.conn, data, r.udpSession)
_, err = netext.WriteToSession(r.conn, data, r.udpSession)
})
if err != nil {

View File

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@ -94,6 +95,12 @@ var _ Server = (*ServerHTTPS)(nil)
// NewServerHTTPS creates a new ServerHTTPS instance.
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
if conf.ListenConfig == nil {
// Do not enable OOB here, because ListenPacket is only used by HTTP/3,
// and quic-go sets the necessary flags.
conf.ListenConfig = netext.DefaultListenConfig()
}
s = &ServerHTTPS{
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
conf: conf,
@ -500,8 +507,7 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
tlsConf.NextProtos = nextProtoDoH3
}
// Do not enable OOB here as quic-go will do that on its own.
conn, err := listenUDP(ctx, s.addr, false)
conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
if err != nil {
return err
}
@ -518,24 +524,28 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
return nil
}
// httpContextWithClientInfo adds client info to the context.
// httpContextWithClientInfo adds client info to the context. ctx is never nil,
// even when there is an error.
func httpContextWithClientInfo(
parent context.Context,
r *http.Request,
) (ctx context.Context, err error) {
ctx = parent
ci := ClientInfo{
URL: netutil.CloneURL(r.URL),
}
// Due to the quic-go bug we should use Host instead of r.TLS:
// https://github.com/lucas-clemente/quic-go/issues/3596
// Due to the quic-go bug we should use Host instead of r.TLS. See
// https://github.com/quic-go/quic-go/issues/2879 and
// https://github.com/lucas-clemente/quic-go/issues/3596.
//
// TODO(ameshkov): remove this when the bug is fixed in quic-go.
// TODO(ameshkov): Remove when quic-go is fixed, likely in v0.32.0.
if r.ProtoAtLeast(3, 0) {
var host string
host, err = netutil.SplitHost(r.Host)
if err != nil {
return nil, fmt.Errorf("failed to parse Host: %w", err)
return ctx, fmt.Errorf("failed to parse Host: %w", err)
}
ci.TLSServerName = host
@ -543,7 +553,7 @@ func httpContextWithClientInfo(
ci.TLSServerName = strings.ToLower(r.TLS.ServerName)
}
return ContextWithClientInfo(parent, ci), nil
return ContextWithClientInfo(ctx, ci), nil
}
// httpRequestToMsg reads the DNS message from http.Request.

View File

@ -133,7 +133,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
require.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
})
}
@ -338,7 +338,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req)
require.True(t, resp.Response)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}

View File

@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache"
@ -98,6 +99,11 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) {
tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...)
}
if conf.ListenConfig == nil {
// Do not enable OOB here as quic-go will do that on its own.
conf.ListenConfig = netext.DefaultListenConfig()
}
s = &ServerQUIC{
ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase),
conf: conf,
@ -476,8 +482,7 @@ func (s *ServerQUIC) readQUICMsg(
// listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts
// the QUIC listener.
func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) {
// Do not enable OOB here as quic-go will do that on its own.
conn, err := listenUDP(ctx, s.addr, false)
conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
if err != nil {
return err
}

View File

@ -59,7 +59,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
assert.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}()
}
@ -97,7 +97,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response)
require.False(t, resp.Truncated)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}

View File

@ -3,7 +3,6 @@ package dnsserver
import (
"context"
"crypto/tls"
"net"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@ -102,10 +101,9 @@ func (s *ServerTLS) startServeTCP(ctx context.Context) {
}
}
// listenTLS creates the TLS listener for the ServerTLS.addr.
// listenTLS creates the TLS listener for s.addr.
func (s *ServerTLS) listenTLS(ctx context.Context) (err error) {
var l net.Listener
l, err = listenTCP(ctx, s.addr)
l, err := s.listenConfig.Listen(ctx, "tcp", s.addr)
if err != nil {
return err
}

View File

@ -40,7 +40,7 @@ func TestServerTLS_integration_queryTLS(t *testing.T) {
require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}
@ -142,7 +142,7 @@ func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt.
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}
@ -231,7 +231,7 @@ func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response)
require.False(t, resp.Truncated)
paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}

View File

@ -0,0 +1,172 @@
package dnssvc
import (
"context"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTXTExtra is a helper function that converts strs into DNS TXT resource
// records with Name and Txt fields set to first and second values of each
// tuple.
func newTXTExtra(strs [][2]string) (extra []dns.RR) {
for _, v := range strs {
extra = append(extra, &dns.TXT{
Hdr: dns.RR_Header{
Name: v[0],
Rrtype: dns.TypeTXT,
Class: dns.ClassCHAOS,
Ttl: 1,
},
Txt: []string{v[1]},
})
}
return extra
}
func TestService_writeDebugResponse(t *testing.T) {
svc := &Service{messages: &dnsmsg.Constructor{FilteredResponseTTL: time.Second}}
const (
fltListID1 agd.FilterListID = "fl1"
fltListID2 agd.FilterListID = "fl2"
blockRule = "||example.com^"
)
testCases := []struct {
name string
ri *agd.RequestInfo
reqRes filter.Result
respRes filter.Result
wantExtra []dns.RR
}{{
name: "normal",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "request_result_blocked",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "blocked"},
{"req.rule.adguard-dns.com.", "||example.com^"},
{"req.rule-list-id.adguard-dns.com.", "fl1"},
}),
}, {
name: "response_result_blocked",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule},
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "blocked"},
{"resp.rule.adguard-dns.com.", "||example.com^"},
{"resp.rule-list-id.adguard-dns.com.", "fl2"},
}),
}, {
name: "request_result_allowed",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultAllowed{},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "allowed"},
{"req.rule.adguard-dns.com.", ""},
{"req.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "response_result_allowed",
ri: &agd.RequestInfo{},
reqRes: nil,
respRes: &filter.ResultAllowed{},
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"resp.res-type.adguard-dns.com.", "allowed"},
{"resp.rule.adguard-dns.com.", ""},
{"resp.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "request_result_modified",
ri: &agd.RequestInfo{},
reqRes: &filter.ResultModified{
Rule: "||example.com^$dnsrewrite=REFUSED",
},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"req.res-type.adguard-dns.com.", "modified"},
{"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"},
{"req.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "device",
ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"device-id.adguard-dns.com.", "dev1234"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "profile",
ri: &agd.RequestInfo{
Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")},
},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"profile-id.adguard-dns.com.", "some-profile-id"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "location",
ri: &agd.RequestInfo{Location: &agd.Location{Country: agd.CountryAD}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
{"client-ip.adguard-dns.com.", "1.2.3.4"},
{"country.adguard-dns.com.", "AD"},
{"asn.adguard-dns.com.", "0"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)
req := dnsservertest.NewReq("example.com", dns.TypeA, dns.ClassINET)
resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
err := svc.writeDebugResponse(ctx, rw, req, resp, tc.reqRes, tc.respRes)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantExtra, msg.Extra)
})
}
}

View File

@ -195,7 +195,15 @@ func deviceIDFromEDNS(req *dns.Msg) (id agd.DeviceID, err error) {
continue
}
return agd.NewDeviceID(string(o.Data))
id, err = agd.NewDeviceID(string(o.Data))
if err != nil {
return "", &deviceIDError{
err: err,
typ: "edns option",
}
}
return id, nil
}
return "", nil

View File

@ -233,7 +233,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte{},
},
wantDeviceID: "",
wantErrMsg: `bad device id "": too short: got 0 bytes, min 1`,
wantErrMsg: `edns option device id check: bad device id "": ` +
`too short: got 0 bytes, min 1`,
}, {
name: "bad_device_id",
opt: &dns.EDNS0_LOCAL{
@ -241,7 +242,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte("toolongdeviceid"),
},
wantDeviceID: "",
wantErrMsg: `bad device id "toolongdeviceid": too long: got 15 bytes, max 8`,
wantErrMsg: `edns option device id check: bad device id "toolongdeviceid": ` +
`too long: got 15 bytes, max 8`,
}, {
name: "device_id",
opt: &dns.EDNS0_LOCAL{

View File

@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/netip"
"net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@ -79,10 +78,6 @@ type Config struct {
// Upstream defines the upstream server and the group of fallback servers.
Upstream *agd.Upstream
// RootRedirectURL is the URL to which non-DNS and non-Debug HTTP requests
// are redirected.
RootRedirectURL *url.URL
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
@ -119,6 +114,11 @@ type Config struct {
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used.
UseECSCache bool
// ResearchMetrics controls whether research metrics are enabled or not.
// This is a set of metrics that we may need temporary, so its collection is
// controlled by a separate setting.
ResearchMetrics bool
}
// New returns a new DNS service.
@ -150,7 +150,6 @@ func New(c *Config) (svc *Service, err error) {
groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{
messages: c.Messages,
rootRedirectURL: c.RootRedirectURL,
billStat: c.BillStat,
errColl: c.ErrColl,
fltStrg: c.FilterStorage,
@ -158,6 +157,7 @@ func New(c *Config) (svc *Service, err error) {
queryLog: c.QueryLog,
ruleStat: c.RuleStat,
groups: groups,
researchMetrics: c.ResearchMetrics,
}
for i, srvGrp := range c.ServerGroups {
@ -212,7 +212,6 @@ var _ agd.Service = (*Service)(nil)
// Service is the main DNS service of AdGuard DNS.
type Service struct {
messages *dnsmsg.Constructor
rootRedirectURL *url.URL
billStat billstat.Recorder
errColl agd.ErrorCollector
fltStrg filter.Storage
@ -220,6 +219,7 @@ type Service struct {
queryLog querylog.Interface
ruleStat rulestat.Interface
groups []*serverGroup
researchMetrics bool
}
// mustStartListener starts l and panics on any error.

View File

@ -260,14 +260,6 @@ func (mh *initMwHandler) ServeDNS(
// Copy middleware to the local variable to make the code simpler.
mw := mh.mw
if specHdlr, name := mw.noReqInfoSpecialHandler(fqdn, qt, cl); specHdlr != nil {
optlog.Debug1("init mw: got no-req-info special handler %s", name)
// Don't wrap the error, because it's informative enough as is, and
// because if handled is true, the main flow terminates here.
return specHdlr(ctx, rw, req)
}
// Get the request's information, such as GeoIP data and user profiles.
ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
if err != nil {

View File

@ -277,7 +277,7 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
}
}
func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
testCases := []struct {
name string
host string
@ -287,7 +287,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked bool
wantRCode dnsmsg.RCode
}{{
name: "blocked_by_fltgrp",
name: "private_relay_blocked_by_fltgrp",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
@ -295,7 +295,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeNameError,
}, {
name: "no_private_relay_domain",
name: "no_special_domain",
host: "www.example.com",
qtype: dns.TypeA,
fltGrpBlocked: true,
@ -311,7 +311,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "blocked_by_prof",
name: "private_relay_blocked_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
@ -319,7 +319,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: true,
wantRCode: dns.RcodeNameError,
}, {
name: "allowed_by_prof",
name: "private_relay_allowed_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
@ -327,7 +327,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "allowed_by_both",
name: "private_relay_allowed_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
@ -335,13 +335,45 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "blocked_by_both",
name: "private_relay_blocked_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeNameError,
}, {
name: "firefox_canary_allowed_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_allowed_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
name: "firefox_canary_blocked_by_prof",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeRefused,
}, {
name: "firefox_canary_blocked_by_fltgrp",
host: firefoxCanaryHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: false,
profBlocked: false,
wantRCode: dns.RcodeRefused,
}}
for _, tc := range testCases {
@ -368,9 +400,12 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
return nil, nil, agd.DeviceNotFoundError{}
}
return &agd.Profile{
prof := &agd.Profile{
BlockPrivateRelay: tc.profBlocked,
}, &agd.Device{}, nil
BlockFirefoxCanary: tc.profBlocked,
}
return prof, &agd.Device{}, nil
}
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
@ -407,6 +442,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
},
fltGrp: &agd.FilteringGroup{
BlockPrivateRelay: tc.fltGrpBlocked,
BlockFirefoxCanary: tc.fltGrpBlocked,
},
srvGrp: &agd.ServerGroup{},
srv: &agd.Server{
@ -436,7 +472,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
resp := rw.Msg()
require.NotNil(t, resp)
assert.Equal(t, dnsmsg.RCode(resp.Rcode), tc.wantRCode)
assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
})
}
}

View File

@ -101,7 +101,7 @@ func (svc *Service) filterQuery(
) (reqRes, respRes filter.Result) {
start := time.Now()
defer func() {
reportMetrics(ri, reqRes, respRes, time.Since(start))
svc.reportMetrics(ri, reqRes, respRes, time.Since(start))
}()
f := svc.fltStrg.FilterFromContext(ctx, ri)
@ -120,7 +120,7 @@ func (svc *Service) filterQuery(
// reportMetrics extracts filtering metrics data from the context and reports it
// to Prometheus.
func reportMetrics(
func (svc *Service) reportMetrics(
ri *agd.RequestInfo,
reqRes filter.Result,
respRes filter.Result,
@ -139,7 +139,7 @@ func reportMetrics(
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
id, _, _ := filteringData(reqRes, respRes)
id, _, blocked := filteringData(reqRes, respRes)
metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
string(id),
metrics.BoolString(ri.Profile == nil),
@ -147,6 +147,22 @@ func reportMetrics(
metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds())
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
if svc.researchMetrics {
anonymous := ri.Profile == nil
filteringEnabled := ri.FilteringGroup != nil &&
ri.FilteringGroup.RuleListsEnabled &&
len(ri.FilteringGroup.RuleListIDs) > 0
metrics.ReportResearchMetrics(
anonymous,
filteringEnabled,
asn,
ctry,
string(id),
blocked,
)
}
}
// reportf is a helper method for reporting non-critical errors.

View File

@ -0,0 +1,135 @@
package dnssvc
import (
"context"
"crypto/sha256"
"encoding/hex"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
const safeBrowsingHost = "scam.example.net."
var (
ip = net.IP{127, 0, 0, 1}
name = "example.com"
)
sum := sha256.Sum256([]byte(safeBrowsingHost))
hashStr := hex.EncodeToString(sum[:])
hashes, herr := hashstorage.New(safeBrowsingHost)
require.NoError(t, herr)
srv := filter.NewSafeBrowsingServer(hashes, nil)
host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
const ttl = 60
testCases := []struct {
name string
req *dns.Msg
dnscheckResp *dns.Msg
ri *agd.RequestInfo
wantAns []dns.RR
}{{
name: "normal",
req: dnsservertest.CreateMessage(name, dns.TypeA),
dnscheckResp: nil,
ri: &agd.RequestInfo{},
wantAns: []dns.RR{
dnsservertest.NewA(name, 100, ip),
},
}, {
name: "dnscheck",
req: dnsservertest.CreateMessage(name, dns.TypeA),
dnscheckResp: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET),
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)},
Sec: dnsservertest.SectionAnswer,
},
),
ri: &agd.RequestInfo{
Host: name,
QType: dns.TypeA,
QClass: dns.ClassINET,
},
wantAns: []dns.RR{
dnsservertest.NewA(name, ttl, ip),
},
}, {
name: "with_hashes",
req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT},
wantAns: []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{
Name: safeBrowsingHost,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: ttl,
},
Txt: []string{hashStr},
}},
}, {
name: "not_matched",
req: dnsservertest.CreateMessage(name, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT},
wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
tctx := agd.ContextWithRequestInfo(ctx, tc.ri)
dnsCk := &agdtest.DNSCheck{
OnCheck: func(
ctx context.Context,
msg *dns.Msg,
ri *agd.RequestInfo,
) (resp *dns.Msg, err error) {
return tc.dnscheckResp, nil
},
}
mw := &preServiceMw{
messages: &dnsmsg.Constructor{
FilteredResponseTTL: ttl * time.Second,
},
filter: srv,
checker: dnsCk,
}
handler := dnsservertest.DefaultHandler()
h := mw.Wrap(handler)
err := h.ServeDNS(tctx, rw, tc.req)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantAns, msg.Answer)
})
}
}

View File

@ -0,0 +1,264 @@
package dnssvc
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
reqHostname = "example.com."
defaultTTL = 3600
)
func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) {
remoteIP := netip.MustParseAddr("1.2.3.4")
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
respIP := remoteIP.AsSlice()
resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)},
Sec: dnsservertest.SectionAnswer,
})
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
Host: aReq.Question[0].Name,
})
const N = 5
testCases := []struct {
name string
cacheSize int
wantNumReq int
}{{
name: "no_cache",
cacheSize: 0,
wantNumReq: N,
}, {
name: "with_cache",
cacheSize: 100,
wantNumReq: 1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
numReq := 0
handler := dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) error {
numReq++
return rw.WriteMsg(ctx, req, resp)
})
mw := &preUpstreamMw{
db: dnsdb.Empty{},
cacheSize: tc.cacheSize,
}
h := mw.Wrap(handler)
for i := 0; i < N; i++ {
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
err := h.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
}
assert.Equal(t, tc.wantNumReq, numReq)
})
}
}
func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
remoteIP := netip.MustParseAddr("1.2.3.4")
subnet := netip.MustParsePrefix("1.2.3.4/24")
const ctry = agd.CountryAD
resp := dnsservertest.NewResp(
dns.RcodeSuccess,
aReq,
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(
reqHostname,
defaultTTL,
net.IP{1, 2, 3, 4},
)},
Sec: dnsservertest.SectionAnswer,
},
)
numReq := 0
handler := dnsserver.HandlerFunc(
func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
numReq++
return rw.WriteMsg(ctx, req, resp)
},
)
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
ctry agd.Country,
_ agd.ASN,
_ netutil.AddrFamily,
) (n netip.Prefix, err error) {
return netip.MustParsePrefix("1.2.0.0/16"), nil
},
OnData: func(_ string, _ netip.Addr) (_ *agd.Location, _ error) {
panic("not implemented")
},
}
mw := &preUpstreamMw{
db: dnsdb.Empty{},
geoIP: geoIP,
cacheSize: 100,
ecsCacheSize: 100,
useECSCache: true,
}
h := mw.Wrap(handler)
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
Location: &agd.Location{
Country: ctry,
},
ECS: &agd.ECS{
Location: &agd.Location{
Country: ctry,
},
Subnet: subnet,
Scope: 0,
},
Host: aReq.Question[0].Name,
RemoteIP: remoteIP,
})
const N = 5
var nrw *dnsserver.NonWriterResponseWriter
for i := 0; i < N; i++ {
addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
nrw = dnsserver.NewNonWriterResponseWriter(addr, addr)
req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
err := h.ServeDNS(ctx, nrw, req)
require.NoError(t, err)
}
assert.Equal(t, 1, numReq)
}
func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
mw := &preUpstreamMw{db: dnsdb.Empty{}}
req := dnsservertest.CreateMessage("example.com", dns.TypeA)
resp := new(dns.Msg).SetReply(req)
ctx := context.Background()
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{})
const ttl = 100
testCases := []struct {
name string
req *dns.Msg
resp *dns.Msg
wantName string
wantAns []dns.RR
}{{
name: "no_changes",
req: dnsservertest.CreateMessage("example.com.", dns.TypeA),
resp: resp,
wantName: "example.com.",
wantAns: nil,
}, {
name: "android-tls-metric",
req: dnsservertest.CreateMessage(
"12345678-dnsotls-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: resp,
wantName: "00000000-dnsotls-ds.metric.gstatic.com.",
wantAns: nil,
}, {
name: "android-https-metric",
req: dnsservertest.CreateMessage(
"123456-dnsohttps-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: resp,
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
wantAns: nil,
}, {
name: "multiple_answers_metric",
req: dnsservertest.CreateMessage(
"123456-dnsohttps-ds.metric.gstatic.com.",
dns.TypeA,
),
resp: dnsservertest.NewResp(
dns.RcodeSuccess,
req,
dnsservertest.RRSection{
RRs: []dns.RR{dnsservertest.NewA(
"123456-dnsohttps-ds.metric.gstatic.com.",
ttl,
net.IP{1, 2, 3, 4},
), dnsservertest.NewA(
"654321-dnsohttps-ds.metric.gstatic.com.",
ttl,
net.IP{1, 2, 3, 5},
)},
Sec: dnsservertest.SectionAnswer,
},
),
wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
wantAns: []dns.RR{
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}),
dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}),
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := dnsserver.HandlerFunc(func(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) error {
assert.Equal(t, tc.wantName, req.Question[0].Name)
return rw.WriteMsg(ctx, req, tc.resp)
})
h := mw.Wrap(handler)
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
err := h.ServeDNS(ctx, rw, tc.req)
require.NoError(t, err)
msg := rw.Msg()
require.NotNil(t, msg)
assert.Equal(t, tc.wantAns, msg.Answer)
})
}
}

View File

@ -32,57 +32,23 @@ const (
// Resolvers for querying the resolver with unknown or absent name.
ddrDomain = ddrLabel + "." + resolverArpaDomain
// firefoxCanaryFQDN is the fully-qualified canary domain that Firefox uses
// to check if it should use its own DNS-over-HTTPS settings.
// firefoxCanaryHost is the hostname that Firefox uses to check if it
// should use its own DNS-over-HTTPS settings.
//
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
firefoxCanaryFQDN = "use-application-dns.net."
// applePrivateRelayMaskHost and applePrivateRelayMaskH2Host are the
// hostnames that Apple devices use to check if Apple Private Relay can be
// enabled. Returning NXDOMAIN to queries for these domain names blocks
// Apple Private Relay.
//
// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
applePrivateRelayMaskHost = "mask.icloud.com"
applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
firefoxCanaryHost = "use-application-dns.net"
)
// noReqInfoSpecialHandler returns a handler that can handle a special-domain
// query based only on its question type, class, and target, as well as the
// handler's name for debugging.
func (mw *initMw) noReqInfoSpecialHandler(
fqdn string,
qt dnsmsg.RRType,
cl dnsmsg.Class,
) (f dnsserver.HandlerFunc, name string) {
if cl != dns.ClassINET {
return nil, ""
}
if (qt == dns.TypeA || qt == dns.TypeAAAA) && fqdn == firefoxCanaryFQDN {
return mw.handleFirefoxCanary, "firefox"
}
return nil, ""
}
// Firefox Canary
// handleFirefoxCanary checks if the request is for the fully-qualified domain
// name that Firefox uses to check DoH settings and writes a response if needed.
func (mw *initMw) handleFirefoxCanary(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := mw.messages.NewMsgREFUSED(req)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w")
}
// Hostnames that Apple devices use to check if Apple Private Relay can be
// enabled. Returning NXDOMAIN to queries for these domain names blocks Apple
// Private Relay.
//
// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
const (
applePrivateRelayMaskHost = "mask.icloud.com"
applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
applePrivateRelayMaskCanaryHost = "mask-canary.icloud.com"
)
// reqInfoSpecialHandler returns a handler that can handle a special-domain
// query based on the request info, as well as the handler's name for debugging.
@ -99,11 +65,9 @@ func (mw *initMw) reqInfoSpecialHandler(
} else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) {
// A badly formed resolver.arpa subdomain query.
return mw.handleBadResolverARPA, "bad_resolver_arpa"
} else if shouldBlockPrivateRelay(ri) {
return mw.handlePrivateRelay, "apple_private_relay"
}
return nil, ""
return mw.specialDomainHandler(ri)
}
// reqInfoHandlerFunc is an alias for handler functions that additionally accept
@ -222,24 +186,46 @@ func (mw *initMw) handleBadResolverARPA(
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
}
// specialDomainHandler returns a handler that can handle a special-domain
// query for Apple Private Relay or Firefox canary domain based on the request
// or profile information, as well as the handler's name for debugging.
func (mw *initMw) specialDomainHandler(
ri *agd.RequestInfo,
) (f reqInfoHandlerFunc, name string) {
qt := ri.QType
if qt != dns.TypeA && qt != dns.TypeAAAA {
return nil, ""
}
host := ri.Host
prof := ri.Profile
switch host {
case
applePrivateRelayMaskHost,
applePrivateRelayMaskH2Host,
applePrivateRelayMaskCanaryHost:
if shouldBlockPrivateRelay(ri, prof) {
return mw.handlePrivateRelay, "apple_private_relay"
}
case firefoxCanaryHost:
if shouldBlockFirefoxCanary(ri, prof) {
return mw.handleFirefoxCanary, "firefox"
}
default:
// Go on.
}
return nil, ""
}
// Apple Private Relay
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
// Relay check domain and the request information indicates that Apple Private
// Relay should be blocked.
func shouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
qt := ri.QType
host := ri.Host
return (qt == dns.TypeA || qt == dns.TypeAAAA) &&
(host == applePrivateRelayMaskHost || host == applePrivateRelayMaskH2Host) &&
reqInfoShouldBlockPrivateRelay(ri)
}
// reqInfoShouldBlockPrivateRelay returns true if Apple Private Relay queries
// should be blocked based on the request information.
func reqInfoShouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
if prof := ri.Profile; prof != nil {
// Relay check domain and the request information or profile indicates that
// Apple Private Relay should be blocked.
func shouldBlockPrivateRelay(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
if prof != nil {
return prof.BlockPrivateRelay
}
@ -260,3 +246,32 @@ func (mw *initMw) handlePrivateRelay(
return errors.Annotate(err, "writing private relay resp: %w")
}
// Firefox canary domain
// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
// domain and the request information or profile indicates that Firefox canary
// domain should be blocked.
func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
if prof != nil {
return prof.BlockFirefoxCanary
}
return ri.FilteringGroup.BlockFirefoxCanary
}
// handleFirefoxCanary checks if the request is for the fully-qualified domain
// name that Firefox uses to check DoH settings and writes a response if needed.
func (mw *initMw) handleFirefoxCanary(
ctx context.Context,
rw dnsserver.ResponseWriter,
req *dns.Msg,
ri *agd.RequestInfo,
) (err error) {
metrics.DNSSvcFirefoxRequestsTotal.Inc()
resp := mw.messages.NewMsgREFUSED(req)
err = rw.WriteMsg(ctx, req, resp)
return errors.Annotate(err, "writing firefox canary resp: %w")
}

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
"github.com/miekg/dns"
)
@ -93,22 +94,16 @@ func (mw *Middleware) toCacheKey(cr *cacheRequest, hostHasECS bool) (key uint64)
var buf [6]byte
binary.LittleEndian.PutUint16(buf[:2], cr.qType)
binary.LittleEndian.PutUint16(buf[2:4], cr.qClass)
if cr.reqDO {
buf[4] = 1
} else {
buf[4] = 0
}
if cr.subnet.Addr().Is4() {
buf[5] = 0
} else {
buf[5] = 1
}
buf[4] = mathutil.BoolToNumber[byte](cr.reqDO)
addr := cr.subnet.Addr()
buf[5] = mathutil.BoolToNumber[byte](addr.Is6())
_, _ = h.Write(buf[:])
if hostHasECS {
_, _ = h.Write(cr.subnet.Addr().AsSlice())
_, _ = h.Write(addr.AsSlice())
_ = h.WriteByte(byte(cr.subnet.Bits()))
}

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
@ -269,9 +270,5 @@ func getTTLIfLower(r dns.RR, ttl uint32) (res uint32) {
// Go on.
}
if httl := r.Header().Ttl; httl < ttl {
return httl
}
return ttl
return mathutil.Min(r.Header().Ttl, ttl)
}

View File

@ -67,13 +67,13 @@ type SentryReportableError interface {
// TODO(a.garipov): Make sure that we use this approach everywhere.
func isReportable(err error) (ok bool) {
var (
ravErr SentryReportableError
sentryRepErr SentryReportableError
fwdErr *forward.Error
dnsWErr *dnsserver.WriteError
)
if errors.As(err, &ravErr) {
return ravErr.IsSentryReportable()
if errors.As(err, &sentryRepErr) {
return sentryRepErr.IsSentryReportable()
} else if errors.As(err, &fwdErr) {
return isReportableNetwork(fwdErr.Err)
} else if errors.As(err, &dnsWErr) {

View File

@ -21,8 +21,8 @@ var _ Interface = (*compFilter)(nil)
// compFilter is a composite filter based on several types of safe search
// filters and rule lists.
type compFilter struct {
safeBrowsing *hashPrefixFilter
adultBlocking *hashPrefixFilter
safeBrowsing *HashPrefix
adultBlocking *HashPrefix
genSafeSearch *safeSearch
ytSafeSearch *safeSearch

View File

@ -18,10 +18,11 @@ import (
// maxFilterSize is the maximum size of downloaded filters.
const maxFilterSize = 196 * int64(datasize.MB)
// defaultTimeout is the default timeout to use when fetching filter data.
// defaultFilterRefreshTimeout is the default timeout to use when fetching
// filter lists data.
//
// TODO(a.garipov): Consider making timeouts where they are used configurable.
const defaultTimeout = 30 * time.Second
const defaultFilterRefreshTimeout = 180 * time.Second
// defaultResolveTimeout is the default timeout for resolving hosts for safe
// search and safe browsing filters.

View File

@ -181,14 +181,8 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) {
FilterIndexURL: fltsURL,
GeneralSafeSearchRulesURL: ssURL,
YoutubeSafeSearchRulesURL: ssURL,
SafeBrowsing: &filter.HashPrefixConfig{
CacheTTL: 1 * time.Hour,
CacheSize: 100,
},
AdultBlocking: &filter.HashPrefixConfig{
CacheTTL: 1 * time.Hour,
CacheSize: 100,
},
SafeBrowsing: &filter.HashPrefix{},
AdultBlocking: &filter.HashPrefix{},
Now: time.Now,
ErrColl: nil,
Resolver: nil,

View File

@ -3,13 +3,17 @@ package filter
import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
@ -21,13 +25,32 @@ import (
// HashPrefixConfig is the hash-prefix filter configuration structure.
type HashPrefixConfig struct {
// Hashes are the hostname hashes for this filter.
Hashes *HashStorage
Hashes *hashstorage.Storage
// URL is the URL used to update the filter.
URL *url.URL
// ErrColl is used to collect non-critical and rare errors.
ErrColl agd.ErrorCollector
// Resolver is used to resolve hosts for the hash-prefix filter.
Resolver agdnet.Resolver
// ID is the ID of this hash storage for logging and error reporting.
ID agd.FilterListID
// CachePath is the path to the file containing the cached filtered
// hostnames, one per line.
CachePath string
// ReplacementHost is the replacement host for this filter. Queries
// matched by the filter receive a response with the IP addresses of
// this host.
ReplacementHost string
// Staleness is the time after which a file is considered stale.
Staleness time.Duration
// CacheTTL is the time-to-live value used to cache the results of the
// filter.
//
@ -38,57 +61,58 @@ type HashPrefixConfig struct {
CacheSize int
}
// hashPrefixFilter is a filter that matches hosts by their hashes based on
// a hash-prefix table.
type hashPrefixFilter struct {
hashes *HashStorage
// HashPrefix is a filter that matches hosts by their hashes based on a
// hash-prefix table.
type HashPrefix struct {
hashes *hashstorage.Storage
refr *refreshableFilter
resCache *resultcache.Cache[*ResultModified]
resolver agdnet.Resolver
errColl agd.ErrorCollector
repHost string
id agd.FilterListID
}
// newHashPrefixFilter returns a new hash-prefix filter. c must not be nil.
func newHashPrefixFilter(
c *HashPrefixConfig,
resolver agdnet.Resolver,
errColl agd.ErrorCollector,
id agd.FilterListID,
) (f *hashPrefixFilter) {
f = &hashPrefixFilter{
// NewHashPrefix returns a new hash-prefix filter. c must not be nil.
func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) {
f = &HashPrefix{
hashes: c.Hashes,
refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultFilterRefreshTimeout,
}),
url: c.URL,
id: c.ID,
cachePath: c.CachePath,
typ: "hash storage",
staleness: c.Staleness,
},
resCache: resultcache.New[*ResultModified](c.CacheSize),
resolver: resolver,
errColl: errColl,
resolver: c.Resolver,
errColl: c.ErrColl,
repHost: c.ReplacementHost,
id: id,
}
// Patch the refresh function of the hash storage, if there is one, to make
// sure that we clear the result cache during every refresh.
//
// TODO(a.garipov): Create a better way to do that than this spaghetti-style
// patching. Perhaps, lift the entire logic of hash-storage refresh into
// hashPrefixFilter.
if f.hashes != nil && f.hashes.refr != nil {
resetRules := f.hashes.refr.resetRules
f.hashes.refr.resetRules = func(text string) (err error) {
f.resCache.Clear()
f.refr.resetRules = f.resetRules
return resetRules(text)
}
err = f.refresh(context.Background(), true)
if err != nil {
return nil, err
}
return f
return f, nil
}
// id returns the ID of the hash storage.
func (f *HashPrefix) id() (fltID agd.FilterListID) {
return f.refr.id
}
// type check
var _ qtHostFilter = (*hashPrefixFilter)(nil)
var _ qtHostFilter = (*HashPrefix)(nil)
// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
// modifies the response if host matches f.
func (f *hashPrefixFilter) filterReq(
func (f *HashPrefix) filterReq(
ctx context.Context,
ri *agd.RequestInfo,
req *dns.Msg,
@ -115,7 +139,7 @@ func (f *hashPrefixFilter) filterReq(
var matched string
sub := hashableSubdomains(host)
for _, s := range sub {
if f.hashes.hashMatches(s) {
if f.hashes.Matches(s) {
matched = s
break
@ -134,19 +158,19 @@ func (f *hashPrefixFilter) filterReq(
var result *dns.Msg
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
if err != nil {
agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err)
agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err)
result = ri.Messages.NewMsgSERVFAIL(req)
} else {
result, err = ri.Messages.NewIPRespMsg(req, ips...)
if err != nil {
return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err)
return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err)
}
}
rm = &ResultModified{
Msg: result,
List: f.id,
List: f.id(),
Rule: agd.FilterRuleText(matched),
}
@ -161,21 +185,21 @@ func (f *hashPrefixFilter) filterReq(
}
// updateCacheSizeMetrics updates cache size metrics.
func (f *hashPrefixFilter) updateCacheSizeMetrics(size int) {
switch f.id {
func (f *HashPrefix) updateCacheSizeMetrics(size int) {
switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing:
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
case agd.FilterListIDAdultBlocking:
metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size))
default:
panic(fmt.Errorf("unsupported FilterListID %s", f.id))
panic(fmt.Errorf("unsupported FilterListID %s", id))
}
}
// updateCacheLookupsMetrics updates cache lookups metrics.
func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
var hitsMetric, missesMetric prometheus.Counter
switch f.id {
switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing:
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
@ -183,7 +207,7 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
default:
panic(fmt.Errorf("unsupported FilterListID %s", f.id))
panic(fmt.Errorf("unsupported FilterListID %s", id))
}
if hit {
@ -194,12 +218,50 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
}
// name implements the qtHostFilter interface for *hashPrefixFilter.
func (f *hashPrefixFilter) name() (n string) {
func (f *HashPrefix) name() (n string) {
if f == nil {
return ""
}
return string(f.id)
return string(f.id())
}
// type check
var _ agd.Refresher = (*HashPrefix)(nil)
// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter.
func (f *HashPrefix) Refresh(ctx context.Context) (err error) {
return f.refresh(ctx, false)
}
// refresh reloads the hash filter data. If acceptStale is true, do not try to
// load the list from its URL when there is already a file in the cache
// directory, regardless of its staleness.
func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) {
return f.refr.refresh(ctx, acceptStale)
}
// resetRules resets the hosts in the index.
func (f *HashPrefix) resetRules(text string) (err error) {
n, err := f.hashes.Reset(text)
// Report the filter update to prometheus.
promLabels := prometheus.Labels{
"filter": string(f.id()),
}
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
if err != nil {
return err
}
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
metrics.FilterRulesTotal.With(promLabels).Set(float64(n))
log.Info("filter %s: reset %d hosts", f.id(), n)
return nil
}
// subDomainNum defines how many labels should be hashed to match against a hash

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@ -21,8 +22,10 @@ import (
func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
cacheDir := t.TempDir()
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
hosts := "scam.example.net\n"
err := os.WriteFile(cachePath, []byte(hosts), 0o644)
err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
require.NoError(t, err)
hashes, err := hashstorage.New("")
require.NoError(t, err)
errColl := &agdtest.ErrorCollector{
@ -31,46 +34,33 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
},
}
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
URL: nil,
ErrColl: errColl,
ID: agd.FilterListIDSafeBrowsing,
CachePath: cachePath,
RefreshIvl: testRefreshIvl,
})
require.NoError(t, err)
ctx := context.Background()
err = hashes.Start()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return hashes.Shutdown(ctx)
})
// Fake Data
onLookupIP := func(
resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
_ netutil.AddrFamily,
_ string,
) (ips []net.IP, err error) {
return []net.IP{safeBrowsingSafeIP4}, nil
},
}
c := prepareConf(t)
c.SafeBrowsing = &filter.HashPrefixConfig{
c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDSafeBrowsing,
CachePath: cachePath,
ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
}
})
require.NoError(t, err)
c.ErrColl = errColl
c.Resolver = &agdtest.Resolver{
OnLookupIP: onLookupIP,
}
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@ -93,7 +83,7 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
}
ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
ctx = agd.ContextWithRequestInfo(ctx, ri)
ctx := agd.ContextWithRequestInfo(context.Background(), ri)
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)

View File

@ -1,352 +0,0 @@
package filter
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/prometheus/client_golang/prometheus"
)
// Hash Storage
// Hash and hash part length constants.
const (
// hashPrefixLen is the length of the prefix of the hash of the filtered
// hostname.
hashPrefixLen = 2
// HashPrefixEncLen is the encoded length of the hash prefix. Two text
// bytes per one binary byte.
HashPrefixEncLen = hashPrefixLen * 2
// hashLen is the length of the whole hash of the checked hostname.
hashLen = sha256.Size
// hashSuffixLen is the length of the suffix of the hash of the filtered
// hostname.
hashSuffixLen = hashLen - hashPrefixLen
// hashEncLen is the encoded length of the hash. Two text bytes per one
// binary byte.
hashEncLen = hashLen * 2
// legacyHashPrefixEncLen is the encoded length of a legacy hash.
legacyHashPrefixEncLen = 8
)
// hashPrefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of
// a host being checked.
type hashPrefix [hashPrefixLen]byte
// hashSuffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of
// a host being checked.
type hashSuffix [hashSuffixLen]byte
// HashStorage is a storage for hashes of the filtered hostnames.
type HashStorage struct {
// mu protects hashSuffixes.
mu *sync.RWMutex
hashSuffixes map[hashPrefix][]hashSuffix
// refr contains data for refreshing the filter.
refr *refreshableFilter
refrWorker *agd.RefreshWorker
}
// HashStorageConfig is the configuration structure for a *HashStorage.
type HashStorageConfig struct {
// URL is the URL used to update the filter.
URL *url.URL
// ErrColl is used to collect non-critical and rare errors.
ErrColl agd.ErrorCollector
// ID is the ID of this hash storage for logging and error reporting.
ID agd.FilterListID
// CachePath is the path to the file containing the cached filtered
// hostnames, one per line.
CachePath string
// RefreshIvl is the refresh interval.
RefreshIvl time.Duration
}
// NewHashStorage returns a new hash storage containing hashes of all hostnames.
func NewHashStorage(c *HashStorageConfig) (hs *HashStorage, err error) {
hs = &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout,
}),
url: c.URL,
id: c.ID,
cachePath: c.CachePath,
typ: "hash storage",
refreshIvl: c.RefreshIvl,
},
}
// Do not set this in the literal above, since hs is nil there.
hs.refr.resetRules = hs.resetHosts
refrWorker := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
Context: func() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), defaultTimeout)
},
Refresher: hs,
ErrColl: c.ErrColl,
Name: string(c.ID),
Interval: c.RefreshIvl,
RefreshOnShutdown: false,
RoutineLogsAreDebug: false,
})
hs.refrWorker = refrWorker
err = hs.refresh(context.Background(), true)
if err != nil {
return nil, fmt.Errorf("initializing %s: %w", c.ID, err)
}
return hs, nil
}
// hashes returns all hashes starting with the given prefixes, if any. The
// resulting slice shares storage for all underlying strings.
//
// TODO(a.garipov): This currently doesn't take duplicates into account.
func (hs *HashStorage) hashes(hps []hashPrefix) (hashes []string) {
if len(hps) == 0 {
return nil
}
hs.mu.RLock()
defer hs.mu.RUnlock()
// First, calculate the number of hashes to allocate the buffer.
l := 0
for _, hp := range hps {
hashSufs := hs.hashSuffixes[hp]
l += len(hashSufs)
}
// Then, allocate the buffer of the appropriate size and write all hashes
// into one big buffer and slice it into separate strings to make the
// garbage collector's work easier. This assumes that all references to
// this buffer will become unreachable at the same time.
//
// The fact that we iterate over the map twice shouldn't matter, since we
// assume that len(hps) will be below 5 most of the time.
b := &strings.Builder{}
b.Grow(l * hashEncLen)
// Use a buffer and write the resulting buffer into b directly instead of
// using hex.NewEncoder, because that seems to incur a significant
// performance hit.
var buf [hashEncLen]byte
for _, hp := range hps {
hashSufs := hs.hashSuffixes[hp]
for _, suf := range hashSufs {
// Slicing is safe here, since the contents of hp and suf are being
// encoded.
// nolint:looppointer
hex.Encode(buf[:], hp[:])
// nolint:looppointer
hex.Encode(buf[HashPrefixEncLen:], suf[:])
_, _ = b.Write(buf[:])
}
}
s := b.String()
hashes = make([]string, 0, l)
for i := 0; i < l; i++ {
hashes = append(hashes, s[i*hashEncLen:(i+1)*hashEncLen])
}
return hashes
}
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
// concurrent use.
func (hs *HashStorage) loadHashSuffixes(hp hashPrefix) (sufs []hashSuffix, ok bool) {
hs.mu.RLock()
defer hs.mu.RUnlock()
sufs, ok = hs.hashSuffixes[hp]
return sufs, ok
}
// hashMatches returns true if the host matches one of the hashes.
func (hs *HashStorage) hashMatches(host string) (ok bool) {
sum := sha256.Sum256([]byte(host))
hp := hashPrefix{sum[0], sum[1]}
var buf [hashLen]byte
hashSufs, ok := hs.loadHashSuffixes(hp)
if !ok {
return false
}
copy(buf[:], hp[:])
for _, suf := range hashSufs {
// Slicing is safe here, because we make a copy.
// nolint:looppointer
copy(buf[hashPrefixLen:], suf[:])
if buf == sum {
return true
}
}
return false
}
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashPrefix, err error) {
if prefixesStr == "" {
return nil, nil
}
prefixSet := stringutil.NewSet()
prefixStrs := strings.Split(prefixesStr, ".")
for _, s := range prefixStrs {
if len(s) != HashPrefixEncLen {
// Some legacy clients send eight-character hashes instead of
// four-character ones. For now, remove the final four characters.
//
// TODO(a.garipov): Either remove this crutch or support such
// prefixes better.
if len(s) == legacyHashPrefixEncLen {
s = s[:HashPrefixEncLen]
} else {
return nil, fmt.Errorf("bad hash len for %q", s)
}
}
prefixSet.Add(s)
}
hashPrefixes = make([]hashPrefix, prefixSet.Len())
prefixStrs = prefixSet.Values()
for i, s := range prefixStrs {
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
if err != nil {
return nil, fmt.Errorf("bad hash encoding for %q", s)
}
}
return hashPrefixes, nil
}
// type check
var _ agd.Refresher = (*HashStorage)(nil)
// Refresh implements the agd.Refresher interface for *HashStorage. If the file
// at the storage's path exists and its mtime shows that it's still fresh, it
// loads the data from the file. Otherwise, it uses the URL of the storage.
func (hs *HashStorage) Refresh(ctx context.Context) (err error) {
err = hs.refresh(ctx, false)
// Report the filter update to prometheus.
promLabels := prometheus.Labels{
"filter": string(hs.id()),
}
metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
if err == nil {
metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
// Count the total number of hashes loaded.
count := 0
for _, v := range hs.hashSuffixes {
count += len(v)
}
metrics.FilterRulesTotal.With(promLabels).Set(float64(count))
}
return err
}
// id returns the ID of the hash storage.
func (hs *HashStorage) id() (fltID agd.FilterListID) {
return hs.refr.id
}
// refresh reloads the hash filter data. If acceptStale is true, do not try to
// load the list from its URL when there is already a file in the cache
// directory, regardless of its staleness.
func (hs *HashStorage) refresh(ctx context.Context, acceptStale bool) (err error) {
return hs.refr.refresh(ctx, acceptStale)
}
// resetHosts resets the hosts in the index.
func (hs *HashStorage) resetHosts(hostsStr string) (err error) {
hs.mu.Lock()
defer hs.mu.Unlock()
// Delete all elements without allocating a new map to safe space and
// performance.
//
// This is optimized, see https://github.com/golang/go/issues/20138.
for hp := range hs.hashSuffixes {
delete(hs.hashSuffixes, hp)
}
var n int
s := bufio.NewScanner(strings.NewReader(hostsStr))
for s.Scan() {
host := s.Text()
if len(host) == 0 || host[0] == '#' {
continue
}
sum := sha256.Sum256([]byte(host))
hp := hashPrefix{sum[0], sum[1]}
// TODO(a.garipov): Convert to array directly when proposal
// golang/go#46505 is implemented in Go 1.20.
suf := *(*hashSuffix)(sum[hashPrefixLen:])
hs.hashSuffixes[hp] = append(hs.hashSuffixes[hp], suf)
n++
}
err = s.Err()
if err != nil {
return fmt.Errorf("scanning hosts: %w", err)
}
log.Info("filter %s: reset %d hosts", hs.id(), n)
return nil
}
// Start implements the agd.Service interface for *HashStorage.
func (hs *HashStorage) Start() (err error) {
return hs.refrWorker.Start()
}
// Shutdown implements the agd.Service interface for *HashStorage.
func (hs *HashStorage) Shutdown(ctx context.Context) (err error) {
return hs.refrWorker.Shutdown(ctx)
}

View File

@ -0,0 +1,203 @@
// Package hashstorage defines a storage of hashes of domain names used for
// filtering.
package hashstorage
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
)
// Hash and hash part length constants.
const (
// PrefixLen is the length of the prefix of the hash of the filtered
// hostname.
PrefixLen = 2
// PrefixEncLen is the encoded length of the hash prefix. Two text
// bytes per one binary byte.
PrefixEncLen = PrefixLen * 2
// hashLen is the length of the whole hash of the checked hostname.
hashLen = sha256.Size
// suffixLen is the length of the suffix of the hash of the filtered
// hostname.
suffixLen = hashLen - PrefixLen
// hashEncLen is the encoded length of the hash. Two text bytes per one
// binary byte.
hashEncLen = hashLen * 2
)
// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a
// host being checked.
type Prefix [PrefixLen]byte
// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a
// host being checked.
type suffix [suffixLen]byte
// Storage stores hashes of the filtered hostnames. All methods are safe for
// concurrent use.
type Storage struct {
// mu protects hashSuffixes.
mu *sync.RWMutex
hashSuffixes map[Prefix][]suffix
}
// New returns a new hash storage containing hashes of the domain names listed
// in hostnames, one domain name per line.
func New(hostnames string) (s *Storage, err error) {
s = &Storage{
mu: &sync.RWMutex{},
hashSuffixes: map[Prefix][]suffix{},
}
if hostnames != "" {
_, err = s.Reset(hostnames)
if err != nil {
return nil, err
}
}
return s, nil
}
// Hashes returns all hashes starting with the given prefixes, if any. The
// resulting slice shares storage for all underlying strings.
//
// TODO(a.garipov): This currently doesn't take duplicates into account.
func (s *Storage) Hashes(hps []Prefix) (hashes []string) {
if len(hps) == 0 {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
// First, calculate the number of hashes to allocate the buffer.
l := 0
for _, hp := range hps {
hashSufs := s.hashSuffixes[hp]
l += len(hashSufs)
}
// Then, allocate the buffer of the appropriate size and write all hashes
// into one big buffer and slice it into separate strings to make the
// garbage collector's work easier. This assumes that all references to
// this buffer will become unreachable at the same time.
//
// The fact that we iterate over the [s.hashSuffixes] map twice shouldn't
// matter, since we assume that len(hps) will be below 5 most of the time.
b := &strings.Builder{}
b.Grow(l * hashEncLen)
// Use a buffer and write the resulting buffer into b directly instead of
// using hex.NewEncoder, because that seems to incur a significant
// performance hit.
var buf [hashEncLen]byte
for _, hp := range hps {
hashSufs := s.hashSuffixes[hp]
for _, suf := range hashSufs {
// Slicing is safe here, since the contents of hp and suf are being
// encoded.
// nolint:looppointer
hex.Encode(buf[:], hp[:])
// nolint:looppointer
hex.Encode(buf[PrefixEncLen:], suf[:])
_, _ = b.Write(buf[:])
}
}
str := b.String()
hashes = make([]string, 0, l)
for i := 0; i < l; i++ {
hashes = append(hashes, str[i*hashEncLen:(i+1)*hashEncLen])
}
return hashes
}
// Matches returns true if the host matches one of the hashes.
func (s *Storage) Matches(host string) (ok bool) {
sum := sha256.Sum256([]byte(host))
hp := *(*Prefix)(sum[:PrefixLen])
var buf [hashLen]byte
hashSufs, ok := s.loadHashSuffixes(hp)
if !ok {
return false
}
copy(buf[:], hp[:])
for _, suf := range hashSufs {
// Slicing is safe here, because we make a copy.
// nolint:looppointer
copy(buf[PrefixLen:], suf[:])
if buf == sum {
return true
}
}
return false
}
// Reset resets the hosts in the index using the domain names listed in
// hostnames, one domain name per line, and returns the total number of
// processed rules.
func (s *Storage) Reset(hostnames string) (n int, err error) {
s.mu.Lock()
defer s.mu.Unlock()
// Delete all elements without allocating a new map to save space and
// improve performance.
//
// This is optimized, see https://github.com/golang/go/issues/20138.
//
// TODO(a.garipov): Use clear once golang/go#56351 is implemented.
for hp := range s.hashSuffixes {
delete(s.hashSuffixes, hp)
}
sc := bufio.NewScanner(strings.NewReader(hostnames))
for sc.Scan() {
host := sc.Text()
if len(host) == 0 || host[0] == '#' {
continue
}
sum := sha256.Sum256([]byte(host))
hp := *(*Prefix)(sum[:PrefixLen])
// TODO(a.garipov): Here and everywhere, convert to array directly when
// proposal golang/go#46505 is implemented in Go 1.20.
suf := *(*suffix)(sum[PrefixLen:])
s.hashSuffixes[hp] = append(s.hashSuffixes[hp], suf)
n++
}
err = sc.Err()
if err != nil {
return 0, fmt.Errorf("scanning hosts: %w", err)
}
return n, nil
}
// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
// concurrent use. sufs must not be modified.
func (s *Storage) loadHashSuffixes(hp Prefix) (sufs []suffix, ok bool) {
s.mu.RLock()
defer s.mu.RUnlock()
sufs, ok = s.hashSuffixes[hp]
return sufs, ok
}

View File

@ -0,0 +1,142 @@
package hashstorage_test
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strconv"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Common hostnames for tests.
const (
testHost = "porn.example"
otherHost = "otherporn.example"
)
func TestStorage_Hashes(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
h := sha256.Sum256([]byte(testHost))
want := []string{hex.EncodeToString(h[:])}
p := hashstorage.Prefix{h[0], h[1]}
got := s.Hashes([]hashstorage.Prefix{p})
assert.Equal(t, want, got)
wrong := s.Hashes([]hashstorage.Prefix{{}})
assert.Empty(t, wrong)
}
func TestStorage_Matches(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
got := s.Matches(testHost)
assert.True(t, got)
got = s.Matches(otherHost)
assert.False(t, got)
}
func TestStorage_Reset(t *testing.T) {
s, err := hashstorage.New(testHost)
require.NoError(t, err)
n, err := s.Reset(otherHost)
require.NoError(t, err)
assert.Equal(t, 1, n)
h := sha256.Sum256([]byte(otherHost))
want := []string{hex.EncodeToString(h[:])}
p := hashstorage.Prefix{h[0], h[1]}
got := s.Hashes([]hashstorage.Prefix{p})
assert.Equal(t, want, got)
prevHash := sha256.Sum256([]byte(testHost))
prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}})
assert.Empty(t, prev)
}
// Sinks for benchmarks.
var (
errSink error
strsSink []string
)
func BenchmarkStorage_Hashes(b *testing.B) {
const N = 10_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
}
s, err := hashstorage.New(strings.Join(hosts, "\n"))
require.NoError(b, err)
var hashPrefixes []hashstorage.Prefix
for i := 0; i < 4; i++ {
hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]})
}
for n := 1; n <= 4; n++ {
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
hps := hashPrefixes[:n]
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
strsSink = s.Hashes(hps)
}
})
}
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/3-16 13492526 92.22 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/4-16 9542425 109.2 ns/op 0 B/op 0 allocs/op
}
func BenchmarkStorage_ResetHosts(b *testing.B) {
const N = 1_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
}
hostnames := strings.Join(hosts, "\n")
s, err := hashstorage.New(hostnames)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, errSink = s.Reset(hostnames)
}
require.NoError(b, errSink)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op
}

View File

@ -1,91 +0,0 @@
package filter
import (
"fmt"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
var strsSink []string
func BenchmarkHashStorage_Hashes(b *testing.B) {
const N = 10_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
}
// Don't use a constructor, since we don't need the whole contents of the
// storage.
//
// TODO(a.garipov): Think of a better way to do this.
hs := &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{id: "test_filter"},
}
err := hs.resetHosts(strings.Join(hosts, "\n"))
require.NoError(b, err)
var hashPrefixes []hashPrefix
for i := 0; i < 4; i++ {
hashPrefixes = append(hashPrefixes, hashPrefix{hosts[i][0], hosts[i][1]})
}
for n := 1; n <= 4; n++ {
b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
hps := hashPrefixes[:n]
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
strsSink = hs.hashes(hps)
}
})
}
}
func BenchmarkHashStorage_resetHosts(b *testing.B) {
const N = 1_000
var hosts []string
for i := 0; i < N; i++ {
hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
}
// Don't use a constructor, since we don't need the whole contents of the
// storage.
//
// TODO(a.garipov): Think of a better way to do this.
hs := &HashStorage{
mu: &sync.RWMutex{},
hashSuffixes: map[hashPrefix][]hashSuffix{},
refr: &refreshableFilter{id: "test_filter"},
}
// Reset them once to fill the initial map.
err := hs.resetHosts(strings.Join(hosts, "\n"))
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = hs.resetHosts(strings.Join(hosts, "\n"))
}
require.NoError(b, err)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkHashStorage_resetHosts-16 1404 829505 ns/op 58289 B/op 1011 allocs/op
}

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
)
@ -93,13 +94,9 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
// Save on allocations by reusing a buffer.
var buf [3]byte
binary.LittleEndian.PutUint16(buf[:2], qt)
if isAns {
buf[2] = 1
} else {
buf[2] = 0
}
buf[2] = mathutil.BoolToNumber[byte](isAns)
_, _ = h.Write(buf[:3])
_, _ = h.Write(buf[:])
return Key(h.Sum64())
}

View File

@ -45,9 +45,8 @@ type refreshableFilter struct {
// typ is the type of this filter used for logging and error reporting.
typ string
// refreshIvl is the refresh interval for this filter. It is also used to
// check if the cached file is fresh enough.
refreshIvl time.Duration
// staleness is the time after which a file is considered stale.
staleness time.Duration
}
// refresh reloads the filter data. If acceptStale is true, refresh doesn't try
@ -118,7 +117,7 @@ func (f *refreshableFilter) refreshFromFile(
return "", fmt.Errorf("reading filter file stat: %w", err)
}
if mtime := fi.ModTime(); !mtime.Add(f.refreshIvl).After(time.Now()) {
if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) {
return "", nil
}
}

View File

@ -30,31 +30,31 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
name string
cachePath string
wantText string
refreshIvl time.Duration
staleness time.Duration
acceptStale bool
}{{
name: "no_file",
cachePath: "does_not_exist",
wantText: "",
refreshIvl: 0,
staleness: 0,
acceptStale: true,
}, {
name: "file",
cachePath: cachePath,
wantText: defaultText,
refreshIvl: 0,
staleness: 0,
acceptStale: true,
}, {
name: "file_stale",
cachePath: cachePath,
wantText: "",
refreshIvl: -1 * time.Second,
staleness: -1 * time.Second,
acceptStale: false,
}, {
name: "file_stale_accept",
cachePath: cachePath,
wantText: defaultText,
refreshIvl: -1 * time.Second,
staleness: -1 * time.Second,
acceptStale: true,
}}
@ -66,7 +66,7 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
id: "test_filter",
cachePath: tc.cachePath,
typ: "test filter",
refreshIvl: tc.refreshIvl,
staleness: tc.staleness,
}
var text string
@ -166,7 +166,7 @@ func TestRefreshableFilter_RefreshFromURL(t *testing.T) {
id: "test_filter",
cachePath: tc.cachePath,
typ: "test filter",
refreshIvl: testTimeout,
staleness: testTimeout,
}
if tc.expectReq {

View File

@ -72,13 +72,13 @@ func newRuleListFilter(
mu: &sync.RWMutex{},
refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout,
Timeout: defaultFilterRefreshTimeout,
}),
url: l.URL,
id: l.ID,
cachePath: filepath.Join(fileCacheDir, string(l.ID)),
typ: "rule list",
refreshIvl: l.RefreshIvl,
staleness: l.RefreshIvl,
},
urlFilterID: newURLFilterID(),
}

View File

@ -2,9 +2,13 @@ package filter
import (
"context"
"encoding/hex"
"fmt"
"strings"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
)
// Safe Browsing TXT Record Server
@ -14,12 +18,12 @@ import (
//
// TODO(a.garipov): Consider making an interface to simplify testing.
type SafeBrowsingServer struct {
generalHashes *HashStorage
adultBlockingHashes *HashStorage
generalHashes *hashstorage.Storage
adultBlockingHashes *hashstorage.Storage
}
// NewSafeBrowsingServer returns a new safe browsing DNS server.
func NewSafeBrowsingServer(general, adultBlocking *HashStorage) (f *SafeBrowsingServer) {
func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) {
return &SafeBrowsingServer{
generalHashes: general,
adultBlockingHashes: adultBlocking,
@ -48,8 +52,7 @@ func (srv *SafeBrowsingServer) Hashes(
}
var prefixesStr string
var strg *HashStorage
var strg *hashstorage.Storage
if strings.HasSuffix(host, GeneralTXTSuffix) {
prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
strg = srv.generalHashes
@ -67,5 +70,45 @@ func (srv *SafeBrowsingServer) Hashes(
return nil, false, err
}
return strg.hashes(hashPrefixes), true, nil
return strg.Hashes(hashPrefixes), true, nil
}
// legacyPrefixEncLen is the encoded length of a legacy hash.
const legacyPrefixEncLen = 8
// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) {
if prefixesStr == "" {
return nil, nil
}
prefixSet := stringutil.NewSet()
prefixStrs := strings.Split(prefixesStr, ".")
for _, s := range prefixStrs {
if len(s) != hashstorage.PrefixEncLen {
// Some legacy clients send eight-character hashes instead of
// four-character ones. For now, remove the final four characters.
//
// TODO(a.garipov): Either remove this crutch or support such
// prefixes better.
if len(s) == legacyPrefixEncLen {
s = s[:hashstorage.PrefixEncLen]
} else {
return nil, fmt.Errorf("bad hash len for %q", s)
}
}
prefixSet.Add(s)
}
hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len())
prefixStrs = prefixSet.Values()
for i, s := range prefixStrs {
_, err = hex.Decode(hashPrefixes[i][:], []byte(s))
if err != nil {
return nil, fmt.Errorf("bad hash encoding for %q", s)
}
}
return hashPrefixes, nil
}

View File

@ -4,17 +4,11 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -43,41 +37,10 @@ func TestSafeBrowsingServer(t *testing.T) {
hashStrs[i] = hex.EncodeToString(sum[:])
}
// Hash Storage
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
panic("not implemented")
},
}
cacheDir := t.TempDir()
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
err := os.WriteFile(cachePath, []byte(strings.Join(hosts, "\n")), 0o644)
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
URL: &url.URL{},
ErrColl: errColl,
ID: agd.FilterListIDSafeBrowsing,
CachePath: cachePath,
RefreshIvl: testRefreshIvl,
})
hashes, err := hashstorage.New(strings.Join(hosts, "\n"))
require.NoError(t, err)
ctx := context.Background()
err = hashes.Start()
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return hashes.Shutdown(ctx)
})
// Give the storage some time to process the hashes.
//
// TODO(a.garipov): Think of a less stupid way of doing this.
time.Sleep(100 * time.Millisecond)
testCases := []struct {
name string
host string
@ -90,14 +53,14 @@ func TestSafeBrowsingServer(t *testing.T) {
wantMatched: false,
}, {
name: "realistic",
host: hashStrs[realisticHostIdx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{
hashStrs[realisticHostIdx],
},
wantMatched: true,
}, {
name: "same_prefix",
host: hashStrs[samePrefixHost1Idx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{
hashStrs[samePrefixHost1Idx],
hashStrs[samePrefixHost2Idx],

View File

@ -3,9 +3,7 @@ package filter_test
import (
"context"
"net"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
@ -20,7 +18,8 @@ import (
func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
numLookupIP := 0
onLookupIP := func(
resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
fam netutil.AddrFamily,
_ string,
@ -32,33 +31,16 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
}
return []net.IP{safeSearchIPRespIP6}, nil
},
}
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_, err = tmpFile.Write([]byte("bad.example.com\n"))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t)
c.SafeBrowsing.Hashes = hashes
c.AdultBlocking.Hashes = hashes
c.ErrColl = &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
c.Resolver = &agdtest.Resolver{
OnLookupIP: onLookupIP,
}
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@ -80,13 +62,13 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
host: safeSearchIPHost,
wantIP: safeSearchIPRespIP4,
rrtype: dns.TypeA,
wantLookups: 0,
wantLookups: 1,
}, {
name: "ip6",
host: safeSearchIPHost,
wantIP: nil,
wantIP: safeSearchIPRespIP6,
rrtype: dns.TypeAAAA,
wantLookups: 0,
wantLookups: 1,
}, {
name: "host_ip4",
host: safeSearchHost,

View File

@ -48,7 +48,7 @@ func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *servic
return &serviceBlocker{
url: indexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout,
Timeout: defaultFilterRefreshTimeout,
}),
mu: &sync.RWMutex{},
errColl: errColl,

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/bluele/gcache"
"github.com/prometheus/client_golang/prometheus"
)
@ -55,10 +54,10 @@ type DefaultStorage struct {
services *serviceBlocker
// safeBrowsing is the general safe browsing filter.
safeBrowsing *hashPrefixFilter
safeBrowsing *HashPrefix
// adultBlocking is the adult content blocking safe browsing filter.
adultBlocking *hashPrefixFilter
adultBlocking *HashPrefix
// genSafeSearch is the general safe search filter.
genSafeSearch *safeSearch
@ -111,11 +110,11 @@ type DefaultStorageConfig struct {
// SafeBrowsing is the configuration for the default safe browsing filter.
// It must not be nil.
SafeBrowsing *HashPrefixConfig
SafeBrowsing *HashPrefix
// AdultBlocking is the configuration for the adult content blocking safe
// browsing filter. It must not be nil.
AdultBlocking *HashPrefixConfig
AdultBlocking *HashPrefix
// Now is a function that returns current time.
Now func() (now time.Time)
@ -156,25 +155,8 @@ type DefaultStorageConfig struct {
// NewDefaultStorage returns a new filter storage. c must not be nil.
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
// TODO(ameshkov): Consider making configurable.
resolver := agdnet.NewCachingResolver(c.Resolver, 1*timeutil.Day)
safeBrowsing := newHashPrefixFilter(
c.SafeBrowsing,
resolver,
c.ErrColl,
agd.FilterListIDSafeBrowsing,
)
adultBlocking := newHashPrefixFilter(
c.AdultBlocking,
resolver,
c.ErrColl,
agd.FilterListIDAdultBlocking,
)
genSafeSearch := newSafeSearch(&safeSearchConfig{
resolver: resolver,
resolver: c.Resolver,
errColl: c.ErrColl,
list: &agd.FilterList{
URL: c.GeneralSafeSearchRulesURL,
@ -187,7 +169,7 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
})
ytSafeSearch := newSafeSearch(&safeSearchConfig{
resolver: resolver,
resolver: c.Resolver,
errColl: c.ErrColl,
list: &agd.FilterList{
URL: c.YoutubeSafeSearchRulesURL,
@ -203,11 +185,11 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
mu: &sync.RWMutex{},
url: c.FilterIndexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{
Timeout: defaultTimeout,
Timeout: defaultFilterRefreshTimeout,
}),
services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
safeBrowsing: safeBrowsing,
adultBlocking: adultBlocking,
safeBrowsing: c.SafeBrowsing,
adultBlocking: c.AdultBlocking,
genSafeSearch: genSafeSearch,
ytSafeSearch: ytSafeSearch,
now: c.Now,
@ -322,7 +304,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b
func (s *DefaultStorage) safeBrowsingForProfile(
p *agd.Profile,
parentalEnabled bool,
) (safeBrowsing, adultBlocking *hashPrefixFilter) {
) (safeBrowsing, adultBlocking *HashPrefix) {
if p.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}
@ -359,7 +341,7 @@ func (s *DefaultStorage) safeSearchForProfile(
// in the filtering group. g must not be nil.
func (s *DefaultStorage) safeBrowsingForGroup(
g *agd.FilteringGroup,
) (safeBrowsing, adultBlocking *hashPrefixFilter) {
) (safeBrowsing, adultBlocking *HashPrefix) {
if g.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@ -124,34 +125,11 @@ func TestStorage_FilterFromContext(t *testing.T) {
}
func TestStorage_FilterFromContext_customAllow(t *testing.T) {
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t)
c.SafeBrowsing = &filter.HashPrefixConfig{
Hashes: hashes,
ReplacementHost: safeBrowsingSafeHost,
CacheTTL: 10 * time.Second,
CacheSize: 100,
}
c.ErrColl = &agdtest.ErrorCollector{
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
c.Resolver = &agdtest.Resolver{
resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
_ netutil.AddrFamily,
@ -161,6 +139,35 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
},
}
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := hashstorage.New(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDSafeBrowsing,
CachePath: tmpFile.Name(),
ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
})
require.NoError(t, err)
c.ErrColl = errColl
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@ -211,39 +218,11 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
// parental protection from 11:00:00 until 12:59:59.
nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC)
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
CachePath: tmpFile.Name(),
RefreshIvl: 1 * time.Hour,
})
require.NoError(t, err)
c := prepareConf(t)
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
c.AdultBlocking = &filter.HashPrefixConfig{
Hashes: hashes,
ReplacementHost: safeBrowsingSafeHost,
CacheTTL: 10 * time.Second,
CacheSize: 100,
}
c.Now = func() (t time.Time) {
return nowTime
}
c.ErrColl = &agdtest.ErrorCollector{
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
c.Resolver = &agdtest.Resolver{
resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
_ netutil.AddrFamily,
@ -253,6 +232,40 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
},
}
// Initialize the hashes file and use it with the storage.
tmpFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
hashes, err := hashstorage.New(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
ID: agd.FilterListIDAdultBlocking,
CachePath: tmpFile.Name(),
ReplacementHost: safeBrowsingSafeHost,
Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
})
require.NoError(t, err)
c.Now = func() (t time.Time) {
return nowTime
}
c.ErrColl = errColl
c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)

View File

@ -20,9 +20,6 @@ import (
// FileConfig is the file-based GeoIP configuration structure.
type FileConfig struct {
// ErrColl is the error collector that is used to report errors.
ErrColl agd.ErrorCollector
// ASNPath is the path to the GeoIP database of ASNs.
ASNPath string
@ -39,8 +36,6 @@ type FileConfig struct {
// File is a file implementation of [geoip.Interface].
type File struct {
errColl agd.ErrorCollector
// mu protects asn, country, country subnet maps, and caches against
// simultaneous access during a refresh.
mu *sync.RWMutex
@ -77,8 +72,6 @@ type asnSubnets map[agd.ASN]netip.Prefix
// NewFile returns a new GeoIP database that reads information from a file.
func NewFile(c *FileConfig) (f *File, err error) {
f = &File{
errColl: c.ErrColl,
mu: &sync.RWMutex{},
asnPath: c.ASNPath,

View File

@ -1,12 +1,10 @@
package geoip_test
import (
"context"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
@ -14,12 +12,7 @@ import (
)
func TestFile_Data(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@ -42,12 +35,7 @@ func TestFile_Data(t *testing.T) {
}
func TestFile_Data_hostCache(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 1,
@ -74,12 +62,7 @@ func TestFile_Data_hostCache(t *testing.T) {
}
func TestFile_SubnetByLocation(t *testing.T) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@ -106,12 +89,7 @@ var locSink *agd.Location
var errSink error
func BenchmarkFile_Data(b *testing.B) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@ -162,12 +140,7 @@ func BenchmarkFile_Data(b *testing.B) {
var fileSink *geoip.File
func BenchmarkNewFile(b *testing.B) {
var ec agd.ErrorCollector = &agdtest.ErrorCollector{
OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
}
conf := &geoip.FileConfig{
ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,

View File

@ -11,7 +11,7 @@ var DNSSvcRequestByCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_country_total",
Namespace: namespace,
Subsystem: subsystemDNSSvc,
Help: "The number of filtered DNS requests labeled by country and continent.",
Help: "The number of processed DNS requests labeled by country and continent.",
}, []string{"continent", "country"})
// DNSSvcRequestByASNTotal is a counter with the total number of queries
@ -20,13 +20,14 @@ var DNSSvcRequestByASNTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_asn_total",
Namespace: namespace,
Subsystem: subsystemDNSSvc,
Help: "The number of filtered DNS requests labeled by country and ASN.",
Help: "The number of processed DNS requests labeled by country and ASN.",
}, []string{"country", "asn"})
// DNSSvcRequestByFilterTotal is a counter with the total number of queries
// processed labeled by filter. "filter" contains the ID of the filter list
// applied. "anonymous" is "0" if the request is from a AdGuard DNS customer,
// otherwise it is "1".
// processed labeled by filter. Processed could mean that the request was
// blocked or unblocked by a rule from that filter list. "filter" contains
// the ID of the filter list applied. "anonymous" is "0" if the request is
// from a AdGuard DNS customer, otherwise it is "1".
var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_filter_total",
Namespace: namespace,

View File

@ -26,6 +26,8 @@ const (
subsystemQueryLog = "querylog"
subsystemRuleStat = "rulestat"
subsystemTLS = "tls"
subsystemResearch = "research"
subsystemWebSvc = "websvc"
)
// SetUpGauge signals that the server has been started.

View File

@ -0,0 +1,61 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// ResearchRequestsPerCountryTotal counts the total number of queries per
// country from anonymous users.
var ResearchRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "requests_per_country_total",
Namespace: namespace,
Subsystem: subsystemResearch,
Help: "The total number of DNS queries per country from anonymous users.",
}, []string{"country"})
// ResearchBlockedRequestsPerCountryTotal counts the number of blocked queries
// per country from anonymous users.
var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "blocked_per_country_total",
Namespace: namespace,
Subsystem: subsystemResearch,
Help: "The number of blocked DNS queries per country from anonymous users.",
}, []string{"filter", "country"})
// ReportResearchMetrics reports metrics to prometheus that we may need to
// conduct researches.
//
// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved.
func ReportResearchMetrics(
anonymous bool,
filteringEnabled bool,
asn string,
ctry string,
filterID string,
blocked bool,
) {
// The current research metrics only count queries that come to public
// DNS servers where filtering is enabled.
if !filteringEnabled || !anonymous {
return
}
// Ignore AdGuard ASN specifically in order to avoid counting queries that
// come from the monitoring. This part is ugly, but since these metrics
// are a one-time deal, this is acceptable.
//
// TODO(ameshkov): think of a better way later if we need to do that again.
if asn == "212772" {
return
}
if blocked {
ResearchBlockedRequestsPerCountryTotal.WithLabelValues(
filterID,
ctry,
).Inc()
}
ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
}

View File

@ -86,7 +86,7 @@ func (c *userCounter) record(now time.Time, ip netip.Addr, syncUpdate bool) {
prevMinuteCounter := c.currentMinuteCounter
c.currentMinute = minuteOfTheDay
c.currentMinuteCounter = hyperloglog.New()
c.currentMinuteCounter = newHyperLogLog()
// If this is the first iteration and prevMinute is -1, don't update the
// counters, since there are none.
@ -123,7 +123,7 @@ func (c *userCounter) updateCounters(prevMinute int, prevCounter *hyperloglog.Sk
// estimate uses HyperLogLog counters to estimate the hourly and daily users
// count, starting with the minute of the day m.
func (c *userCounter) estimate(m int) (hourly, daily uint64) {
hourlyCounter, dailyCounter := hyperloglog.New(), hyperloglog.New()
hourlyCounter, dailyCounter := newHyperLogLog(), newHyperLogLog()
// Go through all minutes in a day while decreasing the current minute m.
// Decreasing m, as opposed to increasing it or using i as the minute, is
@ -168,6 +168,11 @@ func decrMod(n, m int) (res int) {
return n - 1
}
// newHyperLogLog creates a new instance of hyperloglog.Sketch.
func newHyperLogLog() (sk *hyperloglog.Sketch) {
return hyperloglog.New16()
}
// defaultUserCounter is the main user statistics counter.
var defaultUserCounter = newUserCounter()

View File

@ -0,0 +1,69 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
webSvcRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "websvc_requests_total",
Namespace: namespace,
Subsystem: subsystemWebSvc,
Help: "The number of DNS requests for websvc.",
}, []string{"kind"})
// WebSvcError404RequestsTotal is a counter with total number of
// requests with error 404.
WebSvcError404RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "error404",
})
// WebSvcError500RequestsTotal is a counter with total number of
// requests with error 500.
WebSvcError500RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "error500",
})
// WebSvcStaticContentRequestsTotal is a counter with total number of
// requests for static content.
WebSvcStaticContentRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "static_content",
})
// WebSvcDNSCheckTestRequestsTotal is a counter with total number of
// requests for dnscheck_test.
WebSvcDNSCheckTestRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "dnscheck_test",
})
// WebSvcRobotsTxtRequestsTotal is a counter with total number of
// requests for robots_txt.
WebSvcRobotsTxtRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "robots_txt",
})
// WebSvcRootRedirectRequestsTotal is a counter with total number of
// root redirected requests.
WebSvcRootRedirectRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "root_redirect",
})
// WebSvcLinkedIPProxyRequestsTotal is a counter with total number of
// requests with linked ip.
WebSvcLinkedIPProxyRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "linkip",
})
// WebSvcAdultBlockingPageRequestsTotal is a counter with total number
// of requests for adult blocking page.
WebSvcAdultBlockingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "adult_blocking_page",
})
// WebSvcSafeBrowsingPageRequestsTotal is a counter with total number
// of requests for safe browsing page.
WebSvcSafeBrowsingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
"kind": "safe_browsing_page",
})
)

View File

@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
)
// FileSystemConfig is the configuration of the file system query log.
@ -71,11 +72,6 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
metrics.QueryLogItemsCount.Inc()
}()
var dnssec uint8 = 0
if e.DNSSEC {
dnssec = 1
}
entBuf := l.bufferPool.Get().(*entryBuffer)
defer l.bufferPool.Put(entBuf)
entBuf.buf.Reset()
@ -94,7 +90,7 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
ClientASN: e.ClientASN,
Elapsed: e.Elapsed,
RequestType: e.RequestType,
DNSSEC: dnssec,
DNSSEC: mathutil.BoolToNumber[uint8](e.DNSSEC),
Protocol: e.Protocol,
ResultCode: c,
ResponseCode: e.ResponseCode,

View File

@ -7,6 +7,7 @@ import (
"os"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@ -58,12 +59,16 @@ func (svc *Service) processRec(
body = svc.error404
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
}
metrics.WebSvcError404RequestsTotal.Inc()
case http.StatusInternalServerError:
action = "writing 500"
if len(svc.error500) != 0 {
body = svc.error500
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
}
metrics.WebSvcError500RequestsTotal.Inc()
default:
action = "writing response"
for k, v := range rec.Header() {
@ -87,6 +92,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/dnscheck/test":
svc.dnsCheck.ServeHTTP(w, r)
metrics.WebSvcDNSCheckTestRequestsTotal.Inc()
case "/robots.txt":
serveRobotsDisallow(w.Header(), w, "handler")
case "/":
@ -94,6 +101,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
} else {
http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound)
metrics.WebSvcRootRedirectRequestsTotal.Inc()
}
default:
http.NotFound(w, r)
@ -101,7 +110,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
}
// safeBrowsingHandler returns an HTTP handler serving the block page from the
// blockPagePath. name is used for logging.
// blockPagePath. name is used for logging and metrics.
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
f := func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
@ -122,6 +131,15 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
}
switch name {
case adultBlockingName:
metrics.WebSvcAdultBlockingPageRequestsTotal.Inc()
case safeBrowsingName:
metrics.WebSvcSafeBrowsingPageRequestsTotal.Inc()
default:
// Go on.
}
}
}
@ -140,10 +158,11 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
return false
}
if f.AllowOrigin != "" {
w.Header().Set(agdhttp.HdrNameAccessControlAllowOrigin, f.AllowOrigin)
h := w.Header()
for k, v := range f.Headers {
h[k] = v
}
w.Header().Set(agdhttp.HdrNameContentType, f.ContentType)
w.WriteHeader(http.StatusOK)
_, err := w.Write(f.Content)
@ -151,16 +170,15 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
logErrorByType(err, "websvc: static content: writing %s: %s", p, err)
}
metrics.WebSvcStaticContentRequestsTotal.Inc()
return true
}
// StaticFile is a single file in a StaticFS.
type StaticFile struct {
// AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
AllowOrigin string
// ContentType is the value for the HTTP Content-Type header.
ContentType string
// Headers contains headers of the HTTP response.
Headers http.Header
// Content is the file content.
Content []byte
@ -174,6 +192,8 @@ func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) {
if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
}
metrics.WebSvcRobotsTxtRequestsTotal.Inc()
}
// logErrorByType writes err to the error log, unless err is a network error or

View File

@ -31,8 +31,10 @@ func TestService_ServeHTTP(t *testing.T) {
staticContent := map[string]*websvc.StaticFile{
"/favicon.ico": {
ContentType: "image/x-icon",
Content: []byte{},
Headers: http.Header{
agdhttp.HdrNameContentType: []string{"image/x-icon"},
},
},
}
@ -56,22 +58,33 @@ func TestService_ServeHTTP(t *testing.T) {
})
// DNSCheck path.
assertPathResponse(t, svc, "/dnscheck/test", http.StatusOK)
assertResponse(t, svc, "/dnscheck/test", http.StatusOK)
// Static content path.
assertPathResponse(t, svc, "/favicon.ico", http.StatusOK)
// Static content path with headers.
h := http.Header{
agdhttp.HdrNameContentType: []string{"image/x-icon"},
agdhttp.HdrNameServer: []string{"AdGuardDNS/"},
}
assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h)
// Robots path.
assertPathResponse(t, svc, "/robots.txt", http.StatusOK)
assertResponse(t, svc, "/robots.txt", http.StatusOK)
// Root redirect path.
assertPathResponse(t, svc, "/", http.StatusFound)
assertResponse(t, svc, "/", http.StatusFound)
// Other path.
assertPathResponse(t, svc, "/other", http.StatusNotFound)
assertResponse(t, svc, "/other", http.StatusNotFound)
}
func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCode int) {
// assertResponse is a helper function that checks status code of HTTP
// response.
func assertResponse(
t *testing.T,
svc *websvc.Service,
path string,
statusCode int,
) (rw *httptest.ResponseRecorder) {
t.Helper()
r := httptest.NewRequest(http.MethodGet, (&url.URL{
@ -79,9 +92,27 @@ func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCo
Host: "127.0.0.1",
Path: path,
}).String(), strings.NewReader(""))
rw := httptest.NewRecorder()
rw = httptest.NewRecorder()
svc.ServeHTTP(rw, r)
assert.Equal(t, statusCode, rw.Code)
assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
return rw
}
// assertResponseWithHeaders is a helper function that checks status code and
// headers of HTTP response.
func assertResponseWithHeaders(
t *testing.T,
svc *websvc.Service,
path string,
statusCode int,
header http.Header,
) {
t.Helper()
rw := assertResponse(t, svc, path, statusCode)
assert.Equal(t, header, rw.Header())
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@ -148,6 +149,8 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
prx.httpProxy.ServeHTTP(w, r)
metrics.WebSvcLinkedIPProxyRequestsTotal.Inc()
} else if r.URL.Path == "/robots.txt" {
serveRobotsDisallow(respHdr, w, prx.logPrefix)
} else {

View File

@ -118,8 +118,8 @@ func New(c *Config) (svc *Service) {
error404: c.Error404,
error500: c.Error500,
adultBlocking: blockPageServers(c.AdultBlocking, "adult blocking", c.Timeout),
safeBrowsing: blockPageServers(c.SafeBrowsing, "safe browsing", c.Timeout),
adultBlocking: blockPageServers(c.AdultBlocking, adultBlockingName, c.Timeout),
safeBrowsing: blockPageServers(c.SafeBrowsing, safeBrowsingName, c.Timeout),
}
if c.RootRedirectURL != nil {
@ -164,6 +164,12 @@ func New(c *Config) (svc *Service) {
return svc
}
// Names for safeBrowsingHandler for logging and metrics.
const (
safeBrowsingName = "safe browsing"
adultBlockingName = "adult blocking"
)
// blockPageServers is a helper function that converts a *BlockPageServer into
// HTTP servers.
func blockPageServers(

View File

@ -117,7 +117,9 @@ underscores() {
git ls-files '*_*.go'\
| grep -F\
-e '_generate.go'\
-e '_linux.go'\
-e '_noreuseport.go'\
-e '_others.go'\
-e '_reuseport.go'\
-e '_test.go'\
-e '_unix.go'\