diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3763adc..a222fcf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,99 @@ The format is **not** based on [Keep a Changelog][kec], since the project
+## AGDNS-916 / Build 456
+
+ * `ratelimit` now defines rate of requests per second for IPv4 and IPv6
+ addresses separately. So replace this:
+
+ ```yaml
+ ratelimit:
+ rps: 30
+ ipv4_subnet_key_len: 24
+ ipv6_subnet_key_len: 48
+ ```
+
+ with this:
+
+ ```yaml
+ ratelimit:
+ ipv4:
+ rps: 30
+ subnet_key_len: 24
+ ipv6:
+ rps: 300
+ subnet_key_len: 48
+ ```
+
+
+
+## AGDNS-907 / Build 449
+
+ * The objects within the `filtering_groups` have a new property,
+ `block_firefox_canary`. So replace this:
+
+ ```yaml
+ filtering_groups:
+ -
+ id: default
+ # …
+ ```
+
+ with this:
+
+ ```yaml
+ filtering_groups:
+ -
+ id: default
+ # …
+ block_firefox_canary: true
+ ```
+
+ The recommended default value is `true`.
+
+
+
+## AGDNS-1308 / Build 447
+
+ * There is now a new env variable `RESEARCH_METRICS` that controls whether
+ collecting research metrics is enabled or not. Also, the first research
+ metric is added: `dns_research_blocked_per_country_total`, it counts the
+ number of blocked requests per country. Its default value is `0`, i.e.
+ research metrics collection is disabled by default.
+
+
+
+## AGDNS-1051 / Build 443
+
+ * There are two changes in the keys of the `static_content` map. Firstly,
+ properties `allow_origin` and `content_type` are removed. Secondly, a new
+ property, called `headers`, is added. So replace this:
+
+ ```yaml
+ static_content:
+ '/favicon.ico':
+ # …
+ allow_origin: '*'
+ content_type: 'image/x-icon'
+ ```
+
+ with this:
+
+ ```yaml
+ static_content:
+ '/favicon.ico':
+ # …
+ headers:
+ 'Access-Control-Allow-Origin':
+ - '*'
+ 'Content-Type':
+ - 'image/x-icon'
+ ```
+
+ Adjust or add the values, if necessary.
+
+
+
## AGDNS-1278 / Build 423
* The object `filters` has two new properties, `rule_list_cache_size` and
diff --git a/config.dist.yml b/config.dist.yml
index 14a99f2..9550cfc 100644
--- a/config.dist.yml
+++ b/config.dist.yml
@@ -8,8 +8,21 @@ ratelimit:
refuseany: true
# If response is larger than this, it is counted as several responses.
response_size_estimate: 1KB
- # Rate of requests per second for one subnet.
rps: 30
+ # Rate limit options for IPv4 addresses.
+ ipv4:
+ # Rate of requests per second for one subnet for IPv4 addresses.
+ rps: 30
+ # The lengths of the subnet prefixes used to calculate rate limiter
+ # bucket keys for IPv4 addresses.
+ subnet_key_len: 24
+ # Rate limit options for IPv6 addresses.
+ ipv6:
+ # Rate of requests per second for one subnet for IPv6 addresses.
+ rps: 300
+ # The lengths of the subnet prefixes used to calculate rate limiter
+ # bucket keys for IPv6 addresses.
+ subnet_key_len: 48
# The time during which to count the number of times a client has hit the
# rate limit for a back off.
#
@@ -21,10 +34,6 @@ ratelimit:
# How much a client that has hit the rate limit too often stays in the back
# off.
back_off_duration: 30m
- # The lengths of the subnet prefixes used to calculate rate limiter bucket
- # keys for IPv4 and IPv6 addresses correspondingly.
- ipv4_subnet_key_len: 24
- ipv6_subnet_key_len: 48
# Configuration for the allowlist.
allowlist:
@@ -156,9 +165,12 @@ web:
# servers. Paths must not cross the ones used by the DNS-over-HTTPS server.
static_content:
'/favicon.ico':
- allow_origin: '*'
- content_type: 'image/x-icon'
content: ''
+ headers:
+ Access-Control-Allow-Origin:
+ - '*'
+ Content-Type:
+ - 'image/x-icon'
# If not defined, AdGuard DNS will respond with a 404 page to all such
# requests.
root_redirect_url: 'https://adguard-dns.com'
@@ -221,6 +233,7 @@ filtering_groups:
safe_browsing:
enabled: true
block_private_relay: false
+ block_firefox_canary: true
- id: 'family'
parental:
enabled: true
@@ -234,6 +247,7 @@ filtering_groups:
safe_browsing:
enabled: true
block_private_relay: false
+ block_firefox_canary: true
- id: 'non_filtering'
rule_lists:
enabled: false
@@ -242,6 +256,7 @@ filtering_groups:
safe_browsing:
enabled: false
block_private_relay: false
+ block_firefox_canary: true
# Server groups and servers.
server_groups:
diff --git a/doc/configuration.md b/doc/configuration.md
index e68825a..49fdc69 100644
--- a/doc/configuration.md
+++ b/doc/configuration.md
@@ -85,11 +85,23 @@ The `ratelimit` object has the following properties:
**Example:** `30m`.
- * `rps`:
- The rate of requests per second for one subnet. Requests above this are
- counted in the backoff count.
+ * `ipv4`:
+ The ipv4 configuration object. It has the following fields:
- **Example:** `30`.
+ * `rps`:
+ The rate of requests per second for one subnet. Requests above this are
+ counted in the backoff count.
+
+ **Example:** `30`.
+
+ * `ipv4-subnet_key_len`:
+ The length of the subnet prefix used to calculate rate limiter bucket keys.
+
+ **Example:** `24`.
+
+ * `ipv6`:
+ The `ipv6` configuration object has the same properties as the `ipv4` one
+ above.
* `back_off_count`:
Maximum number of requests a client can make above the RPS within
@@ -120,22 +132,11 @@ The `ratelimit` object has the following properties:
**Example:** `30s`.
- * `ipv4_subnet_key_len`:
- The length of the subnet prefix used to calculate rate limiter bucket keys
- for IPv4 addresses.
-
- **Example:** `24`.
-
- * `ipv6_subnet_key_len`:
- Same as `ipv4_subnet_key_len` above but for IPv6 addresses.
-
- **Example:** `48`.
-
-For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and `rps`
-is `5`, a client (meaning all IP addresses within the subnet defined by
-`ipv4_subnet_key_len` and `ipv6_subnet_key_len`) that made 15 requests in one
-second or 6 requests (one above `rps`) every second for 10 seconds within one
-minute, the client is blocked for `back_off_duration`.
+For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and
+`ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined
+by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests
+(one above `rps`) every second for 10 seconds within one minute, the client is
+blocked for `back_off_duration`.
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
@@ -454,13 +455,17 @@ The optional `web` object has the following properties:
`safe_browsing` and `adult_blocking` servers. Paths must not duplicate the
ones used by the DNS-over-HTTPS server.
+ Inside of the `headers` map, the header `Content-Type` is required.
+
**Property example:**
```yaml
- 'static_content':
+ static_content:
'/favicon.ico':
- 'content_type': 'image/x-icon'
- 'content': 'base64content'
+ content: 'base64content'
+ headers:
+ 'Content-Type':
+ - 'image/x-icon'
```
* `root_redirect_url`:
@@ -647,6 +652,12 @@ The items of the `filtering_groups` array have the following properties:
**Example:** `false`.
+ * `block_firefox_canary`:
+ If true, Firefox canary domain queries are blocked for requests using this
+ filtering group.
+
+ **Example:** `true`.
+
## Server groups
diff --git a/doc/environment.md b/doc/environment.md
index 48e2b6b..ffc39e3 100644
--- a/doc/environment.md
+++ b/doc/environment.md
@@ -22,6 +22,7 @@ sensitive configuration. All other configuration is stored in the
* [`LISTEN_PORT`](#LISTEN_PORT)
* [`LOG_TIMESTAMP`](#LOG_TIMESTAMP)
* [`QUERYLOG_PATH`](#QUERYLOG_PATH)
+ * [`RESEARCH_METRICS`](#RESEARCH_METRICS)
* [`RULESTAT_URL`](#RULESTAT_URL)
* [`SENTRY_DSN`](#SENTRY_DSN)
* [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE)
@@ -198,6 +199,15 @@ The path to the file into which the query log is going to be written.
+## `RESEARCH_METRICS`
+
+If `1`, enable collection of a set of special prometheus metrics (prefix is
+`dns_research`). If `0`, disable collection of those metrics.
+
+**Default:** `0`.
+
+
+
## `RULESTAT_URL`
The URL to send filtering rule list statistics to. If empty or unset, the
diff --git a/go.mod b/go.mod
index d26f926..65ee670 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@ go 1.19
require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
- github.com/AdguardTeam/golibs v0.11.3
+ github.com/AdguardTeam/golibs v0.11.4
github.com/AdguardTeam/urlfilter v0.16.1
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a
diff --git a/go.sum b/go.sum
index a2cbbd8..2ba411c 100644
--- a/go.sum
+++ b/go.sum
@@ -33,8 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
-github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
-github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
+github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
+github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
diff --git a/go.work.sum b/go.work.sum
index 444de9c..918589d 100644
--- a/go.work.sum
+++ b/go.work.sum
@@ -9,7 +9,11 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVl
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
+github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
+github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
+github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
+github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
@@ -74,7 +78,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
-github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
@@ -198,41 +202,36 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
-github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
-golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
-golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
+golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
-golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
@@ -246,6 +245,7 @@ gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
diff --git a/internal/agd/dns.go b/internal/agd/dns.go
index 39fa7cd..65839af 100644
--- a/internal/agd/dns.go
+++ b/internal/agd/dns.go
@@ -1,11 +1,6 @@
package agd
-import (
- "context"
- "net"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
-)
+import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
// Common DNS Message Constants, Types, And Utilities
@@ -27,10 +22,3 @@ const (
ProtoDoT = dnsserver.ProtoDoT
ProtoDNSCrypt = dnsserver.ProtoDNSCrypt
)
-
-// Resolver is the DNS resolver interface.
-//
-// See go doc net.Resolver.
-type Resolver interface {
- LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error)
-}
diff --git a/internal/agd/filterlist.go b/internal/agd/filterlist.go
index c5b0213..0fbe429 100644
--- a/internal/agd/filterlist.go
+++ b/internal/agd/filterlist.go
@@ -146,6 +146,10 @@ type FilteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay is blocked for requests
// using this filtering group.
BlockPrivateRelay bool
+
+ // BlockFirefoxCanary shows if Firefox canary domain is blocked for
+ // requests using this filtering group.
+ BlockFirefoxCanary bool
}
// FilteringGroupID is the ID of a filter group. It is an opaque string.
diff --git a/internal/agd/profile.go b/internal/agd/profile.go
index 58d8e42..672f2c6 100644
--- a/internal/agd/profile.go
+++ b/internal/agd/profile.go
@@ -73,6 +73,10 @@ type Profile struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests from all devices in this profile.
BlockPrivateRelay bool
+
+ // BlockFirefoxCanary shows if Firefox canary domain is blocked for
+ // requests from all devices in this profile.
+ BlockFirefoxCanary bool
}
// ProfileID is the ID of a profile. It is an opaque string.
diff --git a/internal/agd/profiledb.go b/internal/agd/profiledb.go
index 41899ae..39c3b10 100644
--- a/internal/agd/profiledb.go
+++ b/internal/agd/profiledb.go
@@ -17,6 +17,8 @@ import (
// Data Storage
// ProfileDB is the local database of profiles and other data.
+//
+// TODO(a.garipov): move this logic to the backend package.
type ProfileDB interface {
ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)
diff --git a/internal/agdio/agdio.go b/internal/agdio/agdio.go
index 4799202..bfd3300 100644
--- a/internal/agdio/agdio.go
+++ b/internal/agdio/agdio.go
@@ -7,6 +7,8 @@ package agdio
import (
"fmt"
"io"
+
+ "github.com/AdguardTeam/golibs/mathutil"
)
// LimitError is returned when the Limit is reached.
@@ -35,9 +37,8 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
}
}
- if int64(len(p)) > lr.n {
- p = p[0:lr.n]
- }
+ l := mathutil.Min(int64(len(p)), lr.n)
+ p = p[:l]
n, err = lr.r.Read(p)
lr.n -= int64(n)
diff --git a/internal/backend/profiledb.go b/internal/backend/profiledb.go
index c0dde37..4e089b2 100644
--- a/internal/backend/profiledb.go
+++ b/internal/backend/profiledb.go
@@ -165,6 +165,31 @@ type v1SettingsRespSettings struct {
FilteringEnabled bool `json:"filtering_enabled"`
Deleted bool `json:"deleted"`
BlockPrivateRelay bool `json:"block_private_relay"`
+ BlockFirefoxCanary bool `json:"block_firefox_canary"`
+}
+
+// type check
+var _ json.Unmarshaler = (*v1SettingsRespSettings)(nil)
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface for
+// *v1SettingsRespSettings. It puts default value into BlockFirefoxCanary
+// field while it is not implemented on the backend side.
+//
+// TODO(a.garipov): Remove once the backend starts to always send it.
+func (rs *v1SettingsRespSettings) UnmarshalJSON(b []byte) (err error) {
+ type defaultDec v1SettingsRespSettings
+
+ s := defaultDec{
+ BlockFirefoxCanary: true,
+ }
+
+ if err = json.Unmarshal(b, &s); err != nil {
+ return err
+ }
+
+ *rs = v1SettingsRespSettings(s)
+
+ return nil
}
// v1SettingsRespRuleLists is the structure for decoding filtering rule lists
@@ -414,10 +439,11 @@ const maxFltRespTTL = 1 * time.Hour
func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) {
ttl = time.Duration(respTTL) * time.Second
if ttl > maxFltRespTTL {
- return ttl, fmt.Errorf("too high: got %d, max %d", respTTL, maxFltRespTTL)
+ ttl = maxFltRespTTL
+ err = fmt.Errorf("too high: got %s, max %s", ttl, maxFltRespTTL)
}
- return ttl, nil
+ return ttl, err
}
// toInternal converts r to an [agd.DSProfilesResponse] instance.
@@ -461,6 +487,9 @@ func (r *v1SettingsResp) toInternal(
reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err)
// Go on and use the fixed value.
+ //
+ // TODO(ameshkov, a.garipov): Consider continuing, like with all
+ // other validation errors.
}
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
@@ -479,6 +508,7 @@ func (r *v1SettingsResp) toInternal(
QueryLogEnabled: s.QueryLogEnabled,
Deleted: s.Deleted,
BlockPrivateRelay: s.BlockPrivateRelay,
+ BlockFirefoxCanary: s.BlockFirefoxCanary,
})
}
diff --git a/internal/backend/profiledb_test.go b/internal/backend/profiledb_test.go
index 268de3a..2e16fa0 100644
--- a/internal/backend/profiledb_test.go
+++ b/internal/backend/profiledb_test.go
@@ -131,6 +131,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true,
Deleted: false,
BlockPrivateRelay: true,
+ BlockFirefoxCanary: true,
}, {
Parental: wantParental,
ID: "83f3ea8f",
@@ -162,6 +163,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
QueryLogEnabled: true,
Deleted: true,
BlockPrivateRelay: false,
+ BlockFirefoxCanary: false,
}},
}
diff --git a/internal/backend/testdata/profiles.json b/internal/backend/testdata/profiles.json
index a44e55d..11af374 100644
--- a/internal/backend/testdata/profiles.json
+++ b/internal/backend/testdata/profiles.json
@@ -11,6 +11,7 @@
},
"deleted": false,
"block_private_relay": true,
+ "block_firefox_canary": true,
"devices": [
{
"id": "118ffe93",
@@ -43,6 +44,7 @@
},
"deleted": true,
"block_private_relay": false,
+ "block_firefox_canary": false,
"devices": [
{
"id": "0d7724fa",
diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go
index 46e3b92..064b47c 100644
--- a/internal/cmd/cmd.go
+++ b/internal/cmd/cmd.go
@@ -28,6 +28,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/golibs/timeutil"
)
// Main is the entry point of application.
@@ -89,52 +90,75 @@ func Main() {
go func() {
defer geoIPMu.Unlock()
- geoIP, geoIPErr = envs.geoIP(c.GeoIP, errColl)
+ geoIP, geoIPErr = envs.geoIP(c.GeoIP)
}()
- // Safe Browsing Hosts
+ // Safe-browsing and adult-blocking filters
+
+ // TODO(ameshkov): Consider making configurable.
+ filteringResolver := agdnet.NewCachingResolver(
+ agdnet.DefaultResolver{},
+ 1*timeutil.Day,
+ )
err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm)
check(err)
- safeBrowsingConf := c.SafeBrowsing.toInternal(
+ safeBrowsingConf, err := c.SafeBrowsing.toInternal(
+ errColl,
+ filteringResolver,
agd.FilterListIDSafeBrowsing,
envs.FilterCachePath,
- errColl,
)
- safeBrowsingHashes, err := filter.NewHashStorage(safeBrowsingConf)
check(err)
- err = safeBrowsingHashes.Start()
+ safeBrowsingFilter, err := filter.NewHashPrefix(safeBrowsingConf)
check(err)
- adultBlockingConf := c.AdultBlocking.toInternal(
+ safeBrowsingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
+ Context: ctxWithDefaultTimeout,
+ Refresher: safeBrowsingFilter,
+ ErrColl: errColl,
+ Name: string(agd.FilterListIDSafeBrowsing),
+ Interval: safeBrowsingConf.Staleness,
+ RefreshOnShutdown: false,
+ RoutineLogsAreDebug: false,
+ })
+ err = safeBrowsingUpd.Start()
+ check(err)
+
+ adultBlockingConf, err := c.AdultBlocking.toInternal(
+ errColl,
+ filteringResolver,
agd.FilterListIDAdultBlocking,
envs.FilterCachePath,
- errColl,
)
- adultBlockingHashes, err := filter.NewHashStorage(adultBlockingConf)
check(err)
- err = adultBlockingHashes.Start()
+ adultBlockingFilter, err := filter.NewHashPrefix(adultBlockingConf)
check(err)
- // Filters And Filtering Groups
+ adultBlockingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
+ Context: ctxWithDefaultTimeout,
+ Refresher: adultBlockingFilter,
+ ErrColl: errColl,
+ Name: string(agd.FilterListIDAdultBlocking),
+ Interval: adultBlockingConf.Staleness,
+ RefreshOnShutdown: false,
+ RoutineLogsAreDebug: false,
+ })
+ err = adultBlockingUpd.Start()
+ check(err)
- fltStrgConf := c.Filters.toInternal(errColl, envs)
- fltStrgConf.SafeBrowsing = &filter.HashPrefixConfig{
- Hashes: safeBrowsingHashes,
- ReplacementHost: c.SafeBrowsing.BlockHost,
- CacheTTL: c.SafeBrowsing.CacheTTL.Duration,
- CacheSize: c.SafeBrowsing.CacheSize,
- }
+ // Filter storage and filtering groups
- fltStrgConf.AdultBlocking = &filter.HashPrefixConfig{
- Hashes: adultBlockingHashes,
- ReplacementHost: c.AdultBlocking.BlockHost,
- CacheTTL: c.AdultBlocking.CacheTTL.Duration,
- CacheSize: c.AdultBlocking.CacheSize,
- }
+ fltStrgConf := c.Filters.toInternal(
+ errColl,
+ filteringResolver,
+ envs,
+ safeBrowsingFilter,
+ adultBlockingFilter,
+ )
fltStrg, err := filter.NewDefaultStorage(fltStrgConf)
check(err)
@@ -153,8 +177,6 @@ func Main() {
err = fltStrgUpd.Start()
check(err)
- safeBrowsing := filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes)
-
// Server Groups
fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
@@ -329,8 +351,11 @@ func Main() {
}, c.Upstream.Healthcheck.Enabled)
dnsConf := &dnssvc.Config{
- Messages: messages,
- SafeBrowsing: safeBrowsing,
+ Messages: messages,
+ SafeBrowsing: filter.NewSafeBrowsingServer(
+ safeBrowsingConf.Hashes,
+ adultBlockingConf.Hashes,
+ ),
BillStat: billStatRec,
ProfileDB: profDB,
DNSCheck: dnsCk,
@@ -349,6 +374,7 @@ func Main() {
CacheSize: c.Cache.Size,
ECSCacheSize: c.Cache.ECSSize,
UseECSCache: c.Cache.Type == cacheTypeECS,
+ ResearchMetrics: bool(envs.ResearchMetrics),
}
dnsSvc, err := dnssvc.New(dnsConf)
@@ -383,11 +409,11 @@ func Main() {
)
h := newSignalHandler(
- adultBlockingHashes,
- safeBrowsingHashes,
debugSvc,
webSvc,
dnsSvc,
+ safeBrowsingUpd,
+ adultBlockingUpd,
profDBUpd,
dnsDBUpd,
geoIPUpd,
diff --git a/internal/cmd/config.go b/internal/cmd/config.go
index 34dddfc..9bed65f 100644
--- a/internal/cmd/config.go
+++ b/internal/cmd/config.go
@@ -3,14 +3,9 @@ package cmd
import (
"fmt"
"os"
- "path/filepath"
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/golibs/errors"
- "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v2"
)
@@ -213,74 +208,6 @@ func (c *geoIPConfig) validate() (err error) {
}
}
-// allowListConfig is the consul allow list configuration.
-type allowListConfig struct {
- // List contains IPs and CIDRs.
- List []string `yaml:"list"`
-
- // RefreshIvl time between two updates of allow list from the Consul URL.
- RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
-}
-
-// safeBrowsingConfig is the configuration for one of the safe browsing filters.
-type safeBrowsingConfig struct {
- // URL is the URL used to update the filter.
- URL *agdhttp.URL `yaml:"url"`
-
- // BlockHost is the hostname with which to respond to any requests that
- // match the filter.
- //
- // TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
- // addresses.
- BlockHost string `yaml:"block_host"`
-
- // CacheSize is the size of the response cache, in entries.
- CacheSize int `yaml:"cache_size"`
-
- // CacheTTL is the TTL of the response cache.
- CacheTTL timeutil.Duration `yaml:"cache_ttl"`
-
- // RefreshIvl defines how often AdGuard DNS refreshes the filter.
- RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
-}
-
-// toInternal converts c to the safe browsing filter configuration for the
-// filter storage of the DNS server. c is assumed to be valid.
-func (c *safeBrowsingConfig) toInternal(
- id agd.FilterListID,
- cacheDir string,
- errColl agd.ErrorCollector,
-) (conf *filter.HashStorageConfig) {
- return &filter.HashStorageConfig{
- URL: netutil.CloneURL(&c.URL.URL),
- ErrColl: errColl,
- ID: id,
- CachePath: filepath.Join(cacheDir, string(id)),
- RefreshIvl: c.RefreshIvl.Duration,
- }
-}
-
-// validate returns an error if the safe browsing filter configuration is
-// invalid.
-func (c *safeBrowsingConfig) validate() (err error) {
- switch {
- case c == nil:
- return errNilConfig
- case c.URL == nil:
- return errors.Error("no url")
- case c.BlockHost == "":
- return errors.Error("no block_host")
- case c.CacheSize <= 0:
- return newMustBePositiveError("cache_size", c.CacheSize)
- case c.CacheTTL.Duration <= 0:
- return newMustBePositiveError("cache_ttl", c.CacheTTL)
- case c.RefreshIvl.Duration <= 0:
- return newMustBePositiveError("refresh_interval", c.RefreshIvl)
- default:
- return nil
- }
-}
-
// readConfig reads the configuration.
func readConfig(confPath string) (c *configuration, err error) {
// #nosec G304 -- Trust the path to the configuration file that is given
diff --git a/internal/cmd/env.go b/internal/cmd/env.go
index fe0f94c..5454692 100644
--- a/internal/cmd/env.go
+++ b/internal/cmd/env.go
@@ -48,8 +48,9 @@ type environments struct {
ListenPort int `env:"LISTEN_PORT" envDefault:"8181"`
- LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
- LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
+ LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"`
+ LogVerbose strictBool `env:"VERBOSE" envDefault:"0"`
+ ResearchMetrics strictBool `env:"RESEARCH_METRICS" envDefault:"0"`
}
// readEnvs reads the configuration.
@@ -127,12 +128,10 @@ func (envs *environments) buildDNSDB(
// geoIP returns an GeoIP database implementation from environment.
func (envs *environments) geoIP(
c *geoIPConfig,
- errColl agd.ErrorCollector,
) (g *geoip.File, err error) {
log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath)
g, err = geoip.NewFile(&geoip.FileConfig{
- ErrColl: errColl,
ASNPath: envs.GeoIPASNPath,
CountryPath: envs.GeoIPCountryPath,
HostCacheSize: c.HostCacheSize,
diff --git a/internal/cmd/filter.go b/internal/cmd/filter.go
index da77a78..678d276 100644
--- a/internal/cmd/filter.go
+++ b/internal/cmd/filter.go
@@ -47,16 +47,21 @@ type filtersConfig struct {
// cacheDir must exist. c is assumed to be valid.
func (c *filtersConfig) toInternal(
errColl agd.ErrorCollector,
+ resolver agdnet.Resolver,
envs *environments,
+ safeBrowsing *filter.HashPrefix,
+ adultBlocking *filter.HashPrefix,
) (conf *filter.DefaultStorageConfig) {
return &filter.DefaultStorageConfig{
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL),
GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL),
YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL),
+ SafeBrowsing: safeBrowsing,
+ AdultBlocking: adultBlocking,
Now: time.Now,
ErrColl: errColl,
- Resolver: agdnet.DefaultResolver{},
+ Resolver: resolver,
CacheDir: envs.FilterCachePath,
CustomFilterCacheSize: c.CustomFilterCacheSize,
SafeSearchCacheSize: c.SafeSearchCacheSize,
diff --git a/internal/cmd/filteringgroup.go b/internal/cmd/filteringgroup.go
index f202354..e1d03cf 100644
--- a/internal/cmd/filteringgroup.go
+++ b/internal/cmd/filteringgroup.go
@@ -29,6 +29,10 @@ type filteringGroup struct {
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests using this filtering group.
BlockPrivateRelay bool `yaml:"block_private_relay"`
+
+ // BlockFirefoxCanary shows if Firefox canary domain is blocked for
+ // requests using this filtering group.
+ BlockFirefoxCanary bool `yaml:"block_firefox_canary"`
}
// fltGrpRuleLists contains filter rule lists configuration for a filtering
@@ -133,6 +137,7 @@ func (groups filteringGroups) toInternal(
GeneralSafeSearch: g.Parental.GeneralSafeSearch,
YoutubeSafeSearch: g.Parental.YoutubeSafeSearch,
BlockPrivateRelay: g.BlockPrivateRelay,
+ BlockFirefoxCanary: g.BlockFirefoxCanary,
}
}
diff --git a/internal/cmd/ratelimit.go b/internal/cmd/ratelimit.go
index 677eec3..1c25d6b 100644
--- a/internal/cmd/ratelimit.go
+++ b/internal/cmd/ratelimit.go
@@ -15,14 +15,17 @@ type rateLimitConfig struct {
// AllowList is the allowlist of clients.
Allowlist *allowListConfig `yaml:"allowlist"`
+ // Rate limit options for IPv4 addresses.
+ IPv4 *rateLimitOptions `yaml:"ipv4"`
+
+ // Rate limit options for IPv6 addresses.
+ IPv6 *rateLimitOptions `yaml:"ipv6"`
+
// ResponseSizeEstimate is the size of the estimate of the size of one DNS
// response for the purposes of rate limiting. Responses over this estimate
// are counted as several responses.
ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"`
- // RPS is the maximum number of requests per second.
- RPS int `yaml:"rps"`
-
// BackOffCount helps with repeated offenders. It defines, how many times
// a client hits the rate limit before being held in the back off.
BackOffCount int `yaml:"back_off_count"`
@@ -35,18 +38,42 @@ type rateLimitConfig struct {
// a client has hit the rate limit for a back off.
BackOffPeriod timeutil.Duration `yaml:"back_off_period"`
- // IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
- // rate limiter bucket keys for IPv4 addresses.
- IPv4SubnetKeyLen int `yaml:"ipv4_subnet_key_len"`
-
- // IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
- // rate limiter bucket keys for IPv6 addresses.
- IPv6SubnetKeyLen int `yaml:"ipv6_subnet_key_len"`
-
// RefuseANY, if true, makes the server refuse DNS * queries.
RefuseANY bool `yaml:"refuse_any"`
}
+// allowListConfig is the consul allow list configuration.
+type allowListConfig struct {
+ // List contains IPs and CIDRs.
+ List []string `yaml:"list"`
+
+ // RefreshIvl time between two updates of allow list from the Consul URL.
+ RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
+}
+
+// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6
+// addresses.
+type rateLimitOptions struct {
+ // RPS is the maximum number of requests per second.
+ RPS int `yaml:"rps"`
+
+ // SubnetKeyLen is the length of the subnet prefix used to calculate
+ // rate limiter bucket keys.
+ SubnetKeyLen int `yaml:"subnet_key_len"`
+}
+
+// validate returns an error if rate limit options are invalid.
+func (o *rateLimitOptions) validate() (err error) {
+ if o == nil {
+ return errNilConfig
+ }
+
+ return coalesceError(
+ validatePositive("rps", o.RPS),
+ validatePositive("subnet_key_len", o.SubnetKeyLen),
+ )
+}
+
// toInternal converts c to the rate limiting configuration for the DNS server.
// c is assumed to be valid.
func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) {
@@ -55,10 +82,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()),
Duration: c.BackOffDuration.Duration,
Period: c.BackOffPeriod.Duration,
- RPS: c.RPS,
+ IPv4RPS: c.IPv4.RPS,
+ IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen,
+ IPv6RPS: c.IPv6.RPS,
+ IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen,
Count: c.BackOffCount,
- IPv4SubnetKeyLen: c.IPv4SubnetKeyLen,
- IPv6SubnetKeyLen: c.IPv6SubnetKeyLen,
RefuseANY: c.RefuseANY,
}
}
@@ -71,14 +99,21 @@ func (c *rateLimitConfig) validate() (err error) {
return fmt.Errorf("allowlist: %w", errNilConfig)
}
+ err = c.IPv4.validate()
+ if err != nil {
+ return fmt.Errorf("ipv4: %w", err)
+ }
+
+ err = c.IPv6.validate()
+ if err != nil {
+ return fmt.Errorf("ipv6: %w", err)
+ }
+
return coalesceError(
- validatePositive("rps", c.RPS),
validatePositive("back_off_count", c.BackOffCount),
validatePositive("back_off_duration", c.BackOffDuration),
validatePositive("back_off_period", c.BackOffPeriod),
validatePositive("response_size_estimate", c.ResponseSizeEstimate),
validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl),
- validatePositive("ipv4_subnet_key_len", c.IPv4SubnetKeyLen),
- validatePositive("ipv6_subnet_key_len", c.IPv6SubnetKeyLen),
)
}
diff --git a/internal/cmd/safebrowsing.go b/internal/cmd/safebrowsing.go
new file mode 100644
index 0000000..a16976b
--- /dev/null
+++ b/internal/cmd/safebrowsing.go
@@ -0,0 +1,86 @@
+package cmd
+
+import (
+ "path/filepath"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/AdguardTeam/golibs/timeutil"
+)
+
+// Safe-browsing and adult-blocking configuration
+
+// safeBrowsingConfig is the configuration for one of the safe browsing filters.
+type safeBrowsingConfig struct {
+ // URL is the URL used to update the filter.
+ URL *agdhttp.URL `yaml:"url"`
+
+ // BlockHost is the hostname with which to respond to any requests that
+ // match the filter.
+ //
+ // TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6
+ // addresses.
+ BlockHost string `yaml:"block_host"`
+
+ // CacheSize is the size of the response cache, in entries.
+ CacheSize int `yaml:"cache_size"`
+
+ // CacheTTL is the TTL of the response cache.
+ CacheTTL timeutil.Duration `yaml:"cache_ttl"`
+
+ // RefreshIvl defines how often AdGuard DNS refreshes the filter.
+ RefreshIvl timeutil.Duration `yaml:"refresh_interval"`
+}
+
+// toInternal converts c to the safe browsing filter configuration for the
+// filter storage of the DNS server. c is assumed to be valid.
+func (c *safeBrowsingConfig) toInternal(
+ errColl agd.ErrorCollector,
+ resolver agdnet.Resolver,
+ id agd.FilterListID,
+ cacheDir string,
+) (fltConf *filter.HashPrefixConfig, err error) {
+ hashes, err := hashstorage.New("")
+ if err != nil {
+ return nil, err
+ }
+
+ return &filter.HashPrefixConfig{
+ Hashes: hashes,
+ URL: netutil.CloneURL(&c.URL.URL),
+ ErrColl: errColl,
+ Resolver: resolver,
+ ID: id,
+ CachePath: filepath.Join(cacheDir, string(id)),
+ ReplacementHost: c.BlockHost,
+ Staleness: c.RefreshIvl.Duration,
+ CacheTTL: c.CacheTTL.Duration,
+ CacheSize: c.CacheSize,
+ }, nil
+}
+
+// validate returns an error if the safe browsing filter configuration is
+// invalid.
+func (c *safeBrowsingConfig) validate() (err error) {
+ switch {
+ case c == nil:
+ return errNilConfig
+ case c.URL == nil:
+ return errors.Error("no url")
+ case c.BlockHost == "":
+ return errors.Error("no block_host")
+ case c.CacheSize <= 0:
+ return newMustBePositiveError("cache_size", c.CacheSize)
+ case c.CacheTTL.Duration <= 0:
+ return newMustBePositiveError("cache_ttl", c.CacheTTL)
+ case c.RefreshIvl.Duration <= 0:
+ return newMustBePositiveError("refresh_interval", c.RefreshIvl)
+ default:
+ return nil
+ }
+}
diff --git a/internal/cmd/websvc.go b/internal/cmd/websvc.go
index c72a2ba..5e6834f 100644
--- a/internal/cmd/websvc.go
+++ b/internal/cmd/websvc.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/netip"
+ "net/textproto"
"os"
"path"
@@ -386,11 +387,8 @@ func (sc staticContent) validate() (err error) {
// staticFile is a single file in a static content mapping.
type staticFile struct {
- // AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
- AllowOrigin string `yaml:"allow_origin"`
-
- // ContentType is the value for the HTTP Content-Type header.
- ContentType string `yaml:"content_type"`
+ // Headers contains headers of the HTTP response.
+ Headers http.Header `yaml:"headers"`
// Content is the file content.
Content string `yaml:"content"`
@@ -400,8 +398,12 @@ type staticFile struct {
// assumed to be valid.
func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
file = &websvc.StaticFile{
- AllowOrigin: f.AllowOrigin,
- ContentType: f.ContentType,
+ Headers: http.Header{},
+ }
+
+ for k, vs := range f.Headers {
+ ck := textproto.CanonicalMIMEHeaderKey(k)
+ file.Headers[ck] = vs
}
file.Content, err = base64.StdEncoding.DecodeString(f.Content)
@@ -409,17 +411,20 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
return nil, fmt.Errorf("content: %w", err)
}
+ // Check Content-Type here as opposed to in validate, because we need
+ // all keys to be canonicalized first.
+ if file.Headers.Get(agdhttp.HdrNameContentType) == "" {
+ return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required")
+ }
+
return file, nil
}
// validate returns an error if the static content file is invalid.
func (f *staticFile) validate() (err error) {
- switch {
- case f == nil:
+ if f == nil {
return errors.Error("no file")
- case f.ContentType == "":
- return errors.Error("no content_type")
- default:
- return nil
}
+
+ return nil
}
diff --git a/internal/dnsserver/dnsservertest/msg.go b/internal/dnsserver/dnsservertest/msg.go
index f88ac7e..b9aba02 100644
--- a/internal/dnsserver/dnsservertest/msg.go
+++ b/internal/dnsserver/dnsservertest/msg.go
@@ -222,9 +222,9 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) {
}
}
-// FindENDS0Option searches for the specified EDNS0 option in the OPT resource
+// FindEDNS0Option searches for the specified EDNS0 option in the OPT resource
// record of the msg and returns it or nil if it's not present.
-func FindENDS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
+func FindEDNS0Option[T dns.EDNS0](msg *dns.Msg) (o T) {
rr := msg.IsEdns0()
if rr == nil {
return o
diff --git a/internal/dnsserver/go.mod b/internal/dnsserver/go.mod
index 9dc1562..af68597 100644
--- a/internal/dnsserver/go.mod
+++ b/internal/dnsserver/go.mod
@@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.19
require (
- github.com/AdguardTeam/golibs v0.11.3
+ github.com/AdguardTeam/golibs v0.11.4
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2
diff --git a/internal/dnsserver/go.sum b/internal/dnsserver/go.sum
index 356e549..01acb1f 100644
--- a/internal/dnsserver/go.sum
+++ b/internal/dnsserver/go.sum
@@ -31,8 +31,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
-github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310=
-github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
+github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
diff --git a/internal/dnsserver/listen_noreuseport.go b/internal/dnsserver/listen_noreuseport.go
deleted file mode 100644
index 8cf56e6..0000000
--- a/internal/dnsserver/listen_noreuseport.go
+++ /dev/null
@@ -1,36 +0,0 @@
-//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
-
-package dnsserver
-
-import (
- "context"
- "fmt"
- "net"
-
- "github.com/AdguardTeam/golibs/errors"
-)
-
-// listenUDP listens to the specified address on UDP.
-func listenUDP(_ context.Context, addr string, _ bool) (conn *net.UDPConn, err error) {
- defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
-
- c, err := net.ListenPacket("udp", addr)
- if err != nil {
- return nil, err
- }
-
- conn, ok := c.(*net.UDPConn)
- if !ok {
- // TODO(ameshkov): should not happen, consider panic here.
- err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
-
- return nil, err
- }
-
- return conn, nil
-}
-
-// listenTCP listens to the specified address on TCP.
-func listenTCP(_ context.Context, addr string) (conn net.Listener, err error) {
- return net.Listen("tcp", addr)
-}
diff --git a/internal/dnsserver/listen_reuseport.go b/internal/dnsserver/listen_reuseport.go
deleted file mode 100644
index 552782f..0000000
--- a/internal/dnsserver/listen_reuseport.go
+++ /dev/null
@@ -1,63 +0,0 @@
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
-
-package dnsserver
-
-import (
- "context"
- "fmt"
- "net"
- "syscall"
-
- "github.com/AdguardTeam/golibs/errors"
- "golang.org/x/sys/unix"
-)
-
-func reuseportControl(_, _ string, c syscall.RawConn) (err error) {
- var opErr error
- err = c.Control(func(fd uintptr) {
- opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
- })
- if err != nil {
- return err
- }
-
- return opErr
-}
-
-// listenUDP listens to the specified address on UDP. If oob flag is set to
-// true this method also enables OOB for the listen socket that enables using of
-// ReadMsgUDP/WriteMsgUDP. Doing it this way is necessary to correctly discover
-// the source address when it listens to 0.0.0.0.
-func listenUDP(ctx context.Context, addr string, oob bool) (conn *net.UDPConn, err error) {
- defer func() { err = errors.Annotate(err, "opening packet listener: %w") }()
-
- var lc net.ListenConfig
- lc.Control = reuseportControl
- c, err := lc.ListenPacket(ctx, "udp", addr)
- if err != nil {
- return nil, err
- }
-
- conn, ok := c.(*net.UDPConn)
- if !ok {
- // TODO(ameshkov): should not happen, consider panic here.
- err = fmt.Errorf("expected conn of type %T, got %T", conn, c)
-
- return nil, err
- }
-
- if oob {
- if err = setUDPSocketOptions(conn); err != nil {
- return nil, fmt.Errorf("failed to set socket options: %w", err)
- }
- }
-
- return conn, err
-}
-
-// listenTCP listens to the specified address on TCP.
-func listenTCP(ctx context.Context, addr string) (l net.Listener, err error) {
- var lc net.ListenConfig
- lc.Control = reuseportControl
- return lc.Listen(ctx, "tcp", addr)
-}
diff --git a/internal/dnsserver/netext/listenconfig.go b/internal/dnsserver/netext/listenconfig.go
new file mode 100644
index 0000000..19a47d3
--- /dev/null
+++ b/internal/dnsserver/netext/listenconfig.go
@@ -0,0 +1,73 @@
+// Package netext contains extensions of package net in the Go standard library.
+package netext
+
+import (
+ "context"
+ "fmt"
+ "net"
+)
+
+// ListenConfig is the interface that allows controlling options of connections
+// used by the DNS servers defined in this module. Default ListenConfigs are
+// the ones returned by [DefaultListenConfigWithOOB] for plain DNS and
+// [DefaultListenConfig] for others.
+//
+// This interface is modeled after [net.ListenConfig].
+type ListenConfig interface {
+ Listen(ctx context.Context, network, address string) (l net.Listener, err error)
+ ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error)
+}
+
+// DefaultListenConfig returns the default [ListenConfig] used by the servers in
+// this module except for the plain-DNS ones, which use
+// [DefaultListenConfigWithOOB].
+func DefaultListenConfig() (lc ListenConfig) {
+ return &net.ListenConfig{
+ Control: defaultListenControl,
+ }
+}
+
+// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the
+// plain-DNS servers in this module. The resulting ListenConfig sets additional
+// socket flags and processes the control-messages of connections created with
+// ListenPacket.
+func DefaultListenConfigWithOOB() (lc ListenConfig) {
+ return &listenConfigOOB{
+ ListenConfig: net.ListenConfig{
+ Control: defaultListenControl,
+ },
+ }
+}
+
+// type check
+var _ ListenConfig = (*listenConfigOOB)(nil)
+
+// listenConfigOOB is a wrapper around [net.ListenConfig] with modifications
+// that set the control-message options on packet conns.
+type listenConfigOOB struct {
+ net.ListenConfig
+}
+
+// ListenPacket implements the [ListenConfig] interface for *listenConfigOOB.
+// It sets the control-message flags to receive additional out-of-band data to
+// correctly discover the source address when it listens to 0.0.0.0 as well as
+// in situations when SO_BINDTODEVICE is used.
+//
+// network must be "udp", "udp4", or "udp6".
+func (lc *listenConfigOOB) ListenPacket(
+ ctx context.Context,
+ network string,
+ address string,
+) (c net.PacketConn, err error) {
+ c, err = lc.ListenConfig.ListenPacket(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ err = setIPOpts(c)
+ if err != nil {
+ return nil, fmt.Errorf("setting socket options: %w", err)
+ }
+
+ return wrapPacketConn(c), nil
+}
diff --git a/internal/dnsserver/netext/listenconfig_unix.go b/internal/dnsserver/netext/listenconfig_unix.go
new file mode 100644
index 0000000..09acf9b
--- /dev/null
+++ b/internal/dnsserver/netext/listenconfig_unix.go
@@ -0,0 +1,44 @@
+//go:build unix
+
+package netext
+
+import (
+ "net"
+ "syscall"
+
+ "github.com/AdguardTeam/golibs/errors"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "golang.org/x/sys/unix"
+)
+
+// defaultListenControl is used as a [net.ListenConfig.Control] function to set
+// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this
+// package.
+func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
+ var opErr error
+ err = c.Control(func(fd uintptr) {
+ opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+ })
+ if err != nil {
+ return err
+ }
+
+ return errors.WithDeferred(opErr, err)
+}
+
+// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
+func setIPOpts(c net.PacketConn) (err error) {
+ // TODO(a.garipov): Returning an error only if both functions return one
+ // (which is what module dns does as well) seems rather fragile. Depending
+ // on the OS, the valid errors are ENOPROTOOPT, EINVAL, and maybe others.
+ // Investigate and make OS-specific versions to make sure we don't miss any
+ // real errors.
+ err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
+ err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
+ if err4 != nil && err6 != nil {
+ return errors.List("setting ipv4 and ipv6 options", err4, err6)
+ }
+
+ return nil
+}
diff --git a/internal/dnsserver/netext/listenconfig_unix_test.go b/internal/dnsserver/netext/listenconfig_unix_test.go
new file mode 100644
index 0000000..caa4f6f
--- /dev/null
+++ b/internal/dnsserver/netext/listenconfig_unix_test.go
@@ -0,0 +1,67 @@
+//go:build unix
+
+package netext_test
+
+import (
+ "context"
+ "syscall"
+ "testing"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sys/unix"
+)
+
+func TestDefaultListenConfigWithOOB(t *testing.T) {
+ lc := netext.DefaultListenConfigWithOOB()
+ require.NotNil(t, lc)
+
+ type syscallConner interface {
+ SyscallConn() (c syscall.RawConn, err error)
+ }
+
+ t.Run("ipv4", func(t *testing.T) {
+ c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
+ require.NoError(t, opErr)
+
+ // TODO(a.garipov): Rewrite this to use actual expected values for
+ // each OS.
+ assert.NotEqual(t, 0, val)
+ })
+ require.NoError(t, err)
+ })
+
+ t.Run("ipv6", func(t *testing.T) {
+ c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0")
+ if errors.Is(err, syscall.EADDRNOTAVAIL) {
+ // Some CI machines have IPv6 disabled.
+ t.Skipf("ipv6 seems to not be supported: %s", err)
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT)
+ require.NoError(t, opErr)
+
+ assert.NotEqual(t, 0, val)
+ })
+ require.NoError(t, err)
+ })
+}
diff --git a/internal/dnsserver/netext/listenconfig_windows.go b/internal/dnsserver/netext/listenconfig_windows.go
new file mode 100644
index 0000000..20e11bf
--- /dev/null
+++ b/internal/dnsserver/netext/listenconfig_windows.go
@@ -0,0 +1,17 @@
+//go:build windows
+
+package netext
+
+import (
+ "net"
+ "syscall"
+)
+
+// defaultListenControl is nil on Windows, because it doesn't support
+// SO_REUSEPORT.
+var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
+
+// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
+func setIPOpts(c net.PacketConn) (err error) {
+ return nil
+}
diff --git a/internal/dnsserver/netext/packetconn.go b/internal/dnsserver/netext/packetconn.go
new file mode 100644
index 0000000..180f5df
--- /dev/null
+++ b/internal/dnsserver/netext/packetconn.go
@@ -0,0 +1,69 @@
+package netext
+
+import (
+ "net"
+)
+
+// PacketSession contains additional information about a packet read from or
+// written to a [SessionPacketConn].
+type PacketSession interface {
+ LocalAddr() (addr net.Addr)
+ RemoteAddr() (addr net.Addr)
+}
+
+// NewSimplePacketSession returns a new packet session using the given
+// parameters.
+func NewSimplePacketSession(laddr, raddr net.Addr) (s PacketSession) {
+ return &simplePacketSession{
+ laddr: laddr,
+ raddr: raddr,
+ }
+}
+
+// simplePacketSession is a simple implementation of the [PacketSession]
+// interface.
+type simplePacketSession struct {
+ laddr net.Addr
+ raddr net.Addr
+}
+
+// LocalAddr implements the [PacketSession] interface for *simplePacketSession.
+func (s *simplePacketSession) LocalAddr() (addr net.Addr) { return s.laddr }
+
+// RemoteAddr implements the [PacketSession] interface for *simplePacketSession.
+func (s *simplePacketSession) RemoteAddr() (addr net.Addr) { return s.raddr }
+
+// SessionPacketConn extends [net.PacketConn] with methods for working with
+// packet sessions.
+type SessionPacketConn interface {
+ net.PacketConn
+
+ ReadFromSession(b []byte) (n int, s PacketSession, err error)
+ WriteToSession(b []byte, s PacketSession) (n int, err error)
+}
+
+// ReadFromSession is a convenience wrapper for types that may or may not
+// implement [SessionPacketConn]. If c implements it, ReadFromSession uses
+// c.ReadFromSession. Otherwise, it uses c.ReadFrom and the session is created
+// by using [NewSimplePacketSession] with c.LocalAddr.
+func ReadFromSession(c net.PacketConn, b []byte) (n int, s PacketSession, err error) {
+ if spc, ok := c.(SessionPacketConn); ok {
+ return spc.ReadFromSession(b)
+ }
+
+ n, raddr, err := c.ReadFrom(b)
+ s = NewSimplePacketSession(c.LocalAddr(), raddr)
+
+ return n, s, err
+}
+
+// WriteToSession is a convenience wrapper for types that may or may not
+// implement [SessionPacketConn]. If c implements it, WriteToSession uses
+// c.WriteToSession. Otherwise, it uses c.WriteTo using s.RemoteAddr.
+func WriteToSession(c net.PacketConn, b []byte, s PacketSession) (n int, err error) {
+ if spc, ok := c.(SessionPacketConn); ok {
+ return spc.WriteToSession(b, s)
+ }
+
+ return c.WriteTo(b, s.RemoteAddr())
+}
diff --git a/internal/dnsserver/netext/packetconn_linux.go b/internal/dnsserver/netext/packetconn_linux.go
new file mode 100644
index 0000000..8ffa687
--- /dev/null
+++ b/internal/dnsserver/netext/packetconn_linux.go
@@ -0,0 +1,159 @@
+//go:build linux
+
+// TODO(a.garipov): Technically, we can expand this to other platforms, but that
+// would require separate udpOOBSize constants and tests.
+
+package netext
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+// type check
+var _ PacketSession = (*packetSession)(nil)
+
+// packetSession contains additional information about the packet read from a
+// UDP connection. It is basically an extended version of [dns.SessionUDP] that
+// contains the local address as well.
+type packetSession struct {
+ laddr *net.UDPAddr
+ raddr *net.UDPAddr
+ respOOB []byte
+}
+
+// LocalAddr implements the [PacketSession] interface for *packetSession.
+func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr }
+
+// RemoteAddr implements the [PacketSession] interface for *packetSession.
+func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr }
+
+// type check
+var _ SessionPacketConn = (*sessionPacketConn)(nil)
+
+// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
+// that.
+func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
+ return &sessionPacketConn{
+ UDPConn: *c.(*net.UDPConn),
+ }
+}
+
+// sessionPacketConn wraps a UDP connection and implements [SessionPacketConn].
+type sessionPacketConn struct {
+ net.UDPConn
+}
+
+// oobPool is the pool of byte slices for out-of-band data.
+var oobPool = &sync.Pool{
+ New: func() (v any) {
+ b := make([]byte, IPDstOOBSize)
+
+ return &b
+ },
+}
+
+// IPDstOOBSize is the required size of the control-message buffer for
+// [net.UDPConn.ReadMsgUDP] to read the original destination on Linux.
+//
+// See packetconn_linux_internal_test.go.
+const IPDstOOBSize = 40
+
+// ReadFromSession implements the [SessionPacketConn] interface for *packetConn.
+func (c *sessionPacketConn) ReadFromSession(b []byte) (n int, s PacketSession, err error) {
+ oobPtr := oobPool.Get().(*[]byte)
+ defer oobPool.Put(oobPtr)
+
+ var oobn int
+ oob := *oobPtr
+ ps := &packetSession{}
+ n, oobn, _, ps.raddr, err = c.ReadMsgUDP(b, oob)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var origDstIP net.IP
+ sockLAddr := c.LocalAddr().(*net.UDPAddr)
+ origDstIP, err = origLAddr(oob[:oobn])
+ if err != nil {
+ return 0, nil, fmt.Errorf("getting original addr: %w", err)
+ }
+
+ if origDstIP == nil {
+ ps.laddr = sockLAddr
+ } else {
+ ps.respOOB = newRespOOB(origDstIP)
+ ps.laddr = &net.UDPAddr{
+ IP: origDstIP,
+ Port: sockLAddr.Port,
+ }
+ }
+
+ return n, ps, nil
+}
+
+// origLAddr returns the original local address from the encoded control-message
+// data, if there is one. If not nil, origDst will have a protocol-appropriate
+// length.
+func origLAddr(oob []byte) (origDst net.IP, err error) {
+ ctrlMsg6 := &ipv6.ControlMessage{}
+ err = ctrlMsg6.Parse(oob)
+ if err != nil {
+ return nil, fmt.Errorf("parsing ipv6 control message: %w", err)
+ }
+
+ if dst := ctrlMsg6.Dst; dst != nil {
+ // Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
+ // IPv4 dst from an IPv6 control-message.
+ origDst = dst.To4()
+ if origDst == nil {
+ origDst = dst
+ }
+
+ return origDst, nil
+ }
+
+ ctrlMsg4 := &ipv4.ControlMessage{}
+ err = ctrlMsg4.Parse(oob)
+ if err != nil {
+ return nil, fmt.Errorf("parsing ipv4 control message: %w", err)
+ }
+
+ return ctrlMsg4.Dst.To4(), nil
+}
+
+// newRespOOB returns an encoded control-message for the response for this IP
+// address. origDst is expected to have a protocol-appropriate length.
+func newRespOOB(origDst net.IP) (b []byte) {
+ switch len(origDst) {
+ case net.IPv4len:
+ cm := &ipv4.ControlMessage{
+ Src: origDst,
+ }
+
+ return cm.Marshal()
+ case net.IPv6len:
+ cm := &ipv6.ControlMessage{
+ Src: origDst,
+ }
+
+ return cm.Marshal()
+ default:
+ return nil
+ }
+}
+
+// WriteToSession implements the [SessionPacketConn] interface for *packetConn.
+func (c *sessionPacketConn) WriteToSession(b []byte, s PacketSession) (n int, err error) {
+ if ps, ok := s.(*packetSession); ok {
+ n, _, err = c.WriteMsgUDP(b, ps.respOOB, ps.raddr)
+
+ return n, err
+ }
+
+ return c.WriteTo(b, s.RemoteAddr())
+}
diff --git a/internal/dnsserver/netext/packetconn_linux_internal_test.go b/internal/dnsserver/netext/packetconn_linux_internal_test.go
new file mode 100644
index 0000000..ad31abf
--- /dev/null
+++ b/internal/dnsserver/netext/packetconn_linux_internal_test.go
@@ -0,0 +1,25 @@
+//go:build linux
+
+package netext
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+func TestUDPOOBSize(t *testing.T) {
+ // See https://github.com/miekg/dns/blob/v1.1.50/udp.go.
+
+ len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface))
+ len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface))
+
+ max := len4
+ if len6 > max {
+ max = len6
+ }
+
+ assert.Equal(t, max, IPDstOOBSize)
+}
diff --git a/internal/dnsserver/netext/packetconn_linux_test.go b/internal/dnsserver/netext/packetconn_linux_test.go
new file mode 100644
index 0000000..6999445
--- /dev/null
+++ b/internal/dnsserver/netext/packetconn_linux_test.go
@@ -0,0 +1,140 @@
+//go:build linux
+
+package netext_test
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TODO(a.garipov): Add IPv6 test.
+func TestSessionPacketConn(t *testing.T) {
+ const numTries = 5
+
+ // Try the test multiple times to reduce flakiness due to UDP failures.
+ var success4, success6 bool
+ for i := 0; i < numTries; i++ {
+ var isTimeout4, isTimeout6 bool
+ success4 = t.Run(fmt.Sprintf("ipv4_%d", i), func(t *testing.T) {
+ isTimeout4 = testSessionPacketConn(t, "udp4", "0.0.0.0:0", net.IP{127, 0, 0, 1})
+ })
+
+ success6 = t.Run(fmt.Sprintf("ipv6_%d", i), func(t *testing.T) {
+ isTimeout6 = testSessionPacketConn(t, "udp6", "[::]:0", net.IPv6loopback)
+ })
+
+ if success4 && success6 {
+ break
+ } else if isTimeout4 || isTimeout6 {
+ continue
+ }
+
+ t.Fail()
+ }
+
+ if !success4 {
+ t.Errorf("ipv4 test failed after %d attempts", numTries)
+ } else if !success6 {
+ t.Errorf("ipv6 test failed after %d attempts", numTries)
+ }
+}
+
+func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) {
+ lc := netext.DefaultListenConfigWithOOB()
+ require.NotNil(t, lc)
+
+ c, err := lc.ListenPacket(context.Background(), proto, addr)
+ if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ require.NotNil(t, c)
+
+ deadline := time.Now().Add(1 * time.Second)
+ err = c.SetDeadline(deadline)
+ require.NoError(t, err)
+
+ laddr := testutil.RequireTypeAssert[*net.UDPAddr](t, c.LocalAddr())
+ require.NotNil(t, laddr)
+
+ dstAddr := &net.UDPAddr{
+ IP: dstIP,
+ Port: laddr.Port,
+ }
+
+ remoteConn, err := net.DialUDP(proto, nil, dstAddr)
+ if proto == "udp6" && errors.Is(err, syscall.EADDRNOTAVAIL) {
+ // Some CI machines have IPv6 disabled.
+ t.Skipf("ipv6 seems to not be supported: %s", err)
+ } else if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ err = remoteConn.SetDeadline(deadline)
+ require.NoError(t, err)
+
+ msg := []byte("hello")
+ msgLen := len(msg)
+ _, err = remoteConn.Write(msg)
+ if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ require.Implements(t, (*netext.SessionPacketConn)(nil), c)
+
+ buf := make([]byte, msgLen)
+ n, sess, err := netext.ReadFromSession(c, buf)
+ if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ assert.Equal(t, msgLen, n)
+ assert.Equal(t, net.Addr(dstAddr), sess.LocalAddr())
+ assert.Equal(t, remoteConn.LocalAddr(), sess.RemoteAddr())
+ assert.Equal(t, msg, buf)
+
+ respMsg := []byte("world")
+ respMsgLen := len(respMsg)
+ n, err = netext.WriteToSession(c, respMsg, sess)
+ if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ assert.Equal(t, respMsgLen, n)
+
+ buf = make([]byte, respMsgLen)
+ n, err = remoteConn.Read(buf)
+ if isTimeoutOrFail(t, err) {
+ return true
+ }
+
+ assert.Equal(t, respMsgLen, n)
+ assert.Equal(t, respMsg, buf)
+
+ return false
+}
+
+// isTimeoutOrFail is a helper function that returns true if err is a timeout
+// error and also calls require.NoError on err.
+func isTimeoutOrFail(t *testing.T, err error) (ok bool) {
+ t.Helper()
+
+ if err == nil {
+ return false
+ }
+
+ defer require.NoError(t, err)
+
+ return errors.Is(err, os.ErrDeadlineExceeded)
+}
diff --git a/internal/dnsserver/netext/packetconn_others.go b/internal/dnsserver/netext/packetconn_others.go
new file mode 100644
index 0000000..d64c8f4
--- /dev/null
+++ b/internal/dnsserver/netext/packetconn_others.go
@@ -0,0 +1,11 @@
+//go:build !linux
+
+package netext
+
+import "net"
+
+// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports
+// that.
+func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) {
+ return c
+}
diff --git a/internal/dnsserver/prometheus/ratelimit_test.go b/internal/dnsserver/prometheus/ratelimit_test.go
index 15d634f..1cc1157 100644
--- a/internal/dnsserver/prometheus/ratelimit_test.go
+++ b/internal/dnsserver/prometheus/ratelimit_test.go
@@ -27,7 +27,8 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) {
Duration: time.Minute,
Count: rps,
ResponseSizeEstimate: 1000,
- RPS: rps,
+ IPv4RPS: rps,
+ IPv6RPS: rps,
RefuseANY: true,
})
rlMw, err := ratelimit.NewMiddleware(rl, nil)
diff --git a/internal/dnsserver/ratelimit/backoff.go b/internal/dnsserver/ratelimit/backoff.go
index e1ab0de..c2cc6ce 100644
--- a/internal/dnsserver/ratelimit/backoff.go
+++ b/internal/dnsserver/ratelimit/backoff.go
@@ -37,15 +37,22 @@ type BackOffConfig struct {
// as several responses.
ResponseSizeEstimate int
- // RPS is the maximum number of requests per second allowed from a single
- // subnet. Any requests above this rate are counted as the client's
- // back-off count. RPS must be greater than zero.
- RPS int
+ // IPv4RPS is the maximum number of requests per second allowed from a
+ // single subnet for IPv4 addresses. Any requests above this rate are
+ // counted as the client's back-off count. RPS must be greater than
+ // zero.
+ IPv4RPS int
// IPv4SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv4 addresses. Must be greater than zero.
IPv4SubnetKeyLen int
+ // IPv6RPS is the maximum number of requests per second allowed from a
+ // single subnet for IPv6 addresses. Any requests above this rate are
+ // counted as the client's back-off count. RPS must be greater than
+ // zero.
+ IPv6RPS int
+
// IPv6SubnetKeyLen is the length of the subnet prefix used to calculate
// rate limiter bucket keys for IPv6 addresses. Must be greater than zero.
IPv6SubnetKeyLen int
@@ -62,14 +69,18 @@ type BackOffConfig struct {
// current implementation might be too abstract. Middlewares by themselves
// already provide an interface that can be re-implemented by the users.
// Perhaps, another layer of abstraction is unnecessary.
+//
+// TODO(ameshkov): Consider splitting rps and other properties by protocol
+// family.
type BackOff struct {
rpsCounters *cache.Cache
hitCounters *cache.Cache
allowlist Allowlist
count int
- rps int
respSzEst int
+ ipv4rps int
ipv4SubnetKeyLen int
+ ipv6rps int
ipv6SubnetKeyLen int
refuseANY bool
}
@@ -84,9 +95,10 @@ func NewBackOff(c *BackOffConfig) (l *BackOff) {
hitCounters: cache.New(c.Duration, c.Duration),
allowlist: c.Allowlist,
count: c.Count,
- rps: c.RPS,
respSzEst: c.ResponseSizeEstimate,
+ ipv4rps: c.IPv4RPS,
ipv4SubnetKeyLen: c.IPv4SubnetKeyLen,
+ ipv6rps: c.IPv6RPS,
ipv6SubnetKeyLen: c.IPv6SubnetKeyLen,
refuseANY: c.RefuseANY,
}
@@ -124,7 +136,12 @@ func (l *BackOff) IsRateLimited(
return true, false, nil
}
- return l.hasHitRateLimit(key), false, nil
+ rps := l.ipv4rps
+ if ip.Is6() {
+ rps = l.ipv6rps
+ }
+
+ return l.hasHitRateLimit(key, rps), false, nil
}
// validateAddr returns an error if addr is not a valid IPv4 or IPv6 address.
@@ -184,14 +201,15 @@ func (l *BackOff) incBackOff(key string) {
l.hitCounters.SetDefault(key, counter)
}
-// hasHitRateLimit checks value for a subnet.
-func (l *BackOff) hasHitRateLimit(subnetIPStr string) (ok bool) {
- var r *rps
+// hasHitRateLimit checks value for a subnet with rps as a maximum number
+// requests per second.
+func (l *BackOff) hasHitRateLimit(subnetIPStr string, rps int) (ok bool) {
+ var r *rpsCounter
rVal, ok := l.rpsCounters.Get(subnetIPStr)
if ok {
- r = rVal.(*rps)
+ r = rVal.(*rpsCounter)
} else {
- r = newRPS(l.rps)
+ r = newRPSCounter(rps)
l.rpsCounters.SetDefault(subnetIPStr, r)
}
diff --git a/internal/dnsserver/ratelimit/ratelimit_test.go b/internal/dnsserver/ratelimit/ratelimit_test.go
index 8be231d..efa3ba4 100644
--- a/internal/dnsserver/ratelimit/ratelimit_test.go
+++ b/internal/dnsserver/ratelimit/ratelimit_test.go
@@ -98,8 +98,9 @@ func TestRatelimitMiddleware(t *testing.T) {
Duration: time.Minute,
Count: rps,
ResponseSizeEstimate: 128,
- RPS: rps,
+ IPv4RPS: rps,
IPv4SubnetKeyLen: 24,
+ IPv6RPS: rps,
IPv6SubnetKeyLen: 48,
RefuseANY: true,
})
diff --git a/internal/dnsserver/ratelimit/rps.go b/internal/dnsserver/ratelimit/rps.go
index 823db8f..dcc63dd 100644
--- a/internal/dnsserver/ratelimit/rps.go
+++ b/internal/dnsserver/ratelimit/rps.go
@@ -7,17 +7,17 @@ import (
// Requests Per Second Counter
-// rps is a single request per seconds counter.
-type rps struct {
+// rpsCounter is a single request per seconds counter.
+type rpsCounter struct {
// mu protects all fields.
mu *sync.Mutex
ring []int64
idx int
}
-// newRPS returns a new requests per second counter. n must be above zero.
-func newRPS(n int) (r *rps) {
- return &rps{
+// newRPSCounter returns a new requests per second counter. n must be above zero.
+func newRPSCounter(n int) (r *rpsCounter) {
+ return &rpsCounter{
mu: &sync.Mutex{},
// Add one, because we need to always keep track of the previous
// request. For example, consider n == 1.
@@ -28,7 +28,7 @@ func newRPS(n int) (r *rps) {
// add adds another request to the counter. above is true if the request goes
// above the counter value. It is safe for concurrent use.
-func (r *rps) add(t time.Time) (above bool) {
+func (r *rpsCounter) add(t time.Time) (above bool) {
r.mu.Lock()
defer r.mu.Unlock()
diff --git a/internal/dnsserver/serverbase.go b/internal/dnsserver/serverbase.go
index 42512f5..5ea8206 100644
--- a/internal/dnsserver/serverbase.go
+++ b/internal/dnsserver/serverbase.go
@@ -8,6 +8,7 @@ import (
"sync"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@@ -31,16 +32,19 @@ type ConfigBase struct {
// Handler is a handler that processes incoming DNS messages.
// If not set, we'll use the default handler that returns error response
// to any query.
-
Handler Handler
+
// Metrics is the object we use for collecting performance metrics.
// This field is optional.
-
Metrics MetricsListener
// BaseContext is a function that should return the base context. If not
// set, we'll be using context.Background().
BaseContext func() (ctx context.Context)
+
+ // ListenConfig, when set, is used to set options of connections used by the
+ // DNS server. If nil, an appropriate default ListenConfig is used.
+ ListenConfig netext.ListenConfig
}
// ServerBase implements base methods that every Server implementation uses.
@@ -67,14 +71,19 @@ type ServerBase struct {
// metrics is the object we use for collecting performance metrics.
metrics MetricsListener
+ // listenConfig is used to set tcpListener and udpListener.
+ listenConfig netext.ListenConfig
+
// Server operation
// --
- // will be nil for servers that don't use TCP.
+ // tcpListener is used to accept new TCP connections. It is nil for servers
+ // that don't use TCP.
tcpListener net.Listener
- // will be nil for servers that don't use UDP.
- udpListener *net.UDPConn
+ // udpListener is used to accept new UDP messages. It is nil for servers
+ // that don't use UDP.
+ udpListener net.PacketConn
// Shutdown handling
// --
@@ -94,13 +103,14 @@ var _ Server = (*ServerBase)(nil)
// some of its internal properties.
func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) {
s = &ServerBase{
- name: conf.Name,
- addr: conf.Addr,
- proto: proto,
- network: conf.Network,
- handler: conf.Handler,
- metrics: conf.Metrics,
- baseContext: conf.BaseContext,
+ name: conf.Name,
+ addr: conf.Addr,
+ proto: proto,
+ network: conf.Network,
+ handler: conf.Handler,
+ metrics: conf.Metrics,
+ listenConfig: conf.ListenConfig,
+ baseContext: conf.BaseContext,
}
if s.baseContext == nil {
@@ -347,18 +357,17 @@ func (s *ServerBase) handlePanicAndRecover(ctx context.Context) {
}
}
-// listenUDP creates a UDP listener for the ServerBase.addr. This function will
-// initialize and start ServerBase.udpListener or return an error. If the TCP
-// listener is already running, its address is used instead. The point of this
-// is to properly handle the case when port 0 is used as both listeners should
-// use the same port, and we only learn it after the first one was started.
+// listenUDP initializes and starts s.udpListener using s.addr. If the TCP
+// listener is already running, its address is used instead to properly handle
+// the case when port 0 is used as both listeners should use the same port, and
+// we only learn it after the first one was started.
func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
addr := s.addr
if s.tcpListener != nil {
addr = s.tcpListener.Addr().String()
}
- conn, err := listenUDP(ctx, addr, true)
+ conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr)
if err != nil {
return err
}
@@ -368,19 +377,17 @@ func (s *ServerBase) listenUDP(ctx context.Context) (err error) {
return nil
}
-// listenTCP creates a TCP listener for the ServerBase.addr. This function will
-// initialize and start ServerBase.tcpListener or return an error. If the UDP
-// listener is already running, its address is used instead. The point of this
-// is to properly handle the case when port 0 is used as both listeners should
-// use the same port, and we only learn it after the first one was started.
+// listenTCP initializes and starts s.tcpListener using s.addr. If the UDP
+// listener is already running, its address is used instead to properly handle
+// the case when port 0 is used as both listeners should use the same port, and
+// we only learn it after the first one was started.
func (s *ServerBase) listenTCP(ctx context.Context) (err error) {
addr := s.addr
if s.udpListener != nil {
addr = s.udpListener.LocalAddr().String()
}
- var l net.Listener
- l, err = listenTCP(ctx, addr)
+ l, err := s.listenConfig.Listen(ctx, "tcp", addr)
if err != nil {
return err
}
diff --git a/internal/dnsserver/serverdns.go b/internal/dnsserver/serverdns.go
index c54beb4..88ee67e 100644
--- a/internal/dnsserver/serverdns.go
+++ b/internal/dnsserver/serverdns.go
@@ -6,6 +6,7 @@ import (
"sync"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
@@ -110,6 +111,10 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) {
conf.TCPSize = dns.MinMsgSize
}
+ if conf.ListenConfig == nil {
+ conf.ListenConfig = netext.DefaultListenConfigWithOOB()
+ }
+
s = &ServerDNS{
ServerBase: newServerBase(proto, conf.ConfigBase),
conf: conf,
@@ -262,10 +267,21 @@ func makePacketBuffer(size int) (f func() any) {
}
}
+// writeDeadlineSetter is an interface for connections that can set write
+// deadlines.
+type writeDeadlineSetter interface {
+ SetWriteDeadline(t time.Time) (err error)
+}
+
// withWriteDeadline is a helper that takes the deadline of the context and the
// write timeout into account. It sets the write deadline on conn before
// calling f and resets it once f is done.
-func withWriteDeadline(ctx context.Context, writeTimeout time.Duration, conn net.Conn, f func()) {
+func withWriteDeadline(
+ ctx context.Context,
+ writeTimeout time.Duration,
+ conn writeDeadlineSetter,
+ f func(),
+) {
dl, hasDeadline := ctx.Deadline()
if !hasDeadline {
dl = time.Now().Add(writeTimeout)
diff --git a/internal/dnsserver/serverdns_test.go b/internal/dnsserver/serverdns_test.go
index 1ca02f8..0e6f3d6 100644
--- a/internal/dnsserver/serverdns_test.go
+++ b/internal/dnsserver/serverdns_test.go
@@ -284,8 +284,8 @@ func TestServerDNS_integration_query(t *testing.T) {
tc.expectedTruncated,
)
- reqKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
- respKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
+ reqKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req)
+ respKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp)
if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil {
require.NotNil(t, respKeepAliveOpt)
expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100)
diff --git a/internal/dnsserver/serverdnscrypt.go b/internal/dnsserver/serverdnscrypt.go
index dfb94a1..84e30d0 100644
--- a/internal/dnsserver/serverdnscrypt.go
+++ b/internal/dnsserver/serverdnscrypt.go
@@ -2,7 +2,9 @@ package dnsserver
import (
"context"
+ "net"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2"
@@ -38,6 +40,10 @@ var _ Server = (*ServerDNSCrypt)(nil)
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
+ if conf.ListenConfig == nil {
+ conf.ListenConfig = netext.DefaultListenConfig()
+ }
+
return &ServerDNSCrypt{
ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase),
conf: conf,
@@ -140,7 +146,10 @@ func (s *ServerDNSCrypt) startServeUDP(ctx context.Context) {
// TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in
// dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext
// methods for now.
- err := s.dnsCryptServer.ServeUDP(s.udpListener)
+ //
+ // TODO(ameshkov): Redo the dnscrypt module to make it not depend on
+ // *net.UDPConn and use net.PacketConn instead.
+ err := s.dnsCryptServer.ServeUDP(s.udpListener.(*net.UDPConn))
if err != nil {
log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err)
}
diff --git a/internal/dnsserver/serverdnsudp.go b/internal/dnsserver/serverdnsudp.go
index 2dc4ce2..ea0aecc 100644
--- a/internal/dnsserver/serverdnsudp.go
+++ b/internal/dnsserver/serverdnsudp.go
@@ -4,23 +4,21 @@ import (
"context"
"fmt"
"net"
- "runtime"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
)
// serveUDP runs the UDP serving loop.
-func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) {
+func (s *ServerDNS) serveUDP(ctx context.Context, conn net.PacketConn) (err error) {
defer log.OnCloserError(conn, log.DEBUG)
for s.isStarted() {
var m []byte
- var sess *dns.SessionUDP
+ var sess netext.PacketSession
m, sess, err = s.readUDPMsg(ctx, conn)
if err != nil {
// TODO(ameshkov): Consider the situation where the server is shut
@@ -59,15 +57,15 @@ func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error)
func (s *ServerDNS) serveUDPPacket(
ctx context.Context,
m []byte,
- conn *net.UDPConn,
- udpSession *dns.SessionUDP,
+ conn net.PacketConn,
+ sess netext.PacketSession,
) {
defer s.wg.Done()
defer s.handlePanicAndRecover(ctx)
rw := &udpResponseWriter{
+ udpSession: sess,
conn: conn,
- udpSession: udpSession,
writeTimeout: s.conf.WriteTimeout,
}
s.serveDNS(ctx, m, rw)
@@ -75,15 +73,18 @@ func (s *ServerDNS) serveUDPPacket(
}
// readUDPMsg reads the next incoming DNS message.
-func (s *ServerDNS) readUDPMsg(ctx context.Context, conn *net.UDPConn) (msg []byte, sess *dns.SessionUDP, err error) {
+func (s *ServerDNS) readUDPMsg(
+ ctx context.Context,
+ conn net.PacketConn,
+) (msg []byte, sess netext.PacketSession, err error) {
err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
if err != nil {
return nil, nil, err
}
m := s.getUDPBuffer()
- var n int
- n, sess, err = dns.ReadFromSessionUDP(conn, m)
+
+ n, sess, err := netext.ReadFromSession(conn, m)
if err != nil {
s.putUDPBuffer(m)
@@ -120,30 +121,10 @@ func (s *ServerDNS) putUDPBuffer(m []byte) {
s.udpPool.Put(&m)
}
-// setUDPSocketOptions is a function that is necessary to be able to use
-// dns.ReadFromSessionUDP and dns.WriteToSessionUDP.
-// TODO(ameshkov): https://github.com/AdguardTeam/AdGuardHome/issues/2807
-func setUDPSocketOptions(conn *net.UDPConn) (err error) {
- if runtime.GOOS == "windows" {
- return nil
- }
-
- // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
- // Try enabling receiving of ECN and packet info for both IP versions.
- // We expect at least one of those syscalls to succeed.
- err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
- err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
- if err4 != nil && err6 != nil {
- return errors.List("error while setting NetworkUDP socket options", err4, err6)
- }
-
- return nil
-}
-
// udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP.
type udpResponseWriter struct {
- udpSession *dns.SessionUDP
- conn *net.UDPConn
+ udpSession netext.PacketSession
+ conn net.PacketConn
writeTimeout time.Duration
}
@@ -152,13 +133,15 @@ var _ ResponseWriter = (*udpResponseWriter)(nil)
// LocalAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) LocalAddr() (addr net.Addr) {
- return r.conn.LocalAddr()
+ // Don't use r.conn.LocalAddr(), since udpSession may actually contain the
+ // decoded OOB data, including the real local (dst) address.
+ return r.udpSession.LocalAddr()
}
// RemoteAddr implements the ResponseWriter interface for *udpResponseWriter.
func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) {
- // Don't use r.conn.RemoteAddr(), since udpSession actually contains the
- // decoded OOB data, including the remote address.
+ // Don't use r.conn.RemoteAddr(), since udpSession may actually contain the
+ // decoded OOB data, including the real remote (src) address.
return r.udpSession.RemoteAddr()
}
@@ -173,7 +156,7 @@ func (r *udpResponseWriter) WriteMsg(ctx context.Context, req, resp *dns.Msg) (e
}
withWriteDeadline(ctx, r.writeTimeout, r.conn, func() {
- _, err = dns.WriteToSessionUDP(r.conn, data, r.udpSession)
+ _, err = netext.WriteToSession(r.conn, data, r.udpSession)
})
if err != nil {
diff --git a/internal/dnsserver/serverhttps.go b/internal/dnsserver/serverhttps.go
index 28bf24c..ba9c5d5 100644
--- a/internal/dnsserver/serverhttps.go
+++ b/internal/dnsserver/serverhttps.go
@@ -14,6 +14,7 @@ import (
"strings"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -94,6 +95,12 @@ var _ Server = (*ServerHTTPS)(nil)
// NewServerHTTPS creates a new ServerHTTPS instance.
func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
+ if conf.ListenConfig == nil {
+ // Do not enable OOB here, because ListenPacket is only used by HTTP/3,
+ // and quic-go sets the necessary flags.
+ conf.ListenConfig = netext.DefaultListenConfig()
+ }
+
s = &ServerHTTPS{
ServerBase: newServerBase(ProtoDoH, conf.ConfigBase),
conf: conf,
@@ -500,8 +507,7 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
tlsConf.NextProtos = nextProtoDoH3
}
- // Do not enable OOB here as quic-go will do that on its own.
- conn, err := listenUDP(ctx, s.addr, false)
+ conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
if err != nil {
return err
}
@@ -518,24 +524,28 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) {
return nil
}
-// httpContextWithClientInfo adds client info to the context.
+// httpContextWithClientInfo adds client info to the context. ctx is never nil,
+// even when there is an error.
func httpContextWithClientInfo(
parent context.Context,
r *http.Request,
) (ctx context.Context, err error) {
+ ctx = parent
+
ci := ClientInfo{
URL: netutil.CloneURL(r.URL),
}
- // Due to the quic-go bug we should use Host instead of r.TLS:
- // https://github.com/lucas-clemente/quic-go/issues/3596
+ // Due to the quic-go bug we should use Host instead of r.TLS. See
+ // https://github.com/quic-go/quic-go/issues/2879 and
+ // https://github.com/lucas-clemente/quic-go/issues/3596.
//
- // TODO(ameshkov): remove this when the bug is fixed in quic-go.
+ // TODO(ameshkov): Remove when quic-go is fixed, likely in v0.32.0.
if r.ProtoAtLeast(3, 0) {
var host string
host, err = netutil.SplitHost(r.Host)
if err != nil {
- return nil, fmt.Errorf("failed to parse Host: %w", err)
+ return ctx, fmt.Errorf("failed to parse Host: %w", err)
}
ci.TLSServerName = host
@@ -543,7 +553,7 @@ func httpContextWithClientInfo(
ci.TLSServerName = strings.ToLower(r.TLS.ServerName)
}
- return ContextWithClientInfo(parent, ci), nil
+ return ContextWithClientInfo(ctx, ci), nil
}
// httpRequestToMsg reads the DNS message from http.Request.
diff --git a/internal/dnsserver/serverhttps_test.go b/internal/dnsserver/serverhttps_test.go
index 3a8918a..f927798 100644
--- a/internal/dnsserver/serverhttps_test.go
+++ b/internal/dnsserver/serverhttps_test.go
@@ -133,7 +133,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
require.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt.
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
})
}
@@ -338,7 +338,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) {
resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req)
require.True(t, resp.Response)
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}
diff --git a/internal/dnsserver/serverquic.go b/internal/dnsserver/serverquic.go
index 609d3e9..fbca4ee 100644
--- a/internal/dnsserver/serverquic.go
+++ b/internal/dnsserver/serverquic.go
@@ -12,6 +12,7 @@ import (
"sync"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache"
@@ -98,6 +99,11 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) {
tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...)
}
+ if conf.ListenConfig == nil {
+ // Do not enable OOB here as quic-go will do that on its own.
+ conf.ListenConfig = netext.DefaultListenConfig()
+ }
+
s = &ServerQUIC{
ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase),
conf: conf,
@@ -476,8 +482,7 @@ func (s *ServerQUIC) readQUICMsg(
// listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts
// the QUIC listener.
func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) {
- // Do not enable OOB here as quic-go will do that on its own.
- conn, err := listenUDP(ctx, s.addr, false)
+ conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr)
if err != nil {
return err
}
diff --git a/internal/dnsserver/serverquic_test.go b/internal/dnsserver/serverquic_test.go
index 88d09a1..600c975 100644
--- a/internal/dnsserver/serverquic_test.go
+++ b/internal/dnsserver/serverquic_test.go
@@ -59,7 +59,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
assert.True(t, resp.Response)
// EDNS0 padding is only present when request also has padding opt.
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}()
}
@@ -97,7 +97,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response)
require.False(t, resp.Truncated)
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}
diff --git a/internal/dnsserver/servertls.go b/internal/dnsserver/servertls.go
index 0a8baaa..2659f5a 100644
--- a/internal/dnsserver/servertls.go
+++ b/internal/dnsserver/servertls.go
@@ -3,7 +3,6 @@ package dnsserver
import (
"context"
"crypto/tls"
- "net"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -102,10 +101,9 @@ func (s *ServerTLS) startServeTCP(ctx context.Context) {
}
}
-// listenTLS creates the TLS listener for the ServerTLS.addr.
+// listenTLS creates the TLS listener for s.addr.
func (s *ServerTLS) listenTLS(ctx context.Context) (err error) {
- var l net.Listener
- l, err = listenTCP(ctx, s.addr)
+ l, err := s.listenConfig.Listen(ctx, "tcp", s.addr)
if err != nil {
return err
}
diff --git a/internal/dnsserver/servertls_test.go b/internal/dnsserver/servertls_test.go
index d33390a..aa37a3c 100644
--- a/internal/dnsserver/servertls_test.go
+++ b/internal/dnsserver/servertls_test.go
@@ -40,7 +40,7 @@ func TestServerTLS_integration_queryTLS(t *testing.T) {
require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt.
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}
@@ -142,7 +142,7 @@ func TestServerTLS_integration_noTruncateQuery(t *testing.T) {
require.False(t, resp.Truncated)
// EDNS0 padding is only present when request also has padding opt.
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.Nil(t, paddingOpt)
}
@@ -231,7 +231,7 @@ func TestServerTLS_integration_ENDS0Padding(t *testing.T) {
require.True(t, resp.Response)
require.False(t, resp.Truncated)
- paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp)
+ paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp)
require.NotNil(t, paddingOpt)
require.NotEmpty(t, paddingOpt.Padding)
}
diff --git a/internal/dnssvc/debug_internal_test.go b/internal/dnssvc/debug_internal_test.go
new file mode 100644
index 0000000..24d1c20
--- /dev/null
+++ b/internal/dnssvc/debug_internal_test.go
@@ -0,0 +1,172 @@
+package dnssvc
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newTXTExtra is a helper function that converts strs into DNS TXT resource
+// records with Name and Txt fields set to first and second values of each
+// tuple.
+func newTXTExtra(strs [][2]string) (extra []dns.RR) {
+ for _, v := range strs {
+ extra = append(extra, &dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: v[0],
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassCHAOS,
+ Ttl: 1,
+ },
+ Txt: []string{v[1]},
+ })
+ }
+
+ return extra
+}
+
+func TestService_writeDebugResponse(t *testing.T) {
+ svc := &Service{messages: &dnsmsg.Constructor{FilteredResponseTTL: time.Second}}
+
+ const (
+ fltListID1 agd.FilterListID = "fl1"
+ fltListID2 agd.FilterListID = "fl2"
+
+ blockRule = "||example.com^"
+ )
+
+ testCases := []struct {
+ name string
+ ri *agd.RequestInfo
+ reqRes filter.Result
+ respRes filter.Result
+ wantExtra []dns.RR
+ }{{
+ name: "normal",
+ ri: &agd.RequestInfo{},
+ reqRes: nil,
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"resp.res-type.adguard-dns.com.", "normal"},
+ }),
+ }, {
+ name: "request_result_blocked",
+ ri: &agd.RequestInfo{},
+ reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule},
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"req.res-type.adguard-dns.com.", "blocked"},
+ {"req.rule.adguard-dns.com.", "||example.com^"},
+ {"req.rule-list-id.adguard-dns.com.", "fl1"},
+ }),
+ }, {
+ name: "response_result_blocked",
+ ri: &agd.RequestInfo{},
+ reqRes: nil,
+ respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule},
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"resp.res-type.adguard-dns.com.", "blocked"},
+ {"resp.rule.adguard-dns.com.", "||example.com^"},
+ {"resp.rule-list-id.adguard-dns.com.", "fl2"},
+ }),
+ }, {
+ name: "request_result_allowed",
+ ri: &agd.RequestInfo{},
+ reqRes: &filter.ResultAllowed{},
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"req.res-type.adguard-dns.com.", "allowed"},
+ {"req.rule.adguard-dns.com.", ""},
+ {"req.rule-list-id.adguard-dns.com.", ""},
+ }),
+ }, {
+ name: "response_result_allowed",
+ ri: &agd.RequestInfo{},
+ reqRes: nil,
+ respRes: &filter.ResultAllowed{},
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"resp.res-type.adguard-dns.com.", "allowed"},
+ {"resp.rule.adguard-dns.com.", ""},
+ {"resp.rule-list-id.adguard-dns.com.", ""},
+ }),
+ }, {
+ name: "request_result_modified",
+ ri: &agd.RequestInfo{},
+ reqRes: &filter.ResultModified{
+ Rule: "||example.com^$dnsrewrite=REFUSED",
+ },
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"req.res-type.adguard-dns.com.", "modified"},
+ {"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"},
+ {"req.rule-list-id.adguard-dns.com.", ""},
+ }),
+ }, {
+ name: "device",
+ ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}},
+ reqRes: nil,
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"device-id.adguard-dns.com.", "dev1234"},
+ {"resp.res-type.adguard-dns.com.", "normal"},
+ }),
+ }, {
+ name: "profile",
+ ri: &agd.RequestInfo{
+ Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")},
+ },
+ reqRes: nil,
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"profile-id.adguard-dns.com.", "some-profile-id"},
+ {"resp.res-type.adguard-dns.com.", "normal"},
+ }),
+ }, {
+ name: "location",
+ ri: &agd.RequestInfo{Location: &agd.Location{Country: agd.CountryAD}},
+ reqRes: nil,
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"country.adguard-dns.com.", "AD"},
+ {"asn.adguard-dns.com.", "0"},
+ {"resp.res-type.adguard-dns.com.", "normal"},
+ }),
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
+
+ ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri)
+
+ req := dnsservertest.NewReq("example.com", dns.TypeA, dns.ClassINET)
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
+
+ err := svc.writeDebugResponse(ctx, rw, req, resp, tc.reqRes, tc.respRes)
+ require.NoError(t, err)
+
+ msg := rw.Msg()
+ require.NotNil(t, msg)
+
+ assert.Equal(t, tc.wantExtra, msg.Extra)
+ })
+ }
+}
diff --git a/internal/dnssvc/deviceid.go b/internal/dnssvc/deviceid.go
index 60b67f9..9710664 100644
--- a/internal/dnssvc/deviceid.go
+++ b/internal/dnssvc/deviceid.go
@@ -195,7 +195,15 @@ func deviceIDFromEDNS(req *dns.Msg) (id agd.DeviceID, err error) {
continue
}
- return agd.NewDeviceID(string(o.Data))
+ id, err = agd.NewDeviceID(string(o.Data))
+ if err != nil {
+ return "", &deviceIDError{
+ err: err,
+ typ: "edns option",
+ }
+ }
+
+ return id, nil
}
return "", nil
diff --git a/internal/dnssvc/deviceid_internal_test.go b/internal/dnssvc/deviceid_internal_test.go
index 30f7b50..2058ae8 100644
--- a/internal/dnssvc/deviceid_internal_test.go
+++ b/internal/dnssvc/deviceid_internal_test.go
@@ -233,7 +233,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte{},
},
wantDeviceID: "",
- wantErrMsg: `bad device id "": too short: got 0 bytes, min 1`,
+ wantErrMsg: `edns option device id check: bad device id "": ` +
+ `too short: got 0 bytes, min 1`,
}, {
name: "bad_device_id",
opt: &dns.EDNS0_LOCAL{
@@ -241,7 +242,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) {
Data: []byte("toolongdeviceid"),
},
wantDeviceID: "",
- wantErrMsg: `bad device id "toolongdeviceid": too long: got 15 bytes, max 8`,
+ wantErrMsg: `edns option device id check: bad device id "toolongdeviceid": ` +
+ `too long: got 15 bytes, max 8`,
}, {
name: "device_id",
opt: &dns.EDNS0_LOCAL{
diff --git a/internal/dnssvc/dnssvc.go b/internal/dnssvc/dnssvc.go
index f5a6a30..d3ba3da 100644
--- a/internal/dnssvc/dnssvc.go
+++ b/internal/dnssvc/dnssvc.go
@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/netip"
- "net/url"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
@@ -79,10 +78,6 @@ type Config struct {
// Upstream defines the upstream server and the group of fallback servers.
Upstream *agd.Upstream
- // RootRedirectURL is the URL to which non-DNS and non-Debug HTTP requests
- // are redirected.
- RootRedirectURL *url.URL
-
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
@@ -119,6 +114,11 @@ type Config struct {
// UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be
// used.
UseECSCache bool
+
+ // ResearchMetrics controls whether research metrics are enabled or not.
+ // This is a set of metrics that we may need temporary, so its collection is
+ // controlled by a separate setting.
+ ResearchMetrics bool
}
// New returns a new DNS service.
@@ -150,7 +150,6 @@ func New(c *Config) (svc *Service, err error) {
groups := make([]*serverGroup, len(c.ServerGroups))
svc = &Service{
messages: c.Messages,
- rootRedirectURL: c.RootRedirectURL,
billStat: c.BillStat,
errColl: c.ErrColl,
fltStrg: c.FilterStorage,
@@ -158,6 +157,7 @@ func New(c *Config) (svc *Service, err error) {
queryLog: c.QueryLog,
ruleStat: c.RuleStat,
groups: groups,
+ researchMetrics: c.ResearchMetrics,
}
for i, srvGrp := range c.ServerGroups {
@@ -212,7 +212,6 @@ var _ agd.Service = (*Service)(nil)
// Service is the main DNS service of AdGuard DNS.
type Service struct {
messages *dnsmsg.Constructor
- rootRedirectURL *url.URL
billStat billstat.Recorder
errColl agd.ErrorCollector
fltStrg filter.Storage
@@ -220,6 +219,7 @@ type Service struct {
queryLog querylog.Interface
ruleStat rulestat.Interface
groups []*serverGroup
+ researchMetrics bool
}
// mustStartListener starts l and panics on any error.
diff --git a/internal/dnssvc/initmw.go b/internal/dnssvc/initmw.go
index e7561bf..87d5c27 100644
--- a/internal/dnssvc/initmw.go
+++ b/internal/dnssvc/initmw.go
@@ -260,14 +260,6 @@ func (mh *initMwHandler) ServeDNS(
// Copy middleware to the local variable to make the code simpler.
mw := mh.mw
- if specHdlr, name := mw.noReqInfoSpecialHandler(fqdn, qt, cl); specHdlr != nil {
- optlog.Debug1("init mw: got no-req-info special handler %s", name)
-
- // Don't wrap the error, because it's informative enough as is, and
- // because if handled is true, the main flow terminates here.
- return specHdlr(ctx, rw, req)
- }
-
// Get the request's information, such as GeoIP data and user profiles.
ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
if err != nil {
diff --git a/internal/dnssvc/initmw_internal_test.go b/internal/dnssvc/initmw_internal_test.go
index ab05594..f5cb8b7 100644
--- a/internal/dnssvc/initmw_internal_test.go
+++ b/internal/dnssvc/initmw_internal_test.go
@@ -277,7 +277,7 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
}
}
-func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
+func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
testCases := []struct {
name string
host string
@@ -287,7 +287,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked bool
wantRCode dnsmsg.RCode
}{{
- name: "blocked_by_fltgrp",
+ name: "private_relay_blocked_by_fltgrp",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
@@ -295,7 +295,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeNameError,
}, {
- name: "no_private_relay_domain",
+ name: "no_special_domain",
host: "www.example.com",
qtype: dns.TypeA,
fltGrpBlocked: true,
@@ -311,7 +311,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
- name: "blocked_by_prof",
+ name: "private_relay_blocked_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
@@ -319,7 +319,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: true,
wantRCode: dns.RcodeNameError,
}, {
- name: "allowed_by_prof",
+ name: "private_relay_allowed_by_prof",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
@@ -327,7 +327,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
- name: "allowed_by_both",
+ name: "private_relay_allowed_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: false,
@@ -335,13 +335,45 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
profBlocked: false,
wantRCode: dns.RcodeSuccess,
}, {
- name: "blocked_by_both",
+ name: "private_relay_blocked_by_both",
host: applePrivateRelayMaskHost,
qtype: dns.TypeA,
fltGrpBlocked: true,
hasProf: true,
profBlocked: true,
wantRCode: dns.RcodeNameError,
+ }, {
+ name: "firefox_canary_allowed_by_prof",
+ host: firefoxCanaryHost,
+ qtype: dns.TypeA,
+ fltGrpBlocked: false,
+ hasProf: true,
+ profBlocked: false,
+ wantRCode: dns.RcodeSuccess,
+ }, {
+ name: "firefox_canary_allowed_by_fltgrp",
+ host: firefoxCanaryHost,
+ qtype: dns.TypeA,
+ fltGrpBlocked: false,
+ hasProf: false,
+ profBlocked: false,
+ wantRCode: dns.RcodeSuccess,
+ }, {
+ name: "firefox_canary_blocked_by_prof",
+ host: firefoxCanaryHost,
+ qtype: dns.TypeA,
+ fltGrpBlocked: false,
+ hasProf: true,
+ profBlocked: true,
+ wantRCode: dns.RcodeRefused,
+ }, {
+ name: "firefox_canary_blocked_by_fltgrp",
+ host: firefoxCanaryHost,
+ qtype: dns.TypeA,
+ fltGrpBlocked: true,
+ hasProf: false,
+ profBlocked: false,
+ wantRCode: dns.RcodeRefused,
}}
for _, tc := range testCases {
@@ -368,9 +400,12 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
return nil, nil, agd.DeviceNotFoundError{}
}
- return &agd.Profile{
- BlockPrivateRelay: tc.profBlocked,
- }, &agd.Device{}, nil
+ prof := &agd.Profile{
+ BlockPrivateRelay: tc.profBlocked,
+ BlockFirefoxCanary: tc.profBlocked,
+ }
+
+ return prof, &agd.Device{}, nil
}
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
@@ -406,7 +441,8 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
FilteredResponseTTL: 10 * time.Second,
},
fltGrp: &agd.FilteringGroup{
- BlockPrivateRelay: tc.fltGrpBlocked,
+ BlockPrivateRelay: tc.fltGrpBlocked,
+ BlockFirefoxCanary: tc.fltGrpBlocked,
},
srvGrp: &agd.ServerGroup{},
srv: &agd.Server{
@@ -436,7 +472,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) {
resp := rw.Msg()
require.NotNil(t, resp)
- assert.Equal(t, dnsmsg.RCode(resp.Rcode), tc.wantRCode)
+ assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode))
})
}
}
diff --git a/internal/dnssvc/middleware.go b/internal/dnssvc/middleware.go
index 628dd68..5d0719b 100644
--- a/internal/dnssvc/middleware.go
+++ b/internal/dnssvc/middleware.go
@@ -101,7 +101,7 @@ func (svc *Service) filterQuery(
) (reqRes, respRes filter.Result) {
start := time.Now()
defer func() {
- reportMetrics(ri, reqRes, respRes, time.Since(start))
+ svc.reportMetrics(ri, reqRes, respRes, time.Since(start))
}()
f := svc.fltStrg.FilterFromContext(ctx, ri)
@@ -120,7 +120,7 @@ func (svc *Service) filterQuery(
// reportMetrics extracts filtering metrics data from the context and reports it
// to Prometheus.
-func reportMetrics(
+func (svc *Service) reportMetrics(
ri *agd.RequestInfo,
reqRes filter.Result,
respRes filter.Result,
@@ -139,7 +139,7 @@ func reportMetrics(
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
- id, _, _ := filteringData(reqRes, respRes)
+ id, _, blocked := filteringData(reqRes, respRes)
metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
string(id),
metrics.BoolString(ri.Profile == nil),
@@ -147,6 +147,22 @@ func reportMetrics(
metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds())
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
+
+ if svc.researchMetrics {
+ anonymous := ri.Profile == nil
+ filteringEnabled := ri.FilteringGroup != nil &&
+ ri.FilteringGroup.RuleListsEnabled &&
+ len(ri.FilteringGroup.RuleListIDs) > 0
+
+ metrics.ReportResearchMetrics(
+ anonymous,
+ filteringEnabled,
+ asn,
+ ctry,
+ string(id),
+ blocked,
+ )
+ }
}
// reportf is a helper method for reporting non-critical errors.
diff --git a/internal/dnssvc/presvcmw_test.go b/internal/dnssvc/presvcmw_test.go
new file mode 100644
index 0000000..345dac6
--- /dev/null
+++ b/internal/dnssvc/presvcmw_test.go
@@ -0,0 +1,135 @@
+package dnssvc
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
+ const safeBrowsingHost = "scam.example.net."
+
+ var (
+ ip = net.IP{127, 0, 0, 1}
+ name = "example.com"
+ )
+
+ sum := sha256.Sum256([]byte(safeBrowsingHost))
+ hashStr := hex.EncodeToString(sum[:])
+ hashes, herr := hashstorage.New(safeBrowsingHost)
+ require.NoError(t, herr)
+
+ srv := filter.NewSafeBrowsingServer(hashes, nil)
+ host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix
+
+ ctx := context.Background()
+ ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
+ ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
+ ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
+
+ const ttl = 60
+
+ testCases := []struct {
+ name string
+ req *dns.Msg
+ dnscheckResp *dns.Msg
+ ri *agd.RequestInfo
+ wantAns []dns.RR
+ }{{
+ name: "normal",
+ req: dnsservertest.CreateMessage(name, dns.TypeA),
+ dnscheckResp: nil,
+ ri: &agd.RequestInfo{},
+ wantAns: []dns.RR{
+ dnsservertest.NewA(name, 100, ip),
+ },
+ }, {
+ name: "dnscheck",
+ req: dnsservertest.CreateMessage(name, dns.TypeA),
+ dnscheckResp: dnsservertest.NewResp(
+ dns.RcodeSuccess,
+ dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET),
+ dnsservertest.RRSection{
+ RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)},
+ Sec: dnsservertest.SectionAnswer,
+ },
+ ),
+ ri: &agd.RequestInfo{
+ Host: name,
+ QType: dns.TypeA,
+ QClass: dns.ClassINET,
+ },
+ wantAns: []dns.RR{
+ dnsservertest.NewA(name, ttl, ip),
+ },
+ }, {
+ name: "with_hashes",
+ req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT),
+ dnscheckResp: nil,
+ ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT},
+ wantAns: []dns.RR{&dns.TXT{
+ Hdr: dns.RR_Header{
+ Name: safeBrowsingHost,
+ Rrtype: dns.TypeTXT,
+ Class: dns.ClassINET,
+ Ttl: ttl,
+ },
+ Txt: []string{hashStr},
+ }},
+ }, {
+ name: "not_matched",
+ req: dnsservertest.CreateMessage(name, dns.TypeTXT),
+ dnscheckResp: nil,
+ ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT},
+ wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)},
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
+ tctx := agd.ContextWithRequestInfo(ctx, tc.ri)
+
+ dnsCk := &agdtest.DNSCheck{
+ OnCheck: func(
+ ctx context.Context,
+ msg *dns.Msg,
+ ri *agd.RequestInfo,
+ ) (resp *dns.Msg, err error) {
+ return tc.dnscheckResp, nil
+ },
+ }
+
+ mw := &preServiceMw{
+ messages: &dnsmsg.Constructor{
+ FilteredResponseTTL: ttl * time.Second,
+ },
+ filter: srv,
+ checker: dnsCk,
+ }
+ handler := dnsservertest.DefaultHandler()
+ h := mw.Wrap(handler)
+
+ err := h.ServeDNS(tctx, rw, tc.req)
+ require.NoError(t, err)
+
+ msg := rw.Msg()
+ require.NotNil(t, msg)
+
+ assert.Equal(t, tc.wantAns, msg.Answer)
+ })
+ }
+}
diff --git a/internal/dnssvc/preupstreammw_test.go b/internal/dnssvc/preupstreammw_test.go
new file mode 100644
index 0000000..58a0c91
--- /dev/null
+++ b/internal/dnssvc/preupstreammw_test.go
@@ -0,0 +1,264 @@
+package dnssvc
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ reqHostname = "example.com."
+ defaultTTL = 3600
+)
+
+func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) {
+ remoteIP := netip.MustParseAddr("1.2.3.4")
+ aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
+ respIP := remoteIP.AsSlice()
+
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
+ RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)},
+ Sec: dnsservertest.SectionAnswer,
+ })
+ ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
+ Host: aReq.Question[0].Name,
+ })
+
+ const N = 5
+ testCases := []struct {
+ name string
+ cacheSize int
+ wantNumReq int
+ }{{
+ name: "no_cache",
+ cacheSize: 0,
+ wantNumReq: N,
+ }, {
+ name: "with_cache",
+ cacheSize: 100,
+ wantNumReq: 1,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ numReq := 0
+ handler := dnsserver.HandlerFunc(func(
+ ctx context.Context,
+ rw dnsserver.ResponseWriter,
+ req *dns.Msg,
+ ) error {
+ numReq++
+
+ return rw.WriteMsg(ctx, req, resp)
+ })
+
+ mw := &preUpstreamMw{
+ db: dnsdb.Empty{},
+ cacheSize: tc.cacheSize,
+ }
+ h := mw.Wrap(handler)
+
+ for i := 0; i < N; i++ {
+ req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
+ addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
+ nrw := dnsserver.NewNonWriterResponseWriter(addr, addr)
+
+ err := h.ServeDNS(ctx, nrw, req)
+ require.NoError(t, err)
+ }
+
+ assert.Equal(t, tc.wantNumReq, numReq)
+ })
+ }
+}
+
+func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) {
+ aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
+ remoteIP := netip.MustParseAddr("1.2.3.4")
+ subnet := netip.MustParsePrefix("1.2.3.4/24")
+
+ const ctry = agd.CountryAD
+
+ resp := dnsservertest.NewResp(
+ dns.RcodeSuccess,
+ aReq,
+ dnsservertest.RRSection{
+ RRs: []dns.RR{dnsservertest.NewA(
+ reqHostname,
+ defaultTTL,
+ net.IP{1, 2, 3, 4},
+ )},
+ Sec: dnsservertest.SectionAnswer,
+ },
+ )
+
+ numReq := 0
+ handler := dnsserver.HandlerFunc(
+ func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error {
+ numReq++
+
+ return rw.WriteMsg(ctx, req, resp)
+ },
+ )
+
+ geoIP := &agdtest.GeoIP{
+ OnSubnetByLocation: func(
+ ctry agd.Country,
+ _ agd.ASN,
+ _ netutil.AddrFamily,
+ ) (n netip.Prefix, err error) {
+ return netip.MustParsePrefix("1.2.0.0/16"), nil
+ },
+ OnData: func(_ string, _ netip.Addr) (_ *agd.Location, _ error) {
+ panic("not implemented")
+ },
+ }
+
+ mw := &preUpstreamMw{
+ db: dnsdb.Empty{},
+ geoIP: geoIP,
+ cacheSize: 100,
+ ecsCacheSize: 100,
+ useECSCache: true,
+ }
+ h := mw.Wrap(handler)
+
+ ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
+ Location: &agd.Location{
+ Country: ctry,
+ },
+ ECS: &agd.ECS{
+ Location: &agd.Location{
+ Country: ctry,
+ },
+ Subnet: subnet,
+ Scope: 0,
+ },
+ Host: aReq.Question[0].Name,
+ RemoteIP: remoteIP,
+ })
+
+ const N = 5
+ var nrw *dnsserver.NonWriterResponseWriter
+ for i := 0; i < N; i++ {
+ addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53}
+ nrw = dnsserver.NewNonWriterResponseWriter(addr, addr)
+ req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
+
+ err := h.ServeDNS(ctx, nrw, req)
+ require.NoError(t, err)
+ }
+
+ assert.Equal(t, 1, numReq)
+}
+
+func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
+ mw := &preUpstreamMw{db: dnsdb.Empty{}}
+
+ req := dnsservertest.CreateMessage("example.com", dns.TypeA)
+ resp := new(dns.Msg).SetReply(req)
+
+ ctx := context.Background()
+ ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{})
+ ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
+ ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
+ ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{})
+
+ const ttl = 100
+
+ testCases := []struct {
+ name string
+ req *dns.Msg
+ resp *dns.Msg
+ wantName string
+ wantAns []dns.RR
+ }{{
+ name: "no_changes",
+ req: dnsservertest.CreateMessage("example.com.", dns.TypeA),
+ resp: resp,
+ wantName: "example.com.",
+ wantAns: nil,
+ }, {
+ name: "android-tls-metric",
+ req: dnsservertest.CreateMessage(
+ "12345678-dnsotls-ds.metric.gstatic.com.",
+ dns.TypeA,
+ ),
+ resp: resp,
+ wantName: "00000000-dnsotls-ds.metric.gstatic.com.",
+ wantAns: nil,
+ }, {
+ name: "android-https-metric",
+ req: dnsservertest.CreateMessage(
+ "123456-dnsohttps-ds.metric.gstatic.com.",
+ dns.TypeA,
+ ),
+ resp: resp,
+ wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
+ wantAns: nil,
+ }, {
+ name: "multiple_answers_metric",
+ req: dnsservertest.CreateMessage(
+ "123456-dnsohttps-ds.metric.gstatic.com.",
+ dns.TypeA,
+ ),
+ resp: dnsservertest.NewResp(
+ dns.RcodeSuccess,
+ req,
+ dnsservertest.RRSection{
+ RRs: []dns.RR{dnsservertest.NewA(
+ "123456-dnsohttps-ds.metric.gstatic.com.",
+ ttl,
+ net.IP{1, 2, 3, 4},
+ ), dnsservertest.NewA(
+ "654321-dnsohttps-ds.metric.gstatic.com.",
+ ttl,
+ net.IP{1, 2, 3, 5},
+ )},
+ Sec: dnsservertest.SectionAnswer,
+ },
+ ),
+ wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
+ wantAns: []dns.RR{
+ dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}),
+ dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}),
+ },
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ handler := dnsserver.HandlerFunc(func(
+ ctx context.Context,
+ rw dnsserver.ResponseWriter,
+ req *dns.Msg,
+ ) error {
+ assert.Equal(t, tc.wantName, req.Question[0].Name)
+
+ return rw.WriteMsg(ctx, req, tc.resp)
+ })
+ h := mw.Wrap(handler)
+
+ rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
+
+ err := h.ServeDNS(ctx, rw, tc.req)
+ require.NoError(t, err)
+
+ msg := rw.Msg()
+ require.NotNil(t, msg)
+
+ assert.Equal(t, tc.wantAns, msg.Answer)
+ })
+ }
+}
diff --git a/internal/dnssvc/specialdomain.go b/internal/dnssvc/specialdomain.go
index 2c5d605..2803626 100644
--- a/internal/dnssvc/specialdomain.go
+++ b/internal/dnssvc/specialdomain.go
@@ -32,57 +32,23 @@ const (
// Resolvers for querying the resolver with unknown or absent name.
ddrDomain = ddrLabel + "." + resolverArpaDomain
- // firefoxCanaryFQDN is the fully-qualified canary domain that Firefox uses
- // to check if it should use its own DNS-over-HTTPS settings.
+ // firefoxCanaryHost is the hostname that Firefox uses to check if it
+ // should use its own DNS-over-HTTPS settings.
//
// See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https.
- firefoxCanaryFQDN = "use-application-dns.net."
-
- // applePrivateRelayMaskHost and applePrivateRelayMaskH2Host are the
- // hostnames that Apple devices use to check if Apple Private Relay can be
- // enabled. Returning NXDOMAIN to queries for these domain names blocks
- // Apple Private Relay.
- //
- // See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
- applePrivateRelayMaskHost = "mask.icloud.com"
- applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
+ firefoxCanaryHost = "use-application-dns.net"
)
-// noReqInfoSpecialHandler returns a handler that can handle a special-domain
-// query based only on its question type, class, and target, as well as the
-// handler's name for debugging.
-func (mw *initMw) noReqInfoSpecialHandler(
- fqdn string,
- qt dnsmsg.RRType,
- cl dnsmsg.Class,
-) (f dnsserver.HandlerFunc, name string) {
- if cl != dns.ClassINET {
- return nil, ""
- }
-
- if (qt == dns.TypeA || qt == dns.TypeAAAA) && fqdn == firefoxCanaryFQDN {
- return mw.handleFirefoxCanary, "firefox"
- }
-
- return nil, ""
-}
-
-// Firefox Canary
-
-// handleFirefoxCanary checks if the request is for the fully-qualified domain
-// name that Firefox uses to check DoH settings and writes a response if needed.
-func (mw *initMw) handleFirefoxCanary(
- ctx context.Context,
- rw dnsserver.ResponseWriter,
- req *dns.Msg,
-) (err error) {
- metrics.DNSSvcFirefoxRequestsTotal.Inc()
-
- resp := mw.messages.NewMsgREFUSED(req)
- err = rw.WriteMsg(ctx, req, resp)
-
- return errors.Annotate(err, "writing firefox canary resp: %w")
-}
+// Hostnames that Apple devices use to check if Apple Private Relay can be
+// enabled. Returning NXDOMAIN to queries for these domain names blocks Apple
+// Private Relay.
+//
+// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay.
+const (
+ applePrivateRelayMaskHost = "mask.icloud.com"
+ applePrivateRelayMaskH2Host = "mask-h2.icloud.com"
+ applePrivateRelayMaskCanaryHost = "mask-canary.icloud.com"
+)
// reqInfoSpecialHandler returns a handler that can handle a special-domain
// query based on the request info, as well as the handler's name for debugging.
@@ -99,11 +65,9 @@ func (mw *initMw) reqInfoSpecialHandler(
} else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) {
// A badly formed resolver.arpa subdomain query.
return mw.handleBadResolverARPA, "bad_resolver_arpa"
- } else if shouldBlockPrivateRelay(ri) {
- return mw.handlePrivateRelay, "apple_private_relay"
}
- return nil, ""
+ return mw.specialDomainHandler(ri)
}
// reqInfoHandlerFunc is an alias for handler functions that additionally accept
@@ -222,24 +186,46 @@ func (mw *initMw) handleBadResolverARPA(
return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host)
}
+// specialDomainHandler returns a handler that can handle a special-domain
+// query for Apple Private Relay or Firefox canary domain based on the request
+// or profile information, as well as the handler's name for debugging.
+func (mw *initMw) specialDomainHandler(
+ ri *agd.RequestInfo,
+) (f reqInfoHandlerFunc, name string) {
+ qt := ri.QType
+ if qt != dns.TypeA && qt != dns.TypeAAAA {
+ return nil, ""
+ }
+
+ host := ri.Host
+ prof := ri.Profile
+
+ switch host {
+ case
+ applePrivateRelayMaskHost,
+ applePrivateRelayMaskH2Host,
+ applePrivateRelayMaskCanaryHost:
+ if shouldBlockPrivateRelay(ri, prof) {
+ return mw.handlePrivateRelay, "apple_private_relay"
+ }
+ case firefoxCanaryHost:
+ if shouldBlockFirefoxCanary(ri, prof) {
+ return mw.handleFirefoxCanary, "firefox"
+ }
+ default:
+ // Go on.
+ }
+
+ return nil, ""
+}
+
// Apple Private Relay
// shouldBlockPrivateRelay returns true if the query is for an Apple Private
-// Relay check domain and the request information indicates that Apple Private
-// Relay should be blocked.
-func shouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
- qt := ri.QType
- host := ri.Host
-
- return (qt == dns.TypeA || qt == dns.TypeAAAA) &&
- (host == applePrivateRelayMaskHost || host == applePrivateRelayMaskH2Host) &&
- reqInfoShouldBlockPrivateRelay(ri)
-}
-
-// reqInfoShouldBlockPrivateRelay returns true if Apple Private Relay queries
-// should be blocked based on the request information.
-func reqInfoShouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) {
- if prof := ri.Profile; prof != nil {
+// Relay check domain and the request information or profile indicates that
+// Apple Private Relay should be blocked.
+func shouldBlockPrivateRelay(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
+ if prof != nil {
return prof.BlockPrivateRelay
}
@@ -260,3 +246,32 @@ func (mw *initMw) handlePrivateRelay(
return errors.Annotate(err, "writing private relay resp: %w")
}
+
+// Firefox canary domain
+
+// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary
+// domain and the request information or profile indicates that Firefox canary
+// domain should be blocked.
+func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) {
+ if prof != nil {
+ return prof.BlockFirefoxCanary
+ }
+
+ return ri.FilteringGroup.BlockFirefoxCanary
+}
+
+// handleFirefoxCanary checks if the request is for the fully-qualified domain
+// name that Firefox uses to check DoH settings and writes a response if needed.
+func (mw *initMw) handleFirefoxCanary(
+ ctx context.Context,
+ rw dnsserver.ResponseWriter,
+ req *dns.Msg,
+ ri *agd.RequestInfo,
+) (err error) {
+ metrics.DNSSvcFirefoxRequestsTotal.Inc()
+
+ resp := mw.messages.NewMsgREFUSED(req)
+ err = rw.WriteMsg(ctx, req, resp)
+
+ return errors.Annotate(err, "writing firefox canary resp: %w")
+}
diff --git a/internal/ecscache/cache.go b/internal/ecscache/cache.go
index e59c900..860a504 100644
--- a/internal/ecscache/cache.go
+++ b/internal/ecscache/cache.go
@@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
"github.com/miekg/dns"
)
@@ -93,22 +94,16 @@ func (mw *Middleware) toCacheKey(cr *cacheRequest, hostHasECS bool) (key uint64)
var buf [6]byte
binary.LittleEndian.PutUint16(buf[:2], cr.qType)
binary.LittleEndian.PutUint16(buf[2:4], cr.qClass)
- if cr.reqDO {
- buf[4] = 1
- } else {
- buf[4] = 0
- }
- if cr.subnet.Addr().Is4() {
- buf[5] = 0
- } else {
- buf[5] = 1
- }
+ buf[4] = mathutil.BoolToNumber[byte](cr.reqDO)
+
+ addr := cr.subnet.Addr()
+ buf[5] = mathutil.BoolToNumber[byte](addr.Is6())
_, _ = h.Write(buf[:])
if hostHasECS {
- _, _ = h.Write(cr.subnet.Addr().AsSlice())
+ _, _ = h.Write(addr.AsSlice())
_ = h.WriteByte(byte(cr.subnet.Bits()))
}
diff --git a/internal/ecscache/msg.go b/internal/ecscache/msg.go
index 1730302..5e7611d 100644
--- a/internal/ecscache/msg.go
+++ b/internal/ecscache/msg.go
@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
@@ -269,9 +270,5 @@ func getTTLIfLower(r dns.RR, ttl uint32) (res uint32) {
// Go on.
}
- if httl := r.Header().Ttl; httl < ttl {
- return httl
- }
-
- return ttl
+ return mathutil.Min(r.Header().Ttl, ttl)
}
diff --git a/internal/errcoll/sentry.go b/internal/errcoll/sentry.go
index 26f7fdf..e4eb994 100644
--- a/internal/errcoll/sentry.go
+++ b/internal/errcoll/sentry.go
@@ -67,13 +67,13 @@ type SentryReportableError interface {
// TODO(a.garipov): Make sure that we use this approach everywhere.
func isReportable(err error) (ok bool) {
var (
- ravErr SentryReportableError
- fwdErr *forward.Error
- dnsWErr *dnsserver.WriteError
+ sentryRepErr SentryReportableError
+ fwdErr *forward.Error
+ dnsWErr *dnsserver.WriteError
)
- if errors.As(err, &ravErr) {
- return ravErr.IsSentryReportable()
+ if errors.As(err, &sentryRepErr) {
+ return sentryRepErr.IsSentryReportable()
} else if errors.As(err, &fwdErr) {
return isReportableNetwork(fwdErr.Err)
} else if errors.As(err, &dnsWErr) {
diff --git a/internal/filter/compfilter.go b/internal/filter/compfilter.go
index 090def6..cbfa521 100644
--- a/internal/filter/compfilter.go
+++ b/internal/filter/compfilter.go
@@ -21,8 +21,8 @@ var _ Interface = (*compFilter)(nil)
// compFilter is a composite filter based on several types of safe search
// filters and rule lists.
type compFilter struct {
- safeBrowsing *hashPrefixFilter
- adultBlocking *hashPrefixFilter
+ safeBrowsing *HashPrefix
+ adultBlocking *HashPrefix
genSafeSearch *safeSearch
ytSafeSearch *safeSearch
diff --git a/internal/filter/filter.go b/internal/filter/filter.go
index 6b8dfab..cdc402a 100644
--- a/internal/filter/filter.go
+++ b/internal/filter/filter.go
@@ -18,10 +18,11 @@ import (
// maxFilterSize is the maximum size of downloaded filters.
const maxFilterSize = 196 * int64(datasize.MB)
-// defaultTimeout is the default timeout to use when fetching filter data.
+// defaultFilterRefreshTimeout is the default timeout to use when fetching
+// filter lists data.
//
// TODO(a.garipov): Consider making timeouts where they are used configurable.
-const defaultTimeout = 30 * time.Second
+const defaultFilterRefreshTimeout = 180 * time.Second
// defaultResolveTimeout is the default timeout for resolving hosts for safe
// search and safe browsing filters.
diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go
index 300c3d7..39383a0 100644
--- a/internal/filter/filter_test.go
+++ b/internal/filter/filter_test.go
@@ -181,24 +181,18 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) {
FilterIndexURL: fltsURL,
GeneralSafeSearchRulesURL: ssURL,
YoutubeSafeSearchRulesURL: ssURL,
- SafeBrowsing: &filter.HashPrefixConfig{
- CacheTTL: 1 * time.Hour,
- CacheSize: 100,
- },
- AdultBlocking: &filter.HashPrefixConfig{
- CacheTTL: 1 * time.Hour,
- CacheSize: 100,
- },
- Now: time.Now,
- ErrColl: nil,
- Resolver: nil,
- CacheDir: cacheDir,
- CustomFilterCacheSize: 100,
- SafeSearchCacheSize: 100,
- SafeSearchCacheTTL: 1 * time.Hour,
- RuleListCacheSize: 100,
- RefreshIvl: testRefreshIvl,
- UseRuleListCache: false,
+ SafeBrowsing: &filter.HashPrefix{},
+ AdultBlocking: &filter.HashPrefix{},
+ Now: time.Now,
+ ErrColl: nil,
+ Resolver: nil,
+ CacheDir: cacheDir,
+ CustomFilterCacheSize: 100,
+ SafeSearchCacheSize: 100,
+ SafeSearchCacheTTL: 1 * time.Hour,
+ RuleListCacheSize: 100,
+ RefreshIvl: testRefreshIvl,
+ UseRuleListCache: false,
}
}
diff --git a/internal/filter/hashprefix.go b/internal/filter/hashprefix.go
index 3bfea1c..3d0dd48 100644
--- a/internal/filter/hashprefix.go
+++ b/internal/filter/hashprefix.go
@@ -3,13 +3,17 @@ package filter
import (
"context"
"fmt"
+ "net/url"
"strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+ "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
@@ -21,13 +25,32 @@ import (
// HashPrefixConfig is the hash-prefix filter configuration structure.
type HashPrefixConfig struct {
// Hashes are the hostname hashes for this filter.
- Hashes *HashStorage
+ Hashes *hashstorage.Storage
+
+ // URL is the URL used to update the filter.
+ URL *url.URL
+
+ // ErrColl is used to collect non-critical and rare errors.
+ ErrColl agd.ErrorCollector
+
+ // Resolver is used to resolve hosts for the hash-prefix filter.
+ Resolver agdnet.Resolver
+
+ // ID is the ID of this hash storage for logging and error reporting.
+ ID agd.FilterListID
+
+ // CachePath is the path to the file containing the cached filtered
+ // hostnames, one per line.
+ CachePath string
// ReplacementHost is the replacement host for this filter. Queries
// matched by the filter receive a response with the IP addresses of
// this host.
ReplacementHost string
+ // Staleness is the time after which a file is considered stale.
+ Staleness time.Duration
+
// CacheTTL is the time-to-live value used to cache the results of the
// filter.
//
@@ -38,57 +61,58 @@ type HashPrefixConfig struct {
CacheSize int
}
-// hashPrefixFilter is a filter that matches hosts by their hashes based on
-// a hash-prefix table.
-type hashPrefixFilter struct {
- hashes *HashStorage
+// HashPrefix is a filter that matches hosts by their hashes based on a
+// hash-prefix table.
+type HashPrefix struct {
+ hashes *hashstorage.Storage
+ refr *refreshableFilter
resCache *resultcache.Cache[*ResultModified]
resolver agdnet.Resolver
errColl agd.ErrorCollector
repHost string
- id agd.FilterListID
}
-// newHashPrefixFilter returns a new hash-prefix filter. c must not be nil.
-func newHashPrefixFilter(
- c *HashPrefixConfig,
- resolver agdnet.Resolver,
- errColl agd.ErrorCollector,
- id agd.FilterListID,
-) (f *hashPrefixFilter) {
- f = &hashPrefixFilter{
- hashes: c.Hashes,
+// NewHashPrefix returns a new hash-prefix filter. c must not be nil.
+func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) {
+ f = &HashPrefix{
+ hashes: c.Hashes,
+ refr: &refreshableFilter{
+ http: agdhttp.NewClient(&agdhttp.ClientConfig{
+ Timeout: defaultFilterRefreshTimeout,
+ }),
+ url: c.URL,
+ id: c.ID,
+ cachePath: c.CachePath,
+ typ: "hash storage",
+ staleness: c.Staleness,
+ },
resCache: resultcache.New[*ResultModified](c.CacheSize),
- resolver: resolver,
- errColl: errColl,
+ resolver: c.Resolver,
+ errColl: c.ErrColl,
repHost: c.ReplacementHost,
- id: id,
}
- // Patch the refresh function of the hash storage, if there is one, to make
- // sure that we clear the result cache during every refresh.
- //
- // TODO(a.garipov): Create a better way to do that than this spaghetti-style
- // patching. Perhaps, lift the entire logic of hash-storage refresh into
- // hashPrefixFilter.
- if f.hashes != nil && f.hashes.refr != nil {
- resetRules := f.hashes.refr.resetRules
- f.hashes.refr.resetRules = func(text string) (err error) {
- f.resCache.Clear()
+ f.refr.resetRules = f.resetRules
- return resetRules(text)
- }
+ err = f.refresh(context.Background(), true)
+ if err != nil {
+ return nil, err
}
- return f
+ return f, nil
+}
+
+// id returns the ID of the hash storage.
+func (f *HashPrefix) id() (fltID agd.FilterListID) {
+ return f.refr.id
}
// type check
-var _ qtHostFilter = (*hashPrefixFilter)(nil)
+var _ qtHostFilter = (*HashPrefix)(nil)
// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
// modifies the response if host matches f.
-func (f *hashPrefixFilter) filterReq(
+func (f *HashPrefix) filterReq(
ctx context.Context,
ri *agd.RequestInfo,
req *dns.Msg,
@@ -115,7 +139,7 @@ func (f *hashPrefixFilter) filterReq(
var matched string
sub := hashableSubdomains(host)
for _, s := range sub {
- if f.hashes.hashMatches(s) {
+ if f.hashes.Matches(s) {
matched = s
break
@@ -134,19 +158,19 @@ func (f *hashPrefixFilter) filterReq(
var result *dns.Msg
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
if err != nil {
- agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err)
+ agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err)
result = ri.Messages.NewMsgSERVFAIL(req)
} else {
result, err = ri.Messages.NewIPRespMsg(req, ips...)
if err != nil {
- return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err)
+ return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err)
}
}
rm = &ResultModified{
Msg: result,
- List: f.id,
+ List: f.id(),
Rule: agd.FilterRuleText(matched),
}
@@ -161,21 +185,21 @@ func (f *hashPrefixFilter) filterReq(
}
// updateCacheSizeMetrics updates cache size metrics.
-func (f *hashPrefixFilter) updateCacheSizeMetrics(size int) {
- switch f.id {
+func (f *HashPrefix) updateCacheSizeMetrics(size int) {
+ switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing:
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
case agd.FilterListIDAdultBlocking:
metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size))
default:
- panic(fmt.Errorf("unsupported FilterListID %s", f.id))
+ panic(fmt.Errorf("unsupported FilterListID %s", id))
}
}
// updateCacheLookupsMetrics updates cache lookups metrics.
-func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
+func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
var hitsMetric, missesMetric prometheus.Counter
- switch f.id {
+ switch id := f.id(); id {
case agd.FilterListIDSafeBrowsing:
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
@@ -183,7 +207,7 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
default:
- panic(fmt.Errorf("unsupported FilterListID %s", f.id))
+ panic(fmt.Errorf("unsupported FilterListID %s", id))
}
if hit {
@@ -194,12 +218,50 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) {
}
// name implements the qtHostFilter interface for *hashPrefixFilter.
-func (f *hashPrefixFilter) name() (n string) {
+func (f *HashPrefix) name() (n string) {
if f == nil {
return ""
}
- return string(f.id)
+ return string(f.id())
+}
+
+// type check
+var _ agd.Refresher = (*HashPrefix)(nil)
+
+// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter.
+func (f *HashPrefix) Refresh(ctx context.Context) (err error) {
+ return f.refresh(ctx, false)
+}
+
+// refresh reloads the hash filter data. If acceptStale is true, do not try to
+// load the list from its URL when there is already a file in the cache
+// directory, regardless of its staleness.
+func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) {
+ return f.refr.refresh(ctx, acceptStale)
+}
+
+// resetRules resets the hosts in the index.
+func (f *HashPrefix) resetRules(text string) (err error) {
+ n, err := f.hashes.Reset(text)
+
+ // Report the filter update to prometheus.
+ promLabels := prometheus.Labels{
+ "filter": string(f.id()),
+ }
+
+ metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
+
+ if err != nil {
+ return err
+ }
+
+ metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
+ metrics.FilterRulesTotal.With(promLabels).Set(float64(n))
+
+ log.Info("filter %s: reset %d hosts", f.id(), n)
+
+ return nil
}
// subDomainNum defines how many labels should be hashed to match against a hash
diff --git a/internal/filter/hashprefix_test.go b/internal/filter/hashprefix_test.go
index e5e2c86..fac4d93 100644
--- a/internal/filter/hashprefix_test.go
+++ b/internal/filter/hashprefix_test.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -21,8 +22,10 @@ import (
func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
cacheDir := t.TempDir()
cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
- hosts := "scam.example.net\n"
- err := os.WriteFile(cachePath, []byte(hosts), 0o644)
+ err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
+ require.NoError(t, err)
+
+ hashes, err := hashstorage.New("")
require.NoError(t, err)
errColl := &agdtest.ErrorCollector{
@@ -31,46 +34,33 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
},
}
- hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
- URL: nil,
- ErrColl: errColl,
- ID: agd.FilterListIDSafeBrowsing,
- CachePath: cachePath,
- RefreshIvl: testRefreshIvl,
- })
- require.NoError(t, err)
-
- ctx := context.Background()
- err = hashes.Start()
- require.NoError(t, err)
- testutil.CleanupAndRequireSuccess(t, func() (err error) {
- return hashes.Shutdown(ctx)
- })
-
- // Fake Data
-
- onLookupIP := func(
- _ context.Context,
- _ netutil.AddrFamily,
- _ string,
- ) (ips []net.IP, err error) {
- return []net.IP{safeBrowsingSafeIP4}, nil
+ resolver := &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ return []net.IP{safeBrowsingSafeIP4}, nil
+ },
}
c := prepareConf(t)
- c.SafeBrowsing = &filter.HashPrefixConfig{
+ c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
Hashes: hashes,
+ ErrColl: errColl,
+ Resolver: resolver,
+ ID: agd.FilterListIDSafeBrowsing,
+ CachePath: cachePath,
ReplacementHost: safeBrowsingSafeHost,
+ Staleness: 1 * time.Hour,
CacheTTL: 10 * time.Second,
CacheSize: 100,
- }
+ })
+ require.NoError(t, err)
c.ErrColl = errColl
-
- c.Resolver = &agdtest.Resolver{
- OnLookupIP: onLookupIP,
- }
+ c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@@ -93,7 +83,7 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
}
ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
- ctx = agd.ContextWithRequestInfo(ctx, ri)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
diff --git a/internal/filter/hashstorage.go b/internal/filter/hashstorage.go
deleted file mode 100644
index 9586b45..0000000
--- a/internal/filter/hashstorage.go
+++ /dev/null
@@ -1,352 +0,0 @@
-package filter
-
-import (
- "bufio"
- "context"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
- "net/url"
- "strings"
- "sync"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
- "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
- "github.com/AdguardTeam/golibs/log"
- "github.com/AdguardTeam/golibs/stringutil"
- "github.com/prometheus/client_golang/prometheus"
-)
-
-// Hash Storage
-
-// Hash and hash part length constants.
-const (
- // hashPrefixLen is the length of the prefix of the hash of the filtered
- // hostname.
- hashPrefixLen = 2
-
- // HashPrefixEncLen is the encoded length of the hash prefix. Two text
- // bytes per one binary byte.
- HashPrefixEncLen = hashPrefixLen * 2
-
- // hashLen is the length of the whole hash of the checked hostname.
- hashLen = sha256.Size
-
- // hashSuffixLen is the length of the suffix of the hash of the filtered
- // hostname.
- hashSuffixLen = hashLen - hashPrefixLen
-
- // hashEncLen is the encoded length of the hash. Two text bytes per one
- // binary byte.
- hashEncLen = hashLen * 2
-
- // legacyHashPrefixEncLen is the encoded length of a legacy hash.
- legacyHashPrefixEncLen = 8
-)
-
-// hashPrefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of
-// a host being checked.
-type hashPrefix [hashPrefixLen]byte
-
-// hashSuffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of
-// a host being checked.
-type hashSuffix [hashSuffixLen]byte
-
-// HashStorage is a storage for hashes of the filtered hostnames.
-type HashStorage struct {
- // mu protects hashSuffixes.
- mu *sync.RWMutex
- hashSuffixes map[hashPrefix][]hashSuffix
-
- // refr contains data for refreshing the filter.
- refr *refreshableFilter
-
- refrWorker *agd.RefreshWorker
-}
-
-// HashStorageConfig is the configuration structure for a *HashStorage.
-type HashStorageConfig struct {
- // URL is the URL used to update the filter.
- URL *url.URL
-
- // ErrColl is used to collect non-critical and rare errors.
- ErrColl agd.ErrorCollector
-
- // ID is the ID of this hash storage for logging and error reporting.
- ID agd.FilterListID
-
- // CachePath is the path to the file containing the cached filtered
- // hostnames, one per line.
- CachePath string
-
- // RefreshIvl is the refresh interval.
- RefreshIvl time.Duration
-}
-
-// NewHashStorage returns a new hash storage containing hashes of all hostnames.
-func NewHashStorage(c *HashStorageConfig) (hs *HashStorage, err error) {
- hs = &HashStorage{
- mu: &sync.RWMutex{},
- hashSuffixes: map[hashPrefix][]hashSuffix{},
- refr: &refreshableFilter{
- http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultTimeout,
- }),
- url: c.URL,
- id: c.ID,
- cachePath: c.CachePath,
- typ: "hash storage",
- refreshIvl: c.RefreshIvl,
- },
- }
-
- // Do not set this in the literal above, since hs is nil there.
- hs.refr.resetRules = hs.resetHosts
-
- refrWorker := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
- Context: func() (ctx context.Context, cancel context.CancelFunc) {
- return context.WithTimeout(context.Background(), defaultTimeout)
- },
- Refresher: hs,
- ErrColl: c.ErrColl,
- Name: string(c.ID),
- Interval: c.RefreshIvl,
- RefreshOnShutdown: false,
- RoutineLogsAreDebug: false,
- })
-
- hs.refrWorker = refrWorker
-
- err = hs.refresh(context.Background(), true)
- if err != nil {
- return nil, fmt.Errorf("initializing %s: %w", c.ID, err)
- }
-
- return hs, nil
-}
-
-// hashes returns all hashes starting with the given prefixes, if any. The
-// resulting slice shares storage for all underlying strings.
-//
-// TODO(a.garipov): This currently doesn't take duplicates into account.
-func (hs *HashStorage) hashes(hps []hashPrefix) (hashes []string) {
- if len(hps) == 0 {
- return nil
- }
-
- hs.mu.RLock()
- defer hs.mu.RUnlock()
-
- // First, calculate the number of hashes to allocate the buffer.
- l := 0
- for _, hp := range hps {
- hashSufs := hs.hashSuffixes[hp]
- l += len(hashSufs)
- }
-
- // Then, allocate the buffer of the appropriate size and write all hashes
- // into one big buffer and slice it into separate strings to make the
- // garbage collector's work easier. This assumes that all references to
- // this buffer will become unreachable at the same time.
- //
- // The fact that we iterate over the map twice shouldn't matter, since we
- // assume that len(hps) will be below 5 most of the time.
- b := &strings.Builder{}
- b.Grow(l * hashEncLen)
-
- // Use a buffer and write the resulting buffer into b directly instead of
- // using hex.NewEncoder, because that seems to incur a significant
- // performance hit.
- var buf [hashEncLen]byte
- for _, hp := range hps {
- hashSufs := hs.hashSuffixes[hp]
- for _, suf := range hashSufs {
- // Slicing is safe here, since the contents of hp and suf are being
- // encoded.
-
- // nolint:looppointer
- hex.Encode(buf[:], hp[:])
- // nolint:looppointer
- hex.Encode(buf[HashPrefixEncLen:], suf[:])
- _, _ = b.Write(buf[:])
- }
- }
-
- s := b.String()
- hashes = make([]string, 0, l)
- for i := 0; i < l; i++ {
- hashes = append(hashes, s[i*hashEncLen:(i+1)*hashEncLen])
- }
-
- return hashes
-}
-
-// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
-// concurrent use.
-func (hs *HashStorage) loadHashSuffixes(hp hashPrefix) (sufs []hashSuffix, ok bool) {
- hs.mu.RLock()
- defer hs.mu.RUnlock()
-
- sufs, ok = hs.hashSuffixes[hp]
-
- return sufs, ok
-}
-
-// hashMatches returns true if the host matches one of the hashes.
-func (hs *HashStorage) hashMatches(host string) (ok bool) {
- sum := sha256.Sum256([]byte(host))
- hp := hashPrefix{sum[0], sum[1]}
-
- var buf [hashLen]byte
- hashSufs, ok := hs.loadHashSuffixes(hp)
- if !ok {
- return false
- }
-
- copy(buf[:], hp[:])
- for _, suf := range hashSufs {
- // Slicing is safe here, because we make a copy.
-
- // nolint:looppointer
- copy(buf[hashPrefixLen:], suf[:])
- if buf == sum {
- return true
- }
- }
-
- return false
-}
-
-// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
-func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashPrefix, err error) {
- if prefixesStr == "" {
- return nil, nil
- }
-
- prefixSet := stringutil.NewSet()
- prefixStrs := strings.Split(prefixesStr, ".")
- for _, s := range prefixStrs {
- if len(s) != HashPrefixEncLen {
- // Some legacy clients send eight-character hashes instead of
- // four-character ones. For now, remove the final four characters.
- //
- // TODO(a.garipov): Either remove this crutch or support such
- // prefixes better.
- if len(s) == legacyHashPrefixEncLen {
- s = s[:HashPrefixEncLen]
- } else {
- return nil, fmt.Errorf("bad hash len for %q", s)
- }
- }
-
- prefixSet.Add(s)
- }
-
- hashPrefixes = make([]hashPrefix, prefixSet.Len())
- prefixStrs = prefixSet.Values()
- for i, s := range prefixStrs {
- _, err = hex.Decode(hashPrefixes[i][:], []byte(s))
- if err != nil {
- return nil, fmt.Errorf("bad hash encoding for %q", s)
- }
- }
-
- return hashPrefixes, nil
-}
-
-// type check
-var _ agd.Refresher = (*HashStorage)(nil)
-
-// Refresh implements the agd.Refresher interface for *HashStorage. If the file
-// at the storage's path exists and its mtime shows that it's still fresh, it
-// loads the data from the file. Otherwise, it uses the URL of the storage.
-func (hs *HashStorage) Refresh(ctx context.Context) (err error) {
- err = hs.refresh(ctx, false)
-
- // Report the filter update to prometheus.
- promLabels := prometheus.Labels{
- "filter": string(hs.id()),
- }
-
- metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
-
- if err == nil {
- metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
-
- // Count the total number of hashes loaded.
- count := 0
- for _, v := range hs.hashSuffixes {
- count += len(v)
- }
-
- metrics.FilterRulesTotal.With(promLabels).Set(float64(count))
- }
-
- return err
-}
-
-// id returns the ID of the hash storage.
-func (hs *HashStorage) id() (fltID agd.FilterListID) {
- return hs.refr.id
-}
-
-// refresh reloads the hash filter data. If acceptStale is true, do not try to
-// load the list from its URL when there is already a file in the cache
-// directory, regardless of its staleness.
-func (hs *HashStorage) refresh(ctx context.Context, acceptStale bool) (err error) {
- return hs.refr.refresh(ctx, acceptStale)
-}
-
-// resetHosts resets the hosts in the index.
-func (hs *HashStorage) resetHosts(hostsStr string) (err error) {
- hs.mu.Lock()
- defer hs.mu.Unlock()
-
- // Delete all elements without allocating a new map to safe space and
- // performance.
- //
- // This is optimized, see https://github.com/golang/go/issues/20138.
- for hp := range hs.hashSuffixes {
- delete(hs.hashSuffixes, hp)
- }
-
- var n int
- s := bufio.NewScanner(strings.NewReader(hostsStr))
- for s.Scan() {
- host := s.Text()
- if len(host) == 0 || host[0] == '#' {
- continue
- }
-
- sum := sha256.Sum256([]byte(host))
- hp := hashPrefix{sum[0], sum[1]}
-
- // TODO(a.garipov): Convert to array directly when proposal
- // golang/go#46505 is implemented in Go 1.20.
- suf := *(*hashSuffix)(sum[hashPrefixLen:])
- hs.hashSuffixes[hp] = append(hs.hashSuffixes[hp], suf)
-
- n++
- }
-
- err = s.Err()
- if err != nil {
- return fmt.Errorf("scanning hosts: %w", err)
- }
-
- log.Info("filter %s: reset %d hosts", hs.id(), n)
-
- return nil
-}
-
-// Start implements the agd.Service interface for *HashStorage.
-func (hs *HashStorage) Start() (err error) {
- return hs.refrWorker.Start()
-}
-
-// Shutdown implements the agd.Service interface for *HashStorage.
-func (hs *HashStorage) Shutdown(ctx context.Context) (err error) {
- return hs.refrWorker.Shutdown(ctx)
-}
diff --git a/internal/filter/hashstorage/hashstorage.go b/internal/filter/hashstorage/hashstorage.go
new file mode 100644
index 0000000..8234279
--- /dev/null
+++ b/internal/filter/hashstorage/hashstorage.go
@@ -0,0 +1,203 @@
+// Package hashstorage defines a storage of hashes of domain names used for
+// filtering.
+package hashstorage
+
+import (
+ "bufio"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "sync"
+)
+
+// Hash and hash part length constants.
+const (
+ // PrefixLen is the length of the prefix of the hash of the filtered
+ // hostname.
+ PrefixLen = 2
+
+ // PrefixEncLen is the encoded length of the hash prefix. Two text
+ // bytes per one binary byte.
+ PrefixEncLen = PrefixLen * 2
+
+ // hashLen is the length of the whole hash of the checked hostname.
+ hashLen = sha256.Size
+
+ // suffixLen is the length of the suffix of the hash of the filtered
+ // hostname.
+ suffixLen = hashLen - PrefixLen
+
+ // hashEncLen is the encoded length of the hash. Two text bytes per one
+ // binary byte.
+ hashEncLen = hashLen * 2
+)
+
+// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a
+// host being checked.
+type Prefix [PrefixLen]byte
+
+// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a
+// host being checked.
+type suffix [suffixLen]byte
+
+// Storage stores hashes of the filtered hostnames. All methods are safe for
+// concurrent use.
+type Storage struct {
+ // mu protects hashSuffixes.
+ mu *sync.RWMutex
+ hashSuffixes map[Prefix][]suffix
+}
+
+// New returns a new hash storage containing hashes of the domain names listed
+// in hostnames, one domain name per line.
+func New(hostnames string) (s *Storage, err error) {
+ s = &Storage{
+ mu: &sync.RWMutex{},
+ hashSuffixes: map[Prefix][]suffix{},
+ }
+
+ if hostnames != "" {
+ _, err = s.Reset(hostnames)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return s, nil
+}
+
+// Hashes returns all hashes starting with the given prefixes, if any. The
+// resulting slice shares storage for all underlying strings.
+//
+// TODO(a.garipov): This currently doesn't take duplicates into account.
+func (s *Storage) Hashes(hps []Prefix) (hashes []string) {
+ if len(hps) == 0 {
+ return nil
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // First, calculate the number of hashes to allocate the buffer.
+ l := 0
+ for _, hp := range hps {
+ hashSufs := s.hashSuffixes[hp]
+ l += len(hashSufs)
+ }
+
+ // Then, allocate the buffer of the appropriate size and write all hashes
+ // into one big buffer and slice it into separate strings to make the
+ // garbage collector's work easier. This assumes that all references to
+ // this buffer will become unreachable at the same time.
+ //
+ // The fact that we iterate over the [s.hashSuffixes] map twice shouldn't
+ // matter, since we assume that len(hps) will be below 5 most of the time.
+ b := &strings.Builder{}
+ b.Grow(l * hashEncLen)
+
+ // Use a buffer and write the resulting buffer into b directly instead of
+ // using hex.NewEncoder, because that seems to incur a significant
+ // performance hit.
+ var buf [hashEncLen]byte
+ for _, hp := range hps {
+ hashSufs := s.hashSuffixes[hp]
+ for _, suf := range hashSufs {
+ // Slicing is safe here, since the contents of hp and suf are being
+ // encoded.
+
+ // nolint:looppointer
+ hex.Encode(buf[:], hp[:])
+ // nolint:looppointer
+ hex.Encode(buf[PrefixEncLen:], suf[:])
+ _, _ = b.Write(buf[:])
+ }
+ }
+
+ str := b.String()
+ hashes = make([]string, 0, l)
+ for i := 0; i < l; i++ {
+ hashes = append(hashes, str[i*hashEncLen:(i+1)*hashEncLen])
+ }
+
+ return hashes
+}
+
+// Matches returns true if the host matches one of the hashes.
+func (s *Storage) Matches(host string) (ok bool) {
+ sum := sha256.Sum256([]byte(host))
+ hp := *(*Prefix)(sum[:PrefixLen])
+
+ var buf [hashLen]byte
+ hashSufs, ok := s.loadHashSuffixes(hp)
+ if !ok {
+ return false
+ }
+
+ copy(buf[:], hp[:])
+ for _, suf := range hashSufs {
+ // Slicing is safe here, because we make a copy.
+
+ // nolint:looppointer
+ copy(buf[PrefixLen:], suf[:])
+ if buf == sum {
+ return true
+ }
+ }
+
+ return false
+}
+
+// Reset resets the hosts in the index using the domain names listed in
+// hostnames, one domain name per line, and returns the total number of
+// processed rules.
+func (s *Storage) Reset(hostnames string) (n int, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Delete all elements without allocating a new map to save space and
+ // improve performance.
+ //
+ // This is optimized, see https://github.com/golang/go/issues/20138.
+ //
+ // TODO(a.garipov): Use clear once golang/go#56351 is implemented.
+ for hp := range s.hashSuffixes {
+ delete(s.hashSuffixes, hp)
+ }
+
+ sc := bufio.NewScanner(strings.NewReader(hostnames))
+ for sc.Scan() {
+ host := sc.Text()
+ if len(host) == 0 || host[0] == '#' {
+ continue
+ }
+
+ sum := sha256.Sum256([]byte(host))
+ hp := *(*Prefix)(sum[:PrefixLen])
+
+ // TODO(a.garipov): Here and everywhere, convert to array directly when
+ // proposal golang/go#46505 is implemented in Go 1.20.
+ suf := *(*suffix)(sum[PrefixLen:])
+ s.hashSuffixes[hp] = append(s.hashSuffixes[hp], suf)
+
+ n++
+ }
+
+ err = sc.Err()
+ if err != nil {
+ return 0, fmt.Errorf("scanning hosts: %w", err)
+ }
+
+ return n, nil
+}
+
+// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for
+// concurrent use. sufs must not be modified.
+func (s *Storage) loadHashSuffixes(hp Prefix) (sufs []suffix, ok bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ sufs, ok = s.hashSuffixes[hp]
+
+ return sufs, ok
+}
diff --git a/internal/filter/hashstorage/hashstorage_test.go b/internal/filter/hashstorage/hashstorage_test.go
new file mode 100644
index 0000000..944d304
--- /dev/null
+++ b/internal/filter/hashstorage/hashstorage_test.go
@@ -0,0 +1,142 @@
+package hashstorage_test
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Common hostnames for tests.
+const (
+ testHost = "porn.example"
+ otherHost = "otherporn.example"
+)
+
+func TestStorage_Hashes(t *testing.T) {
+ s, err := hashstorage.New(testHost)
+ require.NoError(t, err)
+
+ h := sha256.Sum256([]byte(testHost))
+ want := []string{hex.EncodeToString(h[:])}
+
+ p := hashstorage.Prefix{h[0], h[1]}
+ got := s.Hashes([]hashstorage.Prefix{p})
+ assert.Equal(t, want, got)
+
+ wrong := s.Hashes([]hashstorage.Prefix{{}})
+ assert.Empty(t, wrong)
+}
+
+func TestStorage_Matches(t *testing.T) {
+ s, err := hashstorage.New(testHost)
+ require.NoError(t, err)
+
+ got := s.Matches(testHost)
+ assert.True(t, got)
+
+ got = s.Matches(otherHost)
+ assert.False(t, got)
+}
+
+func TestStorage_Reset(t *testing.T) {
+ s, err := hashstorage.New(testHost)
+ require.NoError(t, err)
+
+ n, err := s.Reset(otherHost)
+ require.NoError(t, err)
+
+ assert.Equal(t, 1, n)
+
+ h := sha256.Sum256([]byte(otherHost))
+ want := []string{hex.EncodeToString(h[:])}
+
+ p := hashstorage.Prefix{h[0], h[1]}
+ got := s.Hashes([]hashstorage.Prefix{p})
+ assert.Equal(t, want, got)
+
+ prevHash := sha256.Sum256([]byte(testHost))
+ prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}})
+ assert.Empty(t, prev)
+}
+
+// Sinks for benchmarks.
+var (
+ errSink error
+ strsSink []string
+)
+
+func BenchmarkStorage_Hashes(b *testing.B) {
+ const N = 10_000
+
+ var hosts []string
+ for i := 0; i < N; i++ {
+ hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
+ }
+
+ s, err := hashstorage.New(strings.Join(hosts, "\n"))
+ require.NoError(b, err)
+
+ var hashPrefixes []hashstorage.Prefix
+ for i := 0; i < 4; i++ {
+ hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]})
+ }
+
+ for n := 1; n <= 4; n++ {
+ b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
+ hps := hashPrefixes[:n]
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ strsSink = s.Hashes(hps)
+ }
+ })
+ }
+
+ // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op
+ // BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op
+ // BenchmarkStorage_Hashes/3-16 13492526 92.22 ns/op 0 B/op 0 allocs/op
+ // BenchmarkStorage_Hashes/4-16 9542425 109.2 ns/op 0 B/op 0 allocs/op
+}
+
+func BenchmarkStorage_ResetHosts(b *testing.B) {
+ const N = 1_000
+
+ var hosts []string
+ for i := 0; i < N; i++ {
+ hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
+ }
+
+ hostnames := strings.Join(hosts, "\n")
+ s, err := hashstorage.New(hostnames)
+ require.NoError(b, err)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, errSink = s.Reset(hostnames)
+ }
+
+ require.NoError(b, errSink)
+
+ // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op
+}
diff --git a/internal/filter/hashstorage_internal_test.go b/internal/filter/hashstorage_internal_test.go
deleted file mode 100644
index db61dce..0000000
--- a/internal/filter/hashstorage_internal_test.go
+++ /dev/null
@@ -1,91 +0,0 @@
-package filter
-
-import (
- "fmt"
- "strconv"
- "strings"
- "sync"
- "testing"
-
- "github.com/stretchr/testify/require"
-)
-
-var strsSink []string
-
-func BenchmarkHashStorage_Hashes(b *testing.B) {
- const N = 10_000
-
- var hosts []string
- for i := 0; i < N; i++ {
- hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
- }
-
- // Don't use a constructor, since we don't need the whole contents of the
- // storage.
- //
- // TODO(a.garipov): Think of a better way to do this.
- hs := &HashStorage{
- mu: &sync.RWMutex{},
- hashSuffixes: map[hashPrefix][]hashSuffix{},
- refr: &refreshableFilter{id: "test_filter"},
- }
-
- err := hs.resetHosts(strings.Join(hosts, "\n"))
- require.NoError(b, err)
-
- var hashPrefixes []hashPrefix
- for i := 0; i < 4; i++ {
- hashPrefixes = append(hashPrefixes, hashPrefix{hosts[i][0], hosts[i][1]})
- }
-
- for n := 1; n <= 4; n++ {
- b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) {
- hps := hashPrefixes[:n]
-
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- strsSink = hs.hashes(hps)
- }
- })
- }
-}
-
-func BenchmarkHashStorage_resetHosts(b *testing.B) {
- const N = 1_000
-
- var hosts []string
- for i := 0; i < N; i++ {
- hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i))
- }
-
- // Don't use a constructor, since we don't need the whole contents of the
- // storage.
- //
- // TODO(a.garipov): Think of a better way to do this.
- hs := &HashStorage{
- mu: &sync.RWMutex{},
- hashSuffixes: map[hashPrefix][]hashSuffix{},
- refr: &refreshableFilter{id: "test_filter"},
- }
-
- // Reset them once to fill the initial map.
- err := hs.resetHosts(strings.Join(hosts, "\n"))
- require.NoError(b, err)
-
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- err = hs.resetHosts(strings.Join(hosts, "\n"))
- }
-
- require.NoError(b, err)
-
- // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
- //
- // goos: linux
- // goarch: amd64
- // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter
- // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
- // BenchmarkHashStorage_resetHosts-16 1404 829505 ns/op 58289 B/op 1011 allocs/op
-}
diff --git a/internal/filter/internal/resultcache/resultcache.go b/internal/filter/internal/resultcache/resultcache.go
index ba5a2db..f230be5 100644
--- a/internal/filter/internal/resultcache/resultcache.go
+++ b/internal/filter/internal/resultcache/resultcache.go
@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
)
@@ -93,13 +94,9 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
// Save on allocations by reusing a buffer.
var buf [3]byte
binary.LittleEndian.PutUint16(buf[:2], qt)
- if isAns {
- buf[2] = 1
- } else {
- buf[2] = 0
- }
+ buf[2] = mathutil.BoolToNumber[byte](isAns)
- _, _ = h.Write(buf[:3])
+ _, _ = h.Write(buf[:])
return Key(h.Sum64())
}
diff --git a/internal/filter/refrfilter.go b/internal/filter/refrfilter.go
index 55ec1c4..f776266 100644
--- a/internal/filter/refrfilter.go
+++ b/internal/filter/refrfilter.go
@@ -45,9 +45,8 @@ type refreshableFilter struct {
// typ is the type of this filter used for logging and error reporting.
typ string
- // refreshIvl is the refresh interval for this filter. It is also used to
- // check if the cached file is fresh enough.
- refreshIvl time.Duration
+ // staleness is the time after which a file is considered stale.
+ staleness time.Duration
}
// refresh reloads the filter data. If acceptStale is true, refresh doesn't try
@@ -118,7 +117,7 @@ func (f *refreshableFilter) refreshFromFile(
return "", fmt.Errorf("reading filter file stat: %w", err)
}
- if mtime := fi.ModTime(); !mtime.Add(f.refreshIvl).After(time.Now()) {
+ if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) {
return "", nil
}
}
diff --git a/internal/filter/refrfilter_internal_test.go b/internal/filter/refrfilter_internal_test.go
index e989af0..8f33eb2 100644
--- a/internal/filter/refrfilter_internal_test.go
+++ b/internal/filter/refrfilter_internal_test.go
@@ -30,43 +30,43 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
name string
cachePath string
wantText string
- refreshIvl time.Duration
+ staleness time.Duration
acceptStale bool
}{{
name: "no_file",
cachePath: "does_not_exist",
wantText: "",
- refreshIvl: 0,
+ staleness: 0,
acceptStale: true,
}, {
name: "file",
cachePath: cachePath,
wantText: defaultText,
- refreshIvl: 0,
+ staleness: 0,
acceptStale: true,
}, {
name: "file_stale",
cachePath: cachePath,
wantText: "",
- refreshIvl: -1 * time.Second,
+ staleness: -1 * time.Second,
acceptStale: false,
}, {
name: "file_stale_accept",
cachePath: cachePath,
wantText: defaultText,
- refreshIvl: -1 * time.Second,
+ staleness: -1 * time.Second,
acceptStale: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
f := &refreshableFilter{
- http: nil,
- url: nil,
- id: "test_filter",
- cachePath: tc.cachePath,
- typ: "test filter",
- refreshIvl: tc.refreshIvl,
+ http: nil,
+ url: nil,
+ id: "test_filter",
+ cachePath: tc.cachePath,
+ typ: "test filter",
+ staleness: tc.staleness,
}
var text string
@@ -161,12 +161,12 @@ func TestRefreshableFilter_RefreshFromURL(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
f := &refreshableFilter{
- http: httpCli,
- url: u,
- id: "test_filter",
- cachePath: tc.cachePath,
- typ: "test filter",
- refreshIvl: testTimeout,
+ http: httpCli,
+ url: u,
+ id: "test_filter",
+ cachePath: tc.cachePath,
+ typ: "test filter",
+ staleness: testTimeout,
}
if tc.expectReq {
diff --git a/internal/filter/rulelist.go b/internal/filter/rulelist.go
index 367d00c..ba67be2 100644
--- a/internal/filter/rulelist.go
+++ b/internal/filter/rulelist.go
@@ -72,13 +72,13 @@ func newRuleListFilter(
mu: &sync.RWMutex{},
refr: &refreshableFilter{
http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultTimeout,
+ Timeout: defaultFilterRefreshTimeout,
}),
- url: l.URL,
- id: l.ID,
- cachePath: filepath.Join(fileCacheDir, string(l.ID)),
- typ: "rule list",
- refreshIvl: l.RefreshIvl,
+ url: l.URL,
+ id: l.ID,
+ cachePath: filepath.Join(fileCacheDir, string(l.ID)),
+ typ: "rule list",
+ staleness: l.RefreshIvl,
},
urlFilterID: newURLFilterID(),
}
diff --git a/internal/filter/safebrowsing.go b/internal/filter/safebrowsing.go
index a37fdec..2087e0a 100644
--- a/internal/filter/safebrowsing.go
+++ b/internal/filter/safebrowsing.go
@@ -2,9 +2,13 @@ package filter
import (
"context"
+ "encoding/hex"
+ "fmt"
"strings"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/golibs/stringutil"
)
// Safe Browsing TXT Record Server
@@ -14,12 +18,12 @@ import (
//
// TODO(a.garipov): Consider making an interface to simplify testing.
type SafeBrowsingServer struct {
- generalHashes *HashStorage
- adultBlockingHashes *HashStorage
+ generalHashes *hashstorage.Storage
+ adultBlockingHashes *hashstorage.Storage
}
// NewSafeBrowsingServer returns a new safe browsing DNS server.
-func NewSafeBrowsingServer(general, adultBlocking *HashStorage) (f *SafeBrowsingServer) {
+func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) {
return &SafeBrowsingServer{
generalHashes: general,
adultBlockingHashes: adultBlocking,
@@ -48,8 +52,7 @@ func (srv *SafeBrowsingServer) Hashes(
}
var prefixesStr string
- var strg *HashStorage
-
+ var strg *hashstorage.Storage
if strings.HasSuffix(host, GeneralTXTSuffix) {
prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
strg = srv.generalHashes
@@ -67,5 +70,45 @@ func (srv *SafeBrowsingServer) Hashes(
return nil, false, err
}
- return strg.hashes(hashPrefixes), true, nil
+ return strg.Hashes(hashPrefixes), true, nil
+}
+
+// legacyPrefixEncLen is the encoded length of a legacy hash.
+const legacyPrefixEncLen = 8
+
+// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
+func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) {
+ if prefixesStr == "" {
+ return nil, nil
+ }
+
+ prefixSet := stringutil.NewSet()
+ prefixStrs := strings.Split(prefixesStr, ".")
+ for _, s := range prefixStrs {
+ if len(s) != hashstorage.PrefixEncLen {
+ // Some legacy clients send eight-character hashes instead of
+ // four-character ones. For now, remove the final four characters.
+ //
+ // TODO(a.garipov): Either remove this crutch or support such
+ // prefixes better.
+ if len(s) == legacyPrefixEncLen {
+ s = s[:hashstorage.PrefixEncLen]
+ } else {
+ return nil, fmt.Errorf("bad hash len for %q", s)
+ }
+ }
+
+ prefixSet.Add(s)
+ }
+
+ hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len())
+ prefixStrs = prefixSet.Values()
+ for i, s := range prefixStrs {
+ _, err = hex.Decode(hashPrefixes[i][:], []byte(s))
+ if err != nil {
+ return nil, fmt.Errorf("bad hash encoding for %q", s)
+ }
+ }
+
+ return hashPrefixes, nil
}
diff --git a/internal/filter/safebrowsing_test.go b/internal/filter/safebrowsing_test.go
index 0c811a9..55afa30 100644
--- a/internal/filter/safebrowsing_test.go
+++ b/internal/filter/safebrowsing_test.go
@@ -4,17 +4,11 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
- "net/url"
- "os"
- "path/filepath"
"strings"
"testing"
- "time"
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/golibs/testutil"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -43,41 +37,10 @@ func TestSafeBrowsingServer(t *testing.T) {
hashStrs[i] = hex.EncodeToString(sum[:])
}
- // Hash Storage
-
- errColl := &agdtest.ErrorCollector{
- OnCollect: func(_ context.Context, err error) {
- panic("not implemented")
- },
- }
-
- cacheDir := t.TempDir()
- cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
- err := os.WriteFile(cachePath, []byte(strings.Join(hosts, "\n")), 0o644)
- require.NoError(t, err)
-
- hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
- URL: &url.URL{},
- ErrColl: errColl,
- ID: agd.FilterListIDSafeBrowsing,
- CachePath: cachePath,
- RefreshIvl: testRefreshIvl,
- })
+ hashes, err := hashstorage.New(strings.Join(hosts, "\n"))
require.NoError(t, err)
ctx := context.Background()
-
- err = hashes.Start()
- require.NoError(t, err)
- testutil.CleanupAndRequireSuccess(t, func() (err error) {
- return hashes.Shutdown(ctx)
- })
-
- // Give the storage some time to process the hashes.
- //
- // TODO(a.garipov): Think of a less stupid way of doing this.
- time.Sleep(100 * time.Millisecond)
-
testCases := []struct {
name string
host string
@@ -90,14 +53,14 @@ func TestSafeBrowsingServer(t *testing.T) {
wantMatched: false,
}, {
name: "realistic",
- host: hashStrs[realisticHostIdx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
+ host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{
hashStrs[realisticHostIdx],
},
wantMatched: true,
}, {
name: "same_prefix",
- host: hashStrs[samePrefixHost1Idx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix,
+ host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
wantHashStrs: []string{
hashStrs[samePrefixHost1Idx],
hashStrs[samePrefixHost2Idx],
diff --git a/internal/filter/safesearch_test.go b/internal/filter/safesearch_test.go
index dafaf4e..a295e81 100644
--- a/internal/filter/safesearch_test.go
+++ b/internal/filter/safesearch_test.go
@@ -3,9 +3,7 @@ package filter_test
import (
"context"
"net"
- "os"
"testing"
- "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
@@ -20,45 +18,29 @@ import (
func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
numLookupIP := 0
- onLookupIP := func(
- _ context.Context,
- fam netutil.AddrFamily,
- _ string,
- ) (ips []net.IP, err error) {
- numLookupIP++
+ resolver := &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ fam netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ numLookupIP++
- if fam == netutil.AddrFamilyIPv4 {
- return []net.IP{safeSearchIPRespIP4}, nil
- }
+ if fam == netutil.AddrFamilyIPv4 {
+ return []net.IP{safeSearchIPRespIP4}, nil
+ }
- return []net.IP{safeSearchIPRespIP6}, nil
+ return []net.IP{safeSearchIPRespIP6}, nil
+ },
}
- tmpFile, err := os.CreateTemp(t.TempDir(), "")
- require.NoError(t, err)
-
- _, err = tmpFile.Write([]byte("bad.example.com\n"))
- require.NoError(t, err)
- testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
-
- hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
- CachePath: tmpFile.Name(),
- RefreshIvl: 1 * time.Hour,
- })
- require.NoError(t, err)
-
c := prepareConf(t)
- c.SafeBrowsing.Hashes = hashes
- c.AdultBlocking.Hashes = hashes
-
c.ErrColl = &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
- c.Resolver = &agdtest.Resolver{
- OnLookupIP: onLookupIP,
- }
+ c.Resolver = resolver
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@@ -80,13 +62,13 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
host: safeSearchIPHost,
wantIP: safeSearchIPRespIP4,
rrtype: dns.TypeA,
- wantLookups: 0,
+ wantLookups: 1,
}, {
name: "ip6",
host: safeSearchIPHost,
- wantIP: nil,
+ wantIP: safeSearchIPRespIP6,
rrtype: dns.TypeAAAA,
- wantLookups: 0,
+ wantLookups: 1,
}, {
name: "host_ip4",
host: safeSearchHost,
diff --git a/internal/filter/serviceblocker.go b/internal/filter/serviceblocker.go
index fa35448..04cff42 100644
--- a/internal/filter/serviceblocker.go
+++ b/internal/filter/serviceblocker.go
@@ -48,7 +48,7 @@ func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *servic
return &serviceBlocker{
url: indexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultTimeout,
+ Timeout: defaultFilterRefreshTimeout,
}),
mu: &sync.RWMutex{},
errColl: errColl,
diff --git a/internal/filter/storage.go b/internal/filter/storage.go
index 5fbff15..3a3b114 100644
--- a/internal/filter/storage.go
+++ b/internal/filter/storage.go
@@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
- "github.com/AdguardTeam/golibs/timeutil"
"github.com/bluele/gcache"
"github.com/prometheus/client_golang/prometheus"
)
@@ -55,10 +54,10 @@ type DefaultStorage struct {
services *serviceBlocker
// safeBrowsing is the general safe browsing filter.
- safeBrowsing *hashPrefixFilter
+ safeBrowsing *HashPrefix
// adultBlocking is the adult content blocking safe browsing filter.
- adultBlocking *hashPrefixFilter
+ adultBlocking *HashPrefix
// genSafeSearch is the general safe search filter.
genSafeSearch *safeSearch
@@ -111,11 +110,11 @@ type DefaultStorageConfig struct {
// SafeBrowsing is the configuration for the default safe browsing filter.
// It must not be nil.
- SafeBrowsing *HashPrefixConfig
+ SafeBrowsing *HashPrefix
// AdultBlocking is the configuration for the adult content blocking safe
// browsing filter. It must not be nil.
- AdultBlocking *HashPrefixConfig
+ AdultBlocking *HashPrefix
// Now is a function that returns current time.
Now func() (now time.Time)
@@ -156,25 +155,8 @@ type DefaultStorageConfig struct {
// NewDefaultStorage returns a new filter storage. c must not be nil.
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
- // TODO(ameshkov): Consider making configurable.
- resolver := agdnet.NewCachingResolver(c.Resolver, 1*timeutil.Day)
-
- safeBrowsing := newHashPrefixFilter(
- c.SafeBrowsing,
- resolver,
- c.ErrColl,
- agd.FilterListIDSafeBrowsing,
- )
-
- adultBlocking := newHashPrefixFilter(
- c.AdultBlocking,
- resolver,
- c.ErrColl,
- agd.FilterListIDAdultBlocking,
- )
-
genSafeSearch := newSafeSearch(&safeSearchConfig{
- resolver: resolver,
+ resolver: c.Resolver,
errColl: c.ErrColl,
list: &agd.FilterList{
URL: c.GeneralSafeSearchRulesURL,
@@ -187,7 +169,7 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
})
ytSafeSearch := newSafeSearch(&safeSearchConfig{
- resolver: resolver,
+ resolver: c.Resolver,
errColl: c.ErrColl,
list: &agd.FilterList{
URL: c.YoutubeSafeSearchRulesURL,
@@ -203,11 +185,11 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
mu: &sync.RWMutex{},
url: c.FilterIndexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultTimeout,
+ Timeout: defaultFilterRefreshTimeout,
}),
services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
- safeBrowsing: safeBrowsing,
- adultBlocking: adultBlocking,
+ safeBrowsing: c.SafeBrowsing,
+ adultBlocking: c.AdultBlocking,
genSafeSearch: genSafeSearch,
ytSafeSearch: ytSafeSearch,
now: c.Now,
@@ -322,7 +304,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b
func (s *DefaultStorage) safeBrowsingForProfile(
p *agd.Profile,
parentalEnabled bool,
-) (safeBrowsing, adultBlocking *hashPrefixFilter) {
+) (safeBrowsing, adultBlocking *HashPrefix) {
if p.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}
@@ -359,7 +341,7 @@ func (s *DefaultStorage) safeSearchForProfile(
// in the filtering group. g must not be nil.
func (s *DefaultStorage) safeBrowsingForGroup(
g *agd.FilteringGroup,
-) (safeBrowsing, adultBlocking *hashPrefixFilter) {
+) (safeBrowsing, adultBlocking *HashPrefix) {
if g.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}
diff --git a/internal/filter/storage_test.go b/internal/filter/storage_test.go
index 75d0edf..9839abc 100644
--- a/internal/filter/storage_test.go
+++ b/internal/filter/storage_test.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -124,34 +125,11 @@ func TestStorage_FilterFromContext(t *testing.T) {
}
func TestStorage_FilterFromContext_customAllow(t *testing.T) {
- // Initialize the hashes file and use it with the storage.
- tmpFile, err := os.CreateTemp(t.TempDir(), "")
- require.NoError(t, err)
- testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
-
- _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
- require.NoError(t, err)
-
- hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
- CachePath: tmpFile.Name(),
- RefreshIvl: 1 * time.Hour,
- })
- require.NoError(t, err)
-
- c := prepareConf(t)
-
- c.SafeBrowsing = &filter.HashPrefixConfig{
- Hashes: hashes,
- ReplacementHost: safeBrowsingSafeHost,
- CacheTTL: 10 * time.Second,
- CacheSize: 100,
- }
-
- c.ErrColl = &agdtest.ErrorCollector{
+ errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
- c.Resolver = &agdtest.Resolver{
+ resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
_ netutil.AddrFamily,
@@ -161,6 +139,35 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
},
}
+ // Initialize the hashes file and use it with the storage.
+ tmpFile, err := os.CreateTemp(t.TempDir(), "")
+ require.NoError(t, err)
+ testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
+
+ _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
+ require.NoError(t, err)
+
+ hashes, err := hashstorage.New(safeBrowsingHost)
+ require.NoError(t, err)
+
+ c := prepareConf(t)
+
+ c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
+ Hashes: hashes,
+ ErrColl: errColl,
+ Resolver: resolver,
+ ID: agd.FilterListIDSafeBrowsing,
+ CachePath: tmpFile.Name(),
+ ReplacementHost: safeBrowsingSafeHost,
+ Staleness: 1 * time.Hour,
+ CacheTTL: 10 * time.Second,
+ CacheSize: 100,
+ })
+ require.NoError(t, err)
+
+ c.ErrColl = errColl
+ c.Resolver = resolver
+
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
@@ -211,39 +218,11 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
// parental protection from 11:00:00 until 12:59:59.
nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC)
- // Initialize the hashes file and use it with the storage.
- tmpFile, err := os.CreateTemp(t.TempDir(), "")
- require.NoError(t, err)
- testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
-
- _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
- require.NoError(t, err)
-
- hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{
- CachePath: tmpFile.Name(),
- RefreshIvl: 1 * time.Hour,
- })
- require.NoError(t, err)
-
- c := prepareConf(t)
-
- // Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
- c.AdultBlocking = &filter.HashPrefixConfig{
- Hashes: hashes,
- ReplacementHost: safeBrowsingSafeHost,
- CacheTTL: 10 * time.Second,
- CacheSize: 100,
- }
-
- c.Now = func() (t time.Time) {
- return nowTime
- }
-
- c.ErrColl = &agdtest.ErrorCollector{
+ errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) { panic("not implemented") },
}
- c.Resolver = &agdtest.Resolver{
+ resolver := &agdtest.Resolver{
OnLookupIP: func(
_ context.Context,
_ netutil.AddrFamily,
@@ -253,6 +232,40 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
},
}
+ // Initialize the hashes file and use it with the storage.
+ tmpFile, err := os.CreateTemp(t.TempDir(), "")
+ require.NoError(t, err)
+ testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) })
+
+ _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
+ require.NoError(t, err)
+
+ hashes, err := hashstorage.New(safeBrowsingHost)
+ require.NoError(t, err)
+
+ c := prepareConf(t)
+
+ // Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
+ c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
+ Hashes: hashes,
+ ErrColl: errColl,
+ Resolver: resolver,
+ ID: agd.FilterListIDAdultBlocking,
+ CachePath: tmpFile.Name(),
+ ReplacementHost: safeBrowsingSafeHost,
+ Staleness: 1 * time.Hour,
+ CacheTTL: 10 * time.Second,
+ CacheSize: 100,
+ })
+ require.NoError(t, err)
+
+ c.Now = func() (t time.Time) {
+ return nowTime
+ }
+
+ c.ErrColl = errColl
+ c.Resolver = resolver
+
s, err := filter.NewDefaultStorage(c)
require.NoError(t, err)
diff --git a/internal/geoip/file.go b/internal/geoip/file.go
index e7845b8..7207558 100644
--- a/internal/geoip/file.go
+++ b/internal/geoip/file.go
@@ -20,9 +20,6 @@ import (
// FileConfig is the file-based GeoIP configuration structure.
type FileConfig struct {
- // ErrColl is the error collector that is used to report errors.
- ErrColl agd.ErrorCollector
-
// ASNPath is the path to the GeoIP database of ASNs.
ASNPath string
@@ -39,8 +36,6 @@ type FileConfig struct {
// File is a file implementation of [geoip.Interface].
type File struct {
- errColl agd.ErrorCollector
-
// mu protects asn, country, country subnet maps, and caches against
// simultaneous access during a refresh.
mu *sync.RWMutex
@@ -77,8 +72,6 @@ type asnSubnets map[agd.ASN]netip.Prefix
// NewFile returns a new GeoIP database that reads information from a file.
func NewFile(c *FileConfig) (f *File, err error) {
f = &File{
- errColl: c.ErrColl,
-
mu: &sync.RWMutex{},
asnPath: c.ASNPath,
diff --git a/internal/geoip/file_test.go b/internal/geoip/file_test.go
index 3288b99..58c629c 100644
--- a/internal/geoip/file_test.go
+++ b/internal/geoip/file_test.go
@@ -1,12 +1,10 @@
package geoip_test
import (
- "context"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
@@ -14,12 +12,7 @@ import (
)
func TestFile_Data(t *testing.T) {
- var ec agd.ErrorCollector = &agdtest.ErrorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- }
-
conf := &geoip.FileConfig{
- ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@@ -42,12 +35,7 @@ func TestFile_Data(t *testing.T) {
}
func TestFile_Data_hostCache(t *testing.T) {
- var ec agd.ErrorCollector = &agdtest.ErrorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- }
-
conf := &geoip.FileConfig{
- ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 1,
@@ -74,12 +62,7 @@ func TestFile_Data_hostCache(t *testing.T) {
}
func TestFile_SubnetByLocation(t *testing.T) {
- var ec agd.ErrorCollector = &agdtest.ErrorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- }
-
conf := &geoip.FileConfig{
- ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@@ -106,12 +89,7 @@ var locSink *agd.Location
var errSink error
func BenchmarkFile_Data(b *testing.B) {
- var ec agd.ErrorCollector = &agdtest.ErrorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- }
-
conf := &geoip.FileConfig{
- ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
@@ -162,12 +140,7 @@ func BenchmarkFile_Data(b *testing.B) {
var fileSink *geoip.File
func BenchmarkNewFile(b *testing.B) {
- var ec agd.ErrorCollector = &agdtest.ErrorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- }
-
conf := &geoip.FileConfig{
- ErrColl: ec,
ASNPath: asnPath,
CountryPath: countryPath,
HostCacheSize: 0,
diff --git a/internal/metrics/dnssvc.go b/internal/metrics/dnssvc.go
index 8b02047..1acf3c7 100644
--- a/internal/metrics/dnssvc.go
+++ b/internal/metrics/dnssvc.go
@@ -11,7 +11,7 @@ var DNSSvcRequestByCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_country_total",
Namespace: namespace,
Subsystem: subsystemDNSSvc,
- Help: "The number of filtered DNS requests labeled by country and continent.",
+ Help: "The number of processed DNS requests labeled by country and continent.",
}, []string{"continent", "country"})
// DNSSvcRequestByASNTotal is a counter with the total number of queries
@@ -20,13 +20,14 @@ var DNSSvcRequestByASNTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_asn_total",
Namespace: namespace,
Subsystem: subsystemDNSSvc,
- Help: "The number of filtered DNS requests labeled by country and ASN.",
+ Help: "The number of processed DNS requests labeled by country and ASN.",
}, []string{"country", "asn"})
// DNSSvcRequestByFilterTotal is a counter with the total number of queries
-// processed labeled by filter. "filter" contains the ID of the filter list
-// applied. "anonymous" is "0" if the request is from a AdGuard DNS customer,
-// otherwise it is "1".
+// processed labeled by filter. Processed could mean that the request was
+// blocked or unblocked by a rule from that filter list. "filter" contains
+// the ID of the filter list applied. "anonymous" is "0" if the request is
+// from a AdGuard DNS customer, otherwise it is "1".
var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "request_per_filter_total",
Namespace: namespace,
diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go
index 0d163c9..5371034 100644
--- a/internal/metrics/metrics.go
+++ b/internal/metrics/metrics.go
@@ -26,6 +26,8 @@ const (
subsystemQueryLog = "querylog"
subsystemRuleStat = "rulestat"
subsystemTLS = "tls"
+ subsystemResearch = "research"
+ subsystemWebSvc = "websvc"
)
// SetUpGauge signals that the server has been started.
diff --git a/internal/metrics/research.go b/internal/metrics/research.go
new file mode 100644
index 0000000..50646a6
--- /dev/null
+++ b/internal/metrics/research.go
@@ -0,0 +1,61 @@
+package metrics
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+// ResearchRequestsPerCountryTotal counts the total number of queries per
+// country from anonymous users.
+var ResearchRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
+ Name: "requests_per_country_total",
+ Namespace: namespace,
+ Subsystem: subsystemResearch,
+ Help: "The total number of DNS queries per country from anonymous users.",
+}, []string{"country"})
+
+// ResearchBlockedRequestsPerCountryTotal counts the number of blocked queries
+// per country from anonymous users.
+var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{
+ Name: "blocked_per_country_total",
+ Namespace: namespace,
+ Subsystem: subsystemResearch,
+ Help: "The number of blocked DNS queries per country from anonymous users.",
+}, []string{"filter", "country"})
+
+// ReportResearchMetrics reports metrics to prometheus that we may need to
+// conduct researches.
+//
+// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved.
+func ReportResearchMetrics(
+ anonymous bool,
+ filteringEnabled bool,
+ asn string,
+ ctry string,
+ filterID string,
+ blocked bool,
+) {
+ // The current research metrics only count queries that come to public
+ // DNS servers where filtering is enabled.
+ if !filteringEnabled || !anonymous {
+ return
+ }
+
+ // Ignore AdGuard ASN specifically in order to avoid counting queries that
+ // come from the monitoring. This part is ugly, but since these metrics
+ // are a one-time deal, this is acceptable.
+ //
+ // TODO(ameshkov): think of a better way later if we need to do that again.
+ if asn == "212772" {
+ return
+ }
+
+ if blocked {
+ ResearchBlockedRequestsPerCountryTotal.WithLabelValues(
+ filterID,
+ ctry,
+ ).Inc()
+ }
+
+ ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
+}
diff --git a/internal/metrics/usercount.go b/internal/metrics/usercount.go
index b49d2b3..b06cf3f 100644
--- a/internal/metrics/usercount.go
+++ b/internal/metrics/usercount.go
@@ -86,7 +86,7 @@ func (c *userCounter) record(now time.Time, ip netip.Addr, syncUpdate bool) {
prevMinuteCounter := c.currentMinuteCounter
c.currentMinute = minuteOfTheDay
- c.currentMinuteCounter = hyperloglog.New()
+ c.currentMinuteCounter = newHyperLogLog()
// If this is the first iteration and prevMinute is -1, don't update the
// counters, since there are none.
@@ -123,7 +123,7 @@ func (c *userCounter) updateCounters(prevMinute int, prevCounter *hyperloglog.Sk
// estimate uses HyperLogLog counters to estimate the hourly and daily users
// count, starting with the minute of the day m.
func (c *userCounter) estimate(m int) (hourly, daily uint64) {
- hourlyCounter, dailyCounter := hyperloglog.New(), hyperloglog.New()
+ hourlyCounter, dailyCounter := newHyperLogLog(), newHyperLogLog()
// Go through all minutes in a day while decreasing the current minute m.
// Decreasing m, as opposed to increasing it or using i as the minute, is
@@ -168,6 +168,11 @@ func decrMod(n, m int) (res int) {
return n - 1
}
+// newHyperLogLog creates a new instance of hyperloglog.Sketch.
+func newHyperLogLog() (sk *hyperloglog.Sketch) {
+ return hyperloglog.New16()
+}
+
// defaultUserCounter is the main user statistics counter.
var defaultUserCounter = newUserCounter()
diff --git a/internal/metrics/websvc.go b/internal/metrics/websvc.go
new file mode 100644
index 0000000..e324e18
--- /dev/null
+++ b/internal/metrics/websvc.go
@@ -0,0 +1,69 @@
+package metrics
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+var (
+ webSvcRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
+ Name: "websvc_requests_total",
+ Namespace: namespace,
+ Subsystem: subsystemWebSvc,
+ Help: "The number of DNS requests for websvc.",
+ }, []string{"kind"})
+
+ // WebSvcError404RequestsTotal is a counter with total number of
+ // requests with error 404.
+ WebSvcError404RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "error404",
+ })
+
+ // WebSvcError500RequestsTotal is a counter with total number of
+ // requests with error 500.
+ WebSvcError500RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "error500",
+ })
+
+ // WebSvcStaticContentRequestsTotal is a counter with total number of
+ // requests for static content.
+ WebSvcStaticContentRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "static_content",
+ })
+
+ // WebSvcDNSCheckTestRequestsTotal is a counter with total number of
+ // requests for dnscheck_test.
+ WebSvcDNSCheckTestRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "dnscheck_test",
+ })
+
+ // WebSvcRobotsTxtRequestsTotal is a counter with total number of
+ // requests for robots_txt.
+ WebSvcRobotsTxtRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "robots_txt",
+ })
+
+ // WebSvcRootRedirectRequestsTotal is a counter with total number of
+ // root redirected requests.
+ WebSvcRootRedirectRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "root_redirect",
+ })
+
+ // WebSvcLinkedIPProxyRequestsTotal is a counter with total number of
+ // requests with linked ip.
+ WebSvcLinkedIPProxyRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "linkip",
+ })
+
+ // WebSvcAdultBlockingPageRequestsTotal is a counter with total number
+ // of requests for adult blocking page.
+ WebSvcAdultBlockingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "adult_blocking_page",
+ })
+
+ // WebSvcSafeBrowsingPageRequestsTotal is a counter with total number
+ // of requests for safe browsing page.
+ WebSvcSafeBrowsingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{
+ "kind": "safe_browsing_page",
+ })
+)
diff --git a/internal/querylog/fs.go b/internal/querylog/fs.go
index 55f8db1..1e4449d 100644
--- a/internal/querylog/fs.go
+++ b/internal/querylog/fs.go
@@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/mathutil"
)
// FileSystemConfig is the configuration of the file system query log.
@@ -71,11 +72,6 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
metrics.QueryLogItemsCount.Inc()
}()
- var dnssec uint8 = 0
- if e.DNSSEC {
- dnssec = 1
- }
-
entBuf := l.bufferPool.Get().(*entryBuffer)
defer l.bufferPool.Put(entBuf)
entBuf.buf.Reset()
@@ -94,7 +90,7 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) {
ClientASN: e.ClientASN,
Elapsed: e.Elapsed,
RequestType: e.RequestType,
- DNSSEC: dnssec,
+ DNSSEC: mathutil.BoolToNumber[uint8](e.DNSSEC),
Protocol: e.Protocol,
ResultCode: c,
ResponseCode: e.ResponseCode,
diff --git a/internal/websvc/handler.go b/internal/websvc/handler.go
index 55c20d9..ed09a80 100644
--- a/internal/websvc/handler.go
+++ b/internal/websvc/handler.go
@@ -7,6 +7,7 @@ import (
"os"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -58,12 +59,16 @@ func (svc *Service) processRec(
body = svc.error404
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
}
+
+ metrics.WebSvcError404RequestsTotal.Inc()
case http.StatusInternalServerError:
action = "writing 500"
if len(svc.error500) != 0 {
body = svc.error500
respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
}
+
+ metrics.WebSvcError500RequestsTotal.Inc()
default:
action = "writing response"
for k, v := range rec.Header() {
@@ -87,6 +92,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/dnscheck/test":
svc.dnsCheck.ServeHTTP(w, r)
+
+ metrics.WebSvcDNSCheckTestRequestsTotal.Inc()
case "/robots.txt":
serveRobotsDisallow(w.Header(), w, "handler")
case "/":
@@ -94,6 +101,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
} else {
http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound)
+
+ metrics.WebSvcRootRedirectRequestsTotal.Inc()
}
default:
http.NotFound(w, r)
@@ -101,7 +110,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
}
// safeBrowsingHandler returns an HTTP handler serving the block page from the
-// blockPagePath. name is used for logging.
+// blockPagePath. name is used for logging and metrics.
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
f := func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
@@ -122,6 +131,15 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
}
+
+ switch name {
+ case adultBlockingName:
+ metrics.WebSvcAdultBlockingPageRequestsTotal.Inc()
+ case safeBrowsingName:
+ metrics.WebSvcSafeBrowsingPageRequestsTotal.Inc()
+ default:
+ // Go on.
+ }
}
}
@@ -140,10 +158,11 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
return false
}
- if f.AllowOrigin != "" {
- w.Header().Set(agdhttp.HdrNameAccessControlAllowOrigin, f.AllowOrigin)
+ h := w.Header()
+ for k, v := range f.Headers {
+ h[k] = v
}
- w.Header().Set(agdhttp.HdrNameContentType, f.ContentType)
+
w.WriteHeader(http.StatusOK)
_, err := w.Write(f.Content)
@@ -151,16 +170,15 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve
logErrorByType(err, "websvc: static content: writing %s: %s", p, err)
}
+ metrics.WebSvcStaticContentRequestsTotal.Inc()
+
return true
}
// StaticFile is a single file in a StaticFS.
type StaticFile struct {
- // AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header.
- AllowOrigin string
-
- // ContentType is the value for the HTTP Content-Type header.
- ContentType string
+ // Headers contains headers of the HTTP response.
+ Headers http.Header
// Content is the file content.
Content []byte
@@ -174,6 +192,8 @@ func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) {
if err != nil {
logErrorByType(err, "websvc: %s: writing response: %s", name, err)
}
+
+ metrics.WebSvcRobotsTxtRequestsTotal.Inc()
}
// logErrorByType writes err to the error log, unless err is a network error or
diff --git a/internal/websvc/handler_test.go b/internal/websvc/handler_test.go
index 128fd5f..24afecf 100644
--- a/internal/websvc/handler_test.go
+++ b/internal/websvc/handler_test.go
@@ -31,8 +31,10 @@ func TestService_ServeHTTP(t *testing.T) {
staticContent := map[string]*websvc.StaticFile{
"/favicon.ico": {
- ContentType: "image/x-icon",
- Content: []byte{},
+ Content: []byte{},
+ Headers: http.Header{
+ agdhttp.HdrNameContentType: []string{"image/x-icon"},
+ },
},
}
@@ -56,22 +58,33 @@ func TestService_ServeHTTP(t *testing.T) {
})
// DNSCheck path.
- assertPathResponse(t, svc, "/dnscheck/test", http.StatusOK)
+ assertResponse(t, svc, "/dnscheck/test", http.StatusOK)
- // Static content path.
- assertPathResponse(t, svc, "/favicon.ico", http.StatusOK)
+ // Static content path with headers.
+ h := http.Header{
+ agdhttp.HdrNameContentType: []string{"image/x-icon"},
+ agdhttp.HdrNameServer: []string{"AdGuardDNS/"},
+ }
+ assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h)
// Robots path.
- assertPathResponse(t, svc, "/robots.txt", http.StatusOK)
+ assertResponse(t, svc, "/robots.txt", http.StatusOK)
// Root redirect path.
- assertPathResponse(t, svc, "/", http.StatusFound)
+ assertResponse(t, svc, "/", http.StatusFound)
// Other path.
- assertPathResponse(t, svc, "/other", http.StatusNotFound)
+ assertResponse(t, svc, "/other", http.StatusNotFound)
}
-func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCode int) {
+// assertResponse is a helper function that checks status code of HTTP
+// response.
+func assertResponse(
+ t *testing.T,
+ svc *websvc.Service,
+ path string,
+ statusCode int,
+) (rw *httptest.ResponseRecorder) {
t.Helper()
r := httptest.NewRequest(http.MethodGet, (&url.URL{
@@ -79,9 +92,27 @@ func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCo
Host: "127.0.0.1",
Path: path,
}).String(), strings.NewReader(""))
- rw := httptest.NewRecorder()
+ rw = httptest.NewRecorder()
svc.ServeHTTP(rw, r)
assert.Equal(t, statusCode, rw.Code)
assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
+
+ return rw
+}
+
+// assertResponseWithHeaders is a helper function that checks status code and
+// headers of HTTP response.
+func assertResponseWithHeaders(
+ t *testing.T,
+ svc *websvc.Service,
+ path string,
+ statusCode int,
+ header http.Header,
+) {
+ t.Helper()
+
+ rw := assertResponse(t, svc, path, statusCode)
+
+ assert.Equal(t, header, rw.Header())
}
diff --git a/internal/websvc/linkip.go b/internal/websvc/linkip.go
index e7d7a58..90f1668 100644
--- a/internal/websvc/linkip.go
+++ b/internal/websvc/linkip.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -148,6 +149,8 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
prx.httpProxy.ServeHTTP(w, r)
+
+ metrics.WebSvcLinkedIPProxyRequestsTotal.Inc()
} else if r.URL.Path == "/robots.txt" {
serveRobotsDisallow(respHdr, w, prx.logPrefix)
} else {
diff --git a/internal/websvc/websvc.go b/internal/websvc/websvc.go
index d812bad..7a1f4e1 100644
--- a/internal/websvc/websvc.go
+++ b/internal/websvc/websvc.go
@@ -118,8 +118,8 @@ func New(c *Config) (svc *Service) {
error404: c.Error404,
error500: c.Error500,
- adultBlocking: blockPageServers(c.AdultBlocking, "adult blocking", c.Timeout),
- safeBrowsing: blockPageServers(c.SafeBrowsing, "safe browsing", c.Timeout),
+ adultBlocking: blockPageServers(c.AdultBlocking, adultBlockingName, c.Timeout),
+ safeBrowsing: blockPageServers(c.SafeBrowsing, safeBrowsingName, c.Timeout),
}
if c.RootRedirectURL != nil {
@@ -164,6 +164,12 @@ func New(c *Config) (svc *Service) {
return svc
}
+// Names for safeBrowsingHandler for logging and metrics.
+const (
+ safeBrowsingName = "safe browsing"
+ adultBlockingName = "adult blocking"
+)
+
// blockPageServers is a helper function that converts a *BlockPageServer into
// HTTP servers.
func blockPageServers(
diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh
index 5eb2c89..efc84a8 100644
--- a/scripts/make/go-lint.sh
+++ b/scripts/make/go-lint.sh
@@ -117,7 +117,9 @@ underscores() {
git ls-files '*_*.go'\
| grep -F\
-e '_generate.go'\
+ -e '_linux.go'\
-e '_noreuseport.go'\
+ -e '_others.go'\
-e '_reuseport.go'\
-e '_test.go'\
-e '_unix.go'\