diff --git a/CHANGELOG.md b/CHANGELOG.md index 3763adc..a222fcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,99 @@ The format is **not** based on [Keep a Changelog][kec], since the project +## AGDNS-916 / Build 456 + + * `ratelimit` now defines rate of requests per second for IPv4 and IPv6 + addresses separately. So replace this: + + ```yaml + ratelimit: + rps: 30 + ipv4_subnet_key_len: 24 + ipv6_subnet_key_len: 48 + ``` + + with this: + + ```yaml + ratelimit: + ipv4: + rps: 30 + subnet_key_len: 24 + ipv6: + rps: 300 + subnet_key_len: 48 + ``` + + + +## AGDNS-907 / Build 449 + + * The objects within the `filtering_groups` have a new property, + `block_firefox_canary`. So replace this: + + ```yaml + filtering_groups: + - + id: default + # … + ``` + + with this: + + ```yaml + filtering_groups: + - + id: default + # … + block_firefox_canary: true + ``` + + The recommended default value is `true`. + + + +## AGDNS-1308 / Build 447 + + * There is now a new env variable `RESEARCH_METRICS` that controls whether + collecting research metrics is enabled or not. Also, the first research + metric is added: `dns_research_blocked_per_country_total`, it counts the + number of blocked requests per country. Its default value is `0`, i.e. + research metrics collection is disabled by default. + + + +## AGDNS-1051 / Build 443 + + * There are two changes in the keys of the `static_content` map. Firstly, + properties `allow_origin` and `content_type` are removed. Secondly, a new + property, called `headers`, is added. So replace this: + + ```yaml + static_content: + '/favicon.ico': + # … + allow_origin: '*' + content_type: 'image/x-icon' + ``` + + with this: + + ```yaml + static_content: + '/favicon.ico': + # … + headers: + 'Access-Control-Allow-Origin': + - '*' + 'Content-Type': + - 'image/x-icon' + ``` + + Adjust or add the values, if necessary. + + + ## AGDNS-1278 / Build 423 * The object `filters` has two new properties, `rule_list_cache_size` and diff --git a/config.dist.yml b/config.dist.yml index 14a99f2..9550cfc 100644 --- a/config.dist.yml +++ b/config.dist.yml @@ -8,8 +8,21 @@ ratelimit: refuseany: true # If response is larger than this, it is counted as several responses. response_size_estimate: 1KB - # Rate of requests per second for one subnet. rps: 30 + # Rate limit options for IPv4 addresses. + ipv4: + # Rate of requests per second for one subnet for IPv4 addresses. + rps: 30 + # The lengths of the subnet prefixes used to calculate rate limiter + # bucket keys for IPv4 addresses. + subnet_key_len: 24 + # Rate limit options for IPv6 addresses. + ipv6: + # Rate of requests per second for one subnet for IPv6 addresses. + rps: 300 + # The lengths of the subnet prefixes used to calculate rate limiter + # bucket keys for IPv6 addresses. + subnet_key_len: 48 # The time during which to count the number of times a client has hit the # rate limit for a back off. # @@ -21,10 +34,6 @@ ratelimit: # How much a client that has hit the rate limit too often stays in the back # off. back_off_duration: 30m - # The lengths of the subnet prefixes used to calculate rate limiter bucket - # keys for IPv4 and IPv6 addresses correspondingly. - ipv4_subnet_key_len: 24 - ipv6_subnet_key_len: 48 # Configuration for the allowlist. allowlist: @@ -156,9 +165,12 @@ web: # servers. Paths must not cross the ones used by the DNS-over-HTTPS server. static_content: '/favicon.ico': - allow_origin: '*' - content_type: 'image/x-icon' content: '' + headers: + Access-Control-Allow-Origin: + - '*' + Content-Type: + - 'image/x-icon' # If not defined, AdGuard DNS will respond with a 404 page to all such # requests. root_redirect_url: 'https://adguard-dns.com' @@ -221,6 +233,7 @@ filtering_groups: safe_browsing: enabled: true block_private_relay: false + block_firefox_canary: true - id: 'family' parental: enabled: true @@ -234,6 +247,7 @@ filtering_groups: safe_browsing: enabled: true block_private_relay: false + block_firefox_canary: true - id: 'non_filtering' rule_lists: enabled: false @@ -242,6 +256,7 @@ filtering_groups: safe_browsing: enabled: false block_private_relay: false + block_firefox_canary: true # Server groups and servers. server_groups: diff --git a/doc/configuration.md b/doc/configuration.md index e68825a..49fdc69 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -85,11 +85,23 @@ The `ratelimit` object has the following properties: **Example:** `30m`. - * `rps`: - The rate of requests per second for one subnet. Requests above this are - counted in the backoff count. + * `ipv4`: + The ipv4 configuration object. It has the following fields: - **Example:** `30`. + * `rps`: + The rate of requests per second for one subnet. Requests above this are + counted in the backoff count. + + **Example:** `30`. + + * `ipv4-subnet_key_len`: + The length of the subnet prefix used to calculate rate limiter bucket keys. + + **Example:** `24`. + + * `ipv6`: + The `ipv6` configuration object has the same properties as the `ipv4` one + above. * `back_off_count`: Maximum number of requests a client can make above the RPS within @@ -120,22 +132,11 @@ The `ratelimit` object has the following properties: **Example:** `30s`. - * `ipv4_subnet_key_len`: - The length of the subnet prefix used to calculate rate limiter bucket keys - for IPv4 addresses. - - **Example:** `24`. - - * `ipv6_subnet_key_len`: - Same as `ipv4_subnet_key_len` above but for IPv6 addresses. - - **Example:** `48`. - -For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and `rps` -is `5`, a client (meaning all IP addresses within the subnet defined by -`ipv4_subnet_key_len` and `ipv6_subnet_key_len`) that made 15 requests in one -second or 6 requests (one above `rps`) every second for 10 seconds within one -minute, the client is blocked for `back_off_duration`. +For example, if `back_off_period` is `1m`, `back_off_count` is `10`, and +`ipv4-rps` is `5`, a client (meaning all IP addresses within the subnet defined +by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests +(one above `rps`) every second for 10 seconds within one minute, the client is +blocked for `back_off_duration`. [env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL @@ -454,13 +455,17 @@ The optional `web` object has the following properties: `safe_browsing` and `adult_blocking` servers. Paths must not duplicate the ones used by the DNS-over-HTTPS server. + Inside of the `headers` map, the header `Content-Type` is required. + **Property example:** ```yaml - 'static_content': + static_content: '/favicon.ico': - 'content_type': 'image/x-icon' - 'content': 'base64content' + content: 'base64content' + headers: + 'Content-Type': + - 'image/x-icon' ``` * `root_redirect_url`: @@ -647,6 +652,12 @@ The items of the `filtering_groups` array have the following properties: **Example:** `false`. + * `block_firefox_canary`: + If true, Firefox canary domain queries are blocked for requests using this + filtering group. + + **Example:** `true`. + ## Server groups diff --git a/doc/environment.md b/doc/environment.md index 48e2b6b..ffc39e3 100644 --- a/doc/environment.md +++ b/doc/environment.md @@ -22,6 +22,7 @@ sensitive configuration. All other configuration is stored in the * [`LISTEN_PORT`](#LISTEN_PORT) * [`LOG_TIMESTAMP`](#LOG_TIMESTAMP) * [`QUERYLOG_PATH`](#QUERYLOG_PATH) + * [`RESEARCH_METRICS`](#RESEARCH_METRICS) * [`RULESTAT_URL`](#RULESTAT_URL) * [`SENTRY_DSN`](#SENTRY_DSN) * [`SSL_KEY_LOG_FILE`](#SSL_KEY_LOG_FILE) @@ -198,6 +199,15 @@ The path to the file into which the query log is going to be written. +## `RESEARCH_METRICS` + +If `1`, enable collection of a set of special prometheus metrics (prefix is +`dns_research`). If `0`, disable collection of those metrics. + +**Default:** `0`. + + + ## `RULESTAT_URL` The URL to send filtering rule list statistics to. If empty or unset, the diff --git a/go.mod b/go.mod index d26f926..65ee670 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0 - github.com/AdguardTeam/golibs v0.11.3 + github.com/AdguardTeam/golibs v0.11.4 github.com/AdguardTeam/urlfilter v0.16.1 github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a diff --git a/go.sum b/go.sum index a2cbbd8..2ba411c 100644 --- a/go.sum +++ b/go.sum @@ -33,8 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= -github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310= -github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= +github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM= +github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw= github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= diff --git a/go.work.sum b/go.work.sum index 444de9c..918589d 100644 --- a/go.work.sum +++ b/go.work.sum @@ -9,7 +9,11 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVl dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8= +github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= +github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= +github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM= +github.com/AdguardTeam/golibs v0.11.4/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= @@ -74,7 +78,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= @@ -198,41 +202,36 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA= github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w= @@ -246,6 +245,7 @@ gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= diff --git a/internal/agd/dns.go b/internal/agd/dns.go index 39fa7cd..65839af 100644 --- a/internal/agd/dns.go +++ b/internal/agd/dns.go @@ -1,11 +1,6 @@ package agd -import ( - "context" - "net" - - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" -) +import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" // Common DNS Message Constants, Types, And Utilities @@ -27,10 +22,3 @@ const ( ProtoDoT = dnsserver.ProtoDoT ProtoDNSCrypt = dnsserver.ProtoDNSCrypt ) - -// Resolver is the DNS resolver interface. -// -// See go doc net.Resolver. -type Resolver interface { - LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) -} diff --git a/internal/agd/filterlist.go b/internal/agd/filterlist.go index c5b0213..0fbe429 100644 --- a/internal/agd/filterlist.go +++ b/internal/agd/filterlist.go @@ -146,6 +146,10 @@ type FilteringGroup struct { // BlockPrivateRelay shows if Apple Private Relay is blocked for requests // using this filtering group. BlockPrivateRelay bool + + // BlockFirefoxCanary shows if Firefox canary domain is blocked for + // requests using this filtering group. + BlockFirefoxCanary bool } // FilteringGroupID is the ID of a filter group. It is an opaque string. diff --git a/internal/agd/profile.go b/internal/agd/profile.go index 58d8e42..672f2c6 100644 --- a/internal/agd/profile.go +++ b/internal/agd/profile.go @@ -73,6 +73,10 @@ type Profile struct { // BlockPrivateRelay shows if Apple Private Relay queries are blocked for // requests from all devices in this profile. BlockPrivateRelay bool + + // BlockFirefoxCanary shows if Firefox canary domain is blocked for + // requests from all devices in this profile. + BlockFirefoxCanary bool } // ProfileID is the ID of a profile. It is an opaque string. diff --git a/internal/agd/profiledb.go b/internal/agd/profiledb.go index 41899ae..39c3b10 100644 --- a/internal/agd/profiledb.go +++ b/internal/agd/profiledb.go @@ -17,6 +17,8 @@ import ( // Data Storage // ProfileDB is the local database of profiles and other data. +// +// TODO(a.garipov): move this logic to the backend package. type ProfileDB interface { ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error) ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error) diff --git a/internal/agdio/agdio.go b/internal/agdio/agdio.go index 4799202..bfd3300 100644 --- a/internal/agdio/agdio.go +++ b/internal/agdio/agdio.go @@ -7,6 +7,8 @@ package agdio import ( "fmt" "io" + + "github.com/AdguardTeam/golibs/mathutil" ) // LimitError is returned when the Limit is reached. @@ -35,9 +37,8 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) { } } - if int64(len(p)) > lr.n { - p = p[0:lr.n] - } + l := mathutil.Min(int64(len(p)), lr.n) + p = p[:l] n, err = lr.r.Read(p) lr.n -= int64(n) diff --git a/internal/backend/profiledb.go b/internal/backend/profiledb.go index c0dde37..4e089b2 100644 --- a/internal/backend/profiledb.go +++ b/internal/backend/profiledb.go @@ -165,6 +165,31 @@ type v1SettingsRespSettings struct { FilteringEnabled bool `json:"filtering_enabled"` Deleted bool `json:"deleted"` BlockPrivateRelay bool `json:"block_private_relay"` + BlockFirefoxCanary bool `json:"block_firefox_canary"` +} + +// type check +var _ json.Unmarshaler = (*v1SettingsRespSettings)(nil) + +// UnmarshalJSON implements the [json.Unmarshaler] interface for +// *v1SettingsRespSettings. It puts default value into BlockFirefoxCanary +// field while it is not implemented on the backend side. +// +// TODO(a.garipov): Remove once the backend starts to always send it. +func (rs *v1SettingsRespSettings) UnmarshalJSON(b []byte) (err error) { + type defaultDec v1SettingsRespSettings + + s := defaultDec{ + BlockFirefoxCanary: true, + } + + if err = json.Unmarshal(b, &s); err != nil { + return err + } + + *rs = v1SettingsRespSettings(s) + + return nil } // v1SettingsRespRuleLists is the structure for decoding filtering rule lists @@ -414,10 +439,11 @@ const maxFltRespTTL = 1 * time.Hour func fltRespTTLToInternal(respTTL uint32) (ttl time.Duration, err error) { ttl = time.Duration(respTTL) * time.Second if ttl > maxFltRespTTL { - return ttl, fmt.Errorf("too high: got %d, max %d", respTTL, maxFltRespTTL) + ttl = maxFltRespTTL + err = fmt.Errorf("too high: got %s, max %s", ttl, maxFltRespTTL) } - return ttl, nil + return ttl, err } // toInternal converts r to an [agd.DSProfilesResponse] instance. @@ -461,6 +487,9 @@ func (r *v1SettingsResp) toInternal( reportf(ctx, errColl, "settings at index %d: filtered resp ttl: %w", i, err) // Go on and use the fixed value. + // + // TODO(ameshkov, a.garipov): Consider continuing, like with all + // other validation errors. } sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled @@ -479,6 +508,7 @@ func (r *v1SettingsResp) toInternal( QueryLogEnabled: s.QueryLogEnabled, Deleted: s.Deleted, BlockPrivateRelay: s.BlockPrivateRelay, + BlockFirefoxCanary: s.BlockFirefoxCanary, }) } diff --git a/internal/backend/profiledb_test.go b/internal/backend/profiledb_test.go index 268de3a..2e16fa0 100644 --- a/internal/backend/profiledb_test.go +++ b/internal/backend/profiledb_test.go @@ -131,6 +131,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { QueryLogEnabled: true, Deleted: false, BlockPrivateRelay: true, + BlockFirefoxCanary: true, }, { Parental: wantParental, ID: "83f3ea8f", @@ -162,6 +163,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { QueryLogEnabled: true, Deleted: true, BlockPrivateRelay: false, + BlockFirefoxCanary: false, }}, } diff --git a/internal/backend/testdata/profiles.json b/internal/backend/testdata/profiles.json index a44e55d..11af374 100644 --- a/internal/backend/testdata/profiles.json +++ b/internal/backend/testdata/profiles.json @@ -11,6 +11,7 @@ }, "deleted": false, "block_private_relay": true, + "block_firefox_canary": true, "devices": [ { "id": "118ffe93", @@ -43,6 +44,7 @@ }, "deleted": true, "block_private_relay": false, + "block_firefox_canary": false, "devices": [ { "id": "0d7724fa", diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 46e3b92..064b47c 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -28,6 +28,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/websvc" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/timeutil" ) // Main is the entry point of application. @@ -89,52 +90,75 @@ func Main() { go func() { defer geoIPMu.Unlock() - geoIP, geoIPErr = envs.geoIP(c.GeoIP, errColl) + geoIP, geoIPErr = envs.geoIP(c.GeoIP) }() - // Safe Browsing Hosts + // Safe-browsing and adult-blocking filters + + // TODO(ameshkov): Consider making configurable. + filteringResolver := agdnet.NewCachingResolver( + agdnet.DefaultResolver{}, + 1*timeutil.Day, + ) err = os.MkdirAll(envs.FilterCachePath, agd.DefaultDirPerm) check(err) - safeBrowsingConf := c.SafeBrowsing.toInternal( + safeBrowsingConf, err := c.SafeBrowsing.toInternal( + errColl, + filteringResolver, agd.FilterListIDSafeBrowsing, envs.FilterCachePath, - errColl, ) - safeBrowsingHashes, err := filter.NewHashStorage(safeBrowsingConf) check(err) - err = safeBrowsingHashes.Start() + safeBrowsingFilter, err := filter.NewHashPrefix(safeBrowsingConf) check(err) - adultBlockingConf := c.AdultBlocking.toInternal( + safeBrowsingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{ + Context: ctxWithDefaultTimeout, + Refresher: safeBrowsingFilter, + ErrColl: errColl, + Name: string(agd.FilterListIDSafeBrowsing), + Interval: safeBrowsingConf.Staleness, + RefreshOnShutdown: false, + RoutineLogsAreDebug: false, + }) + err = safeBrowsingUpd.Start() + check(err) + + adultBlockingConf, err := c.AdultBlocking.toInternal( + errColl, + filteringResolver, agd.FilterListIDAdultBlocking, envs.FilterCachePath, - errColl, ) - adultBlockingHashes, err := filter.NewHashStorage(adultBlockingConf) check(err) - err = adultBlockingHashes.Start() + adultBlockingFilter, err := filter.NewHashPrefix(adultBlockingConf) check(err) - // Filters And Filtering Groups + adultBlockingUpd := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{ + Context: ctxWithDefaultTimeout, + Refresher: adultBlockingFilter, + ErrColl: errColl, + Name: string(agd.FilterListIDAdultBlocking), + Interval: adultBlockingConf.Staleness, + RefreshOnShutdown: false, + RoutineLogsAreDebug: false, + }) + err = adultBlockingUpd.Start() + check(err) - fltStrgConf := c.Filters.toInternal(errColl, envs) - fltStrgConf.SafeBrowsing = &filter.HashPrefixConfig{ - Hashes: safeBrowsingHashes, - ReplacementHost: c.SafeBrowsing.BlockHost, - CacheTTL: c.SafeBrowsing.CacheTTL.Duration, - CacheSize: c.SafeBrowsing.CacheSize, - } + // Filter storage and filtering groups - fltStrgConf.AdultBlocking = &filter.HashPrefixConfig{ - Hashes: adultBlockingHashes, - ReplacementHost: c.AdultBlocking.BlockHost, - CacheTTL: c.AdultBlocking.CacheTTL.Duration, - CacheSize: c.AdultBlocking.CacheSize, - } + fltStrgConf := c.Filters.toInternal( + errColl, + filteringResolver, + envs, + safeBrowsingFilter, + adultBlockingFilter, + ) fltStrg, err := filter.NewDefaultStorage(fltStrgConf) check(err) @@ -153,8 +177,6 @@ func Main() { err = fltStrgUpd.Start() check(err) - safeBrowsing := filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes) - // Server Groups fltGroups, err := c.FilteringGroups.toInternal(fltStrg) @@ -329,8 +351,11 @@ func Main() { }, c.Upstream.Healthcheck.Enabled) dnsConf := &dnssvc.Config{ - Messages: messages, - SafeBrowsing: safeBrowsing, + Messages: messages, + SafeBrowsing: filter.NewSafeBrowsingServer( + safeBrowsingConf.Hashes, + adultBlockingConf.Hashes, + ), BillStat: billStatRec, ProfileDB: profDB, DNSCheck: dnsCk, @@ -349,6 +374,7 @@ func Main() { CacheSize: c.Cache.Size, ECSCacheSize: c.Cache.ECSSize, UseECSCache: c.Cache.Type == cacheTypeECS, + ResearchMetrics: bool(envs.ResearchMetrics), } dnsSvc, err := dnssvc.New(dnsConf) @@ -383,11 +409,11 @@ func Main() { ) h := newSignalHandler( - adultBlockingHashes, - safeBrowsingHashes, debugSvc, webSvc, dnsSvc, + safeBrowsingUpd, + adultBlockingUpd, profDBUpd, dnsDBUpd, geoIPUpd, diff --git a/internal/cmd/config.go b/internal/cmd/config.go index 34dddfc..9bed65f 100644 --- a/internal/cmd/config.go +++ b/internal/cmd/config.go @@ -3,14 +3,9 @@ package cmd import ( "fmt" "os" - "path/filepath" - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" - "github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" "gopkg.in/yaml.v2" ) @@ -213,74 +208,6 @@ func (c *geoIPConfig) validate() (err error) { } } -// allowListConfig is the consul allow list configuration. -type allowListConfig struct { - // List contains IPs and CIDRs. - List []string `yaml:"list"` - - // RefreshIvl time between two updates of allow list from the Consul URL. - RefreshIvl timeutil.Duration `yaml:"refresh_interval"` -} - -// safeBrowsingConfig is the configuration for one of the safe browsing filters. -type safeBrowsingConfig struct { - // URL is the URL used to update the filter. - URL *agdhttp.URL `yaml:"url"` - - // BlockHost is the hostname with which to respond to any requests that - // match the filter. - // - // TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6 - // addresses. - BlockHost string `yaml:"block_host"` - - // CacheSize is the size of the response cache, in entries. - CacheSize int `yaml:"cache_size"` - - // CacheTTL is the TTL of the response cache. - CacheTTL timeutil.Duration `yaml:"cache_ttl"` - - // RefreshIvl defines how often AdGuard DNS refreshes the filter. - RefreshIvl timeutil.Duration `yaml:"refresh_interval"` -} - -// toInternal converts c to the safe browsing filter configuration for the -// filter storage of the DNS server. c is assumed to be valid. -func (c *safeBrowsingConfig) toInternal( - id agd.FilterListID, - cacheDir string, - errColl agd.ErrorCollector, -) (conf *filter.HashStorageConfig) { - return &filter.HashStorageConfig{ - URL: netutil.CloneURL(&c.URL.URL), - ErrColl: errColl, - ID: id, - CachePath: filepath.Join(cacheDir, string(id)), - RefreshIvl: c.RefreshIvl.Duration, - } -} - -// validate returns an error if the safe browsing filter configuration is -// invalid. -func (c *safeBrowsingConfig) validate() (err error) { - switch { - case c == nil: - return errNilConfig - case c.URL == nil: - return errors.Error("no url") - case c.BlockHost == "": - return errors.Error("no block_host") - case c.CacheSize <= 0: - return newMustBePositiveError("cache_size", c.CacheSize) - case c.CacheTTL.Duration <= 0: - return newMustBePositiveError("cache_ttl", c.CacheTTL) - case c.RefreshIvl.Duration <= 0: - return newMustBePositiveError("refresh_interval", c.RefreshIvl) - default: - return nil - } -} - // readConfig reads the configuration. func readConfig(confPath string) (c *configuration, err error) { // #nosec G304 -- Trust the path to the configuration file that is given diff --git a/internal/cmd/env.go b/internal/cmd/env.go index fe0f94c..5454692 100644 --- a/internal/cmd/env.go +++ b/internal/cmd/env.go @@ -48,8 +48,9 @@ type environments struct { ListenPort int `env:"LISTEN_PORT" envDefault:"8181"` - LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"` - LogVerbose strictBool `env:"VERBOSE" envDefault:"0"` + LogTimestamp strictBool `env:"LOG_TIMESTAMP" envDefault:"1"` + LogVerbose strictBool `env:"VERBOSE" envDefault:"0"` + ResearchMetrics strictBool `env:"RESEARCH_METRICS" envDefault:"0"` } // readEnvs reads the configuration. @@ -127,12 +128,10 @@ func (envs *environments) buildDNSDB( // geoIP returns an GeoIP database implementation from environment. func (envs *environments) geoIP( c *geoIPConfig, - errColl agd.ErrorCollector, ) (g *geoip.File, err error) { log.Debug("using geoip files %q and %q", envs.GeoIPASNPath, envs.GeoIPCountryPath) g, err = geoip.NewFile(&geoip.FileConfig{ - ErrColl: errColl, ASNPath: envs.GeoIPASNPath, CountryPath: envs.GeoIPCountryPath, HostCacheSize: c.HostCacheSize, diff --git a/internal/cmd/filter.go b/internal/cmd/filter.go index da77a78..678d276 100644 --- a/internal/cmd/filter.go +++ b/internal/cmd/filter.go @@ -47,16 +47,21 @@ type filtersConfig struct { // cacheDir must exist. c is assumed to be valid. func (c *filtersConfig) toInternal( errColl agd.ErrorCollector, + resolver agdnet.Resolver, envs *environments, + safeBrowsing *filter.HashPrefix, + adultBlocking *filter.HashPrefix, ) (conf *filter.DefaultStorageConfig) { return &filter.DefaultStorageConfig{ FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL), BlockedServiceIndexURL: netutil.CloneURL(&envs.BlockedServiceIndexURL.URL), GeneralSafeSearchRulesURL: netutil.CloneURL(&envs.GeneralSafeSearchURL.URL), YoutubeSafeSearchRulesURL: netutil.CloneURL(&envs.YoutubeSafeSearchURL.URL), + SafeBrowsing: safeBrowsing, + AdultBlocking: adultBlocking, Now: time.Now, ErrColl: errColl, - Resolver: agdnet.DefaultResolver{}, + Resolver: resolver, CacheDir: envs.FilterCachePath, CustomFilterCacheSize: c.CustomFilterCacheSize, SafeSearchCacheSize: c.SafeSearchCacheSize, diff --git a/internal/cmd/filteringgroup.go b/internal/cmd/filteringgroup.go index f202354..e1d03cf 100644 --- a/internal/cmd/filteringgroup.go +++ b/internal/cmd/filteringgroup.go @@ -29,6 +29,10 @@ type filteringGroup struct { // BlockPrivateRelay shows if Apple Private Relay queries are blocked for // requests using this filtering group. BlockPrivateRelay bool `yaml:"block_private_relay"` + + // BlockFirefoxCanary shows if Firefox canary domain is blocked for + // requests using this filtering group. + BlockFirefoxCanary bool `yaml:"block_firefox_canary"` } // fltGrpRuleLists contains filter rule lists configuration for a filtering @@ -133,6 +137,7 @@ func (groups filteringGroups) toInternal( GeneralSafeSearch: g.Parental.GeneralSafeSearch, YoutubeSafeSearch: g.Parental.YoutubeSafeSearch, BlockPrivateRelay: g.BlockPrivateRelay, + BlockFirefoxCanary: g.BlockFirefoxCanary, } } diff --git a/internal/cmd/ratelimit.go b/internal/cmd/ratelimit.go index 677eec3..1c25d6b 100644 --- a/internal/cmd/ratelimit.go +++ b/internal/cmd/ratelimit.go @@ -15,14 +15,17 @@ type rateLimitConfig struct { // AllowList is the allowlist of clients. Allowlist *allowListConfig `yaml:"allowlist"` + // Rate limit options for IPv4 addresses. + IPv4 *rateLimitOptions `yaml:"ipv4"` + + // Rate limit options for IPv6 addresses. + IPv6 *rateLimitOptions `yaml:"ipv6"` + // ResponseSizeEstimate is the size of the estimate of the size of one DNS // response for the purposes of rate limiting. Responses over this estimate // are counted as several responses. ResponseSizeEstimate datasize.ByteSize `yaml:"response_size_estimate"` - // RPS is the maximum number of requests per second. - RPS int `yaml:"rps"` - // BackOffCount helps with repeated offenders. It defines, how many times // a client hits the rate limit before being held in the back off. BackOffCount int `yaml:"back_off_count"` @@ -35,18 +38,42 @@ type rateLimitConfig struct { // a client has hit the rate limit for a back off. BackOffPeriod timeutil.Duration `yaml:"back_off_period"` - // IPv4SubnetKeyLen is the length of the subnet prefix used to calculate - // rate limiter bucket keys for IPv4 addresses. - IPv4SubnetKeyLen int `yaml:"ipv4_subnet_key_len"` - - // IPv6SubnetKeyLen is the length of the subnet prefix used to calculate - // rate limiter bucket keys for IPv6 addresses. - IPv6SubnetKeyLen int `yaml:"ipv6_subnet_key_len"` - // RefuseANY, if true, makes the server refuse DNS * queries. RefuseANY bool `yaml:"refuse_any"` } +// allowListConfig is the consul allow list configuration. +type allowListConfig struct { + // List contains IPs and CIDRs. + List []string `yaml:"list"` + + // RefreshIvl time between two updates of allow list from the Consul URL. + RefreshIvl timeutil.Duration `yaml:"refresh_interval"` +} + +// rateLimitOptions allows define maximum number of requests for IPv4 or IPv6 +// addresses. +type rateLimitOptions struct { + // RPS is the maximum number of requests per second. + RPS int `yaml:"rps"` + + // SubnetKeyLen is the length of the subnet prefix used to calculate + // rate limiter bucket keys. + SubnetKeyLen int `yaml:"subnet_key_len"` +} + +// validate returns an error if rate limit options are invalid. +func (o *rateLimitOptions) validate() (err error) { + if o == nil { + return errNilConfig + } + + return coalesceError( + validatePositive("rps", o.RPS), + validatePositive("subnet_key_len", o.SubnetKeyLen), + ) +} + // toInternal converts c to the rate limiting configuration for the DNS server. // c is assumed to be valid. func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.BackOffConfig) { @@ -55,10 +82,11 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba ResponseSizeEstimate: int(c.ResponseSizeEstimate.Bytes()), Duration: c.BackOffDuration.Duration, Period: c.BackOffPeriod.Duration, - RPS: c.RPS, + IPv4RPS: c.IPv4.RPS, + IPv4SubnetKeyLen: c.IPv4.SubnetKeyLen, + IPv6RPS: c.IPv6.RPS, + IPv6SubnetKeyLen: c.IPv6.SubnetKeyLen, Count: c.BackOffCount, - IPv4SubnetKeyLen: c.IPv4SubnetKeyLen, - IPv6SubnetKeyLen: c.IPv6SubnetKeyLen, RefuseANY: c.RefuseANY, } } @@ -71,14 +99,21 @@ func (c *rateLimitConfig) validate() (err error) { return fmt.Errorf("allowlist: %w", errNilConfig) } + err = c.IPv4.validate() + if err != nil { + return fmt.Errorf("ipv4: %w", err) + } + + err = c.IPv6.validate() + if err != nil { + return fmt.Errorf("ipv6: %w", err) + } + return coalesceError( - validatePositive("rps", c.RPS), validatePositive("back_off_count", c.BackOffCount), validatePositive("back_off_duration", c.BackOffDuration), validatePositive("back_off_period", c.BackOffPeriod), validatePositive("response_size_estimate", c.ResponseSizeEstimate), validatePositive("allowlist.refresh_interval", c.Allowlist.RefreshIvl), - validatePositive("ipv4_subnet_key_len", c.IPv4SubnetKeyLen), - validatePositive("ipv6_subnet_key_len", c.IPv6SubnetKeyLen), ) } diff --git a/internal/cmd/safebrowsing.go b/internal/cmd/safebrowsing.go new file mode 100644 index 0000000..a16976b --- /dev/null +++ b/internal/cmd/safebrowsing.go @@ -0,0 +1,86 @@ +package cmd + +import ( + "path/filepath" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" + "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/timeutil" +) + +// Safe-browsing and adult-blocking configuration + +// safeBrowsingConfig is the configuration for one of the safe browsing filters. +type safeBrowsingConfig struct { + // URL is the URL used to update the filter. + URL *agdhttp.URL `yaml:"url"` + + // BlockHost is the hostname with which to respond to any requests that + // match the filter. + // + // TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6 + // addresses. + BlockHost string `yaml:"block_host"` + + // CacheSize is the size of the response cache, in entries. + CacheSize int `yaml:"cache_size"` + + // CacheTTL is the TTL of the response cache. + CacheTTL timeutil.Duration `yaml:"cache_ttl"` + + // RefreshIvl defines how often AdGuard DNS refreshes the filter. + RefreshIvl timeutil.Duration `yaml:"refresh_interval"` +} + +// toInternal converts c to the safe browsing filter configuration for the +// filter storage of the DNS server. c is assumed to be valid. +func (c *safeBrowsingConfig) toInternal( + errColl agd.ErrorCollector, + resolver agdnet.Resolver, + id agd.FilterListID, + cacheDir string, +) (fltConf *filter.HashPrefixConfig, err error) { + hashes, err := hashstorage.New("") + if err != nil { + return nil, err + } + + return &filter.HashPrefixConfig{ + Hashes: hashes, + URL: netutil.CloneURL(&c.URL.URL), + ErrColl: errColl, + Resolver: resolver, + ID: id, + CachePath: filepath.Join(cacheDir, string(id)), + ReplacementHost: c.BlockHost, + Staleness: c.RefreshIvl.Duration, + CacheTTL: c.CacheTTL.Duration, + CacheSize: c.CacheSize, + }, nil +} + +// validate returns an error if the safe browsing filter configuration is +// invalid. +func (c *safeBrowsingConfig) validate() (err error) { + switch { + case c == nil: + return errNilConfig + case c.URL == nil: + return errors.Error("no url") + case c.BlockHost == "": + return errors.Error("no block_host") + case c.CacheSize <= 0: + return newMustBePositiveError("cache_size", c.CacheSize) + case c.CacheTTL.Duration <= 0: + return newMustBePositiveError("cache_ttl", c.CacheTTL) + case c.RefreshIvl.Duration <= 0: + return newMustBePositiveError("refresh_interval", c.RefreshIvl) + default: + return nil + } +} diff --git a/internal/cmd/websvc.go b/internal/cmd/websvc.go index c72a2ba..5e6834f 100644 --- a/internal/cmd/websvc.go +++ b/internal/cmd/websvc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/netip" + "net/textproto" "os" "path" @@ -386,11 +387,8 @@ func (sc staticContent) validate() (err error) { // staticFile is a single file in a static content mapping. type staticFile struct { - // AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header. - AllowOrigin string `yaml:"allow_origin"` - - // ContentType is the value for the HTTP Content-Type header. - ContentType string `yaml:"content_type"` + // Headers contains headers of the HTTP response. + Headers http.Header `yaml:"headers"` // Content is the file content. Content string `yaml:"content"` @@ -400,8 +398,12 @@ type staticFile struct { // assumed to be valid. func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) { file = &websvc.StaticFile{ - AllowOrigin: f.AllowOrigin, - ContentType: f.ContentType, + Headers: http.Header{}, + } + + for k, vs := range f.Headers { + ck := textproto.CanonicalMIMEHeaderKey(k) + file.Headers[ck] = vs } file.Content, err = base64.StdEncoding.DecodeString(f.Content) @@ -409,17 +411,20 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) { return nil, fmt.Errorf("content: %w", err) } + // Check Content-Type here as opposed to in validate, because we need + // all keys to be canonicalized first. + if file.Headers.Get(agdhttp.HdrNameContentType) == "" { + return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required") + } + return file, nil } // validate returns an error if the static content file is invalid. func (f *staticFile) validate() (err error) { - switch { - case f == nil: + if f == nil { return errors.Error("no file") - case f.ContentType == "": - return errors.Error("no content_type") - default: - return nil } + + return nil } diff --git a/internal/dnsserver/dnsservertest/msg.go b/internal/dnsserver/dnsservertest/msg.go index f88ac7e..b9aba02 100644 --- a/internal/dnsserver/dnsservertest/msg.go +++ b/internal/dnsserver/dnsservertest/msg.go @@ -222,9 +222,9 @@ func NewEDNS0Padding(msgLen int, UDPBufferSize uint16) (extra dns.RR) { } } -// FindENDS0Option searches for the specified EDNS0 option in the OPT resource +// FindEDNS0Option searches for the specified EDNS0 option in the OPT resource // record of the msg and returns it or nil if it's not present. -func FindENDS0Option[T dns.EDNS0](msg *dns.Msg) (o T) { +func FindEDNS0Option[T dns.EDNS0](msg *dns.Msg) (o T) { rr := msg.IsEdns0() if rr == nil { return o diff --git a/internal/dnsserver/go.mod b/internal/dnsserver/go.mod index 9dc1562..af68597 100644 --- a/internal/dnsserver/go.mod +++ b/internal/dnsserver/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver go 1.19 require ( - github.com/AdguardTeam/golibs v0.11.3 + github.com/AdguardTeam/golibs v0.11.4 github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnsstamps v1.0.3 github.com/bluele/gcache v0.0.2 diff --git a/internal/dnsserver/go.sum b/internal/dnsserver/go.sum index 356e549..01acb1f 100644 --- a/internal/dnsserver/go.sum +++ b/internal/dnsserver/go.sum @@ -31,8 +31,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/AdguardTeam/golibs v0.11.3 h1:Oif+REq2WLycQ2Xm3ZPmJdfftptss0HbGWbxdFaC310= -github.com/AdguardTeam/golibs v0.11.3/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= +github.com/AdguardTeam/golibs v0.11.4 h1:IltyvxwCTN+xxJF5sh6VadF8Zfbf8elgCm9dgijSVzM= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= diff --git a/internal/dnsserver/listen_noreuseport.go b/internal/dnsserver/listen_noreuseport.go deleted file mode 100644 index 8cf56e6..0000000 --- a/internal/dnsserver/listen_noreuseport.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build !(aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd) - -package dnsserver - -import ( - "context" - "fmt" - "net" - - "github.com/AdguardTeam/golibs/errors" -) - -// listenUDP listens to the specified address on UDP. -func listenUDP(_ context.Context, addr string, _ bool) (conn *net.UDPConn, err error) { - defer func() { err = errors.Annotate(err, "opening packet listener: %w") }() - - c, err := net.ListenPacket("udp", addr) - if err != nil { - return nil, err - } - - conn, ok := c.(*net.UDPConn) - if !ok { - // TODO(ameshkov): should not happen, consider panic here. - err = fmt.Errorf("expected conn of type %T, got %T", conn, c) - - return nil, err - } - - return conn, nil -} - -// listenTCP listens to the specified address on TCP. -func listenTCP(_ context.Context, addr string) (conn net.Listener, err error) { - return net.Listen("tcp", addr) -} diff --git a/internal/dnsserver/listen_reuseport.go b/internal/dnsserver/listen_reuseport.go deleted file mode 100644 index 552782f..0000000 --- a/internal/dnsserver/listen_reuseport.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd - -package dnsserver - -import ( - "context" - "fmt" - "net" - "syscall" - - "github.com/AdguardTeam/golibs/errors" - "golang.org/x/sys/unix" -) - -func reuseportControl(_, _ string, c syscall.RawConn) (err error) { - var opErr error - err = c.Control(func(fd uintptr) { - opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) - }) - if err != nil { - return err - } - - return opErr -} - -// listenUDP listens to the specified address on UDP. If oob flag is set to -// true this method also enables OOB for the listen socket that enables using of -// ReadMsgUDP/WriteMsgUDP. Doing it this way is necessary to correctly discover -// the source address when it listens to 0.0.0.0. -func listenUDP(ctx context.Context, addr string, oob bool) (conn *net.UDPConn, err error) { - defer func() { err = errors.Annotate(err, "opening packet listener: %w") }() - - var lc net.ListenConfig - lc.Control = reuseportControl - c, err := lc.ListenPacket(ctx, "udp", addr) - if err != nil { - return nil, err - } - - conn, ok := c.(*net.UDPConn) - if !ok { - // TODO(ameshkov): should not happen, consider panic here. - err = fmt.Errorf("expected conn of type %T, got %T", conn, c) - - return nil, err - } - - if oob { - if err = setUDPSocketOptions(conn); err != nil { - return nil, fmt.Errorf("failed to set socket options: %w", err) - } - } - - return conn, err -} - -// listenTCP listens to the specified address on TCP. -func listenTCP(ctx context.Context, addr string) (l net.Listener, err error) { - var lc net.ListenConfig - lc.Control = reuseportControl - return lc.Listen(ctx, "tcp", addr) -} diff --git a/internal/dnsserver/netext/listenconfig.go b/internal/dnsserver/netext/listenconfig.go new file mode 100644 index 0000000..19a47d3 --- /dev/null +++ b/internal/dnsserver/netext/listenconfig.go @@ -0,0 +1,73 @@ +// Package netext contains extensions of package net in the Go standard library. +package netext + +import ( + "context" + "fmt" + "net" +) + +// ListenConfig is the interface that allows controlling options of connections +// used by the DNS servers defined in this module. Default ListenConfigs are +// the ones returned by [DefaultListenConfigWithOOB] for plain DNS and +// [DefaultListenConfig] for others. +// +// This interface is modeled after [net.ListenConfig]. +type ListenConfig interface { + Listen(ctx context.Context, network, address string) (l net.Listener, err error) + ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error) +} + +// DefaultListenConfig returns the default [ListenConfig] used by the servers in +// this module except for the plain-DNS ones, which use +// [DefaultListenConfigWithOOB]. +func DefaultListenConfig() (lc ListenConfig) { + return &net.ListenConfig{ + Control: defaultListenControl, + } +} + +// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the +// plain-DNS servers in this module. The resulting ListenConfig sets additional +// socket flags and processes the control-messages of connections created with +// ListenPacket. +func DefaultListenConfigWithOOB() (lc ListenConfig) { + return &listenConfigOOB{ + ListenConfig: net.ListenConfig{ + Control: defaultListenControl, + }, + } +} + +// type check +var _ ListenConfig = (*listenConfigOOB)(nil) + +// listenConfigOOB is a wrapper around [net.ListenConfig] with modifications +// that set the control-message options on packet conns. +type listenConfigOOB struct { + net.ListenConfig +} + +// ListenPacket implements the [ListenConfig] interface for *listenConfigOOB. +// It sets the control-message flags to receive additional out-of-band data to +// correctly discover the source address when it listens to 0.0.0.0 as well as +// in situations when SO_BINDTODEVICE is used. +// +// network must be "udp", "udp4", or "udp6". +func (lc *listenConfigOOB) ListenPacket( + ctx context.Context, + network string, + address string, +) (c net.PacketConn, err error) { + c, err = lc.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + + err = setIPOpts(c) + if err != nil { + return nil, fmt.Errorf("setting socket options: %w", err) + } + + return wrapPacketConn(c), nil +} diff --git a/internal/dnsserver/netext/listenconfig_unix.go b/internal/dnsserver/netext/listenconfig_unix.go new file mode 100644 index 0000000..09acf9b --- /dev/null +++ b/internal/dnsserver/netext/listenconfig_unix.go @@ -0,0 +1,44 @@ +//go:build unix + +package netext + +import ( + "net" + "syscall" + + "github.com/AdguardTeam/golibs/errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" +) + +// defaultListenControl is used as a [net.ListenConfig.Control] function to set +// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this +// package. +func defaultListenControl(_, _ string, c syscall.RawConn) (err error) { + var opErr error + err = c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) + if err != nil { + return err + } + + return errors.WithDeferred(opErr, err) +} + +// setIPOpts sets the IPv4 and IPv6 options on a packet connection. +func setIPOpts(c net.PacketConn) (err error) { + // TODO(a.garipov): Returning an error only if both functions return one + // (which is what module dns does as well) seems rather fragile. Depending + // on the OS, the valid errors are ENOPROTOOPT, EINVAL, and maybe others. + // Investigate and make OS-specific versions to make sure we don't miss any + // real errors. + err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) + err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) + if err4 != nil && err6 != nil { + return errors.List("setting ipv4 and ipv6 options", err4, err6) + } + + return nil +} diff --git a/internal/dnsserver/netext/listenconfig_unix_test.go b/internal/dnsserver/netext/listenconfig_unix_test.go new file mode 100644 index 0000000..caa4f6f --- /dev/null +++ b/internal/dnsserver/netext/listenconfig_unix_test.go @@ -0,0 +1,67 @@ +//go:build unix + +package netext_test + +import ( + "context" + "syscall" + "testing" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" + "github.com/AdguardTeam/golibs/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +func TestDefaultListenConfigWithOOB(t *testing.T) { + lc := netext.DefaultListenConfigWithOOB() + require.NotNil(t, lc) + + type syscallConner interface { + SyscallConn() (c syscall.RawConn, err error) + } + + t.Run("ipv4", func(t *testing.T) { + c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0") + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT) + require.NoError(t, opErr) + + // TODO(a.garipov): Rewrite this to use actual expected values for + // each OS. + assert.NotEqual(t, 0, val) + }) + require.NoError(t, err) + }) + + t.Run("ipv6", func(t *testing.T) { + c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0") + if errors.Is(err, syscall.EADDRNOTAVAIL) { + // Some CI machines have IPv6 disabled. + t.Skipf("ipv6 seems to not be supported: %s", err) + } + + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT) + require.NoError(t, opErr) + + assert.NotEqual(t, 0, val) + }) + require.NoError(t, err) + }) +} diff --git a/internal/dnsserver/netext/listenconfig_windows.go b/internal/dnsserver/netext/listenconfig_windows.go new file mode 100644 index 0000000..20e11bf --- /dev/null +++ b/internal/dnsserver/netext/listenconfig_windows.go @@ -0,0 +1,17 @@ +//go:build windows + +package netext + +import ( + "net" + "syscall" +) + +// defaultListenControl is nil on Windows, because it doesn't support +// SO_REUSEPORT. +var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error) + +// setIPOpts sets the IPv4 and IPv6 options on a packet connection. +func setIPOpts(c net.PacketConn) (err error) { + return nil +} diff --git a/internal/dnsserver/netext/packetconn.go b/internal/dnsserver/netext/packetconn.go new file mode 100644 index 0000000..180f5df --- /dev/null +++ b/internal/dnsserver/netext/packetconn.go @@ -0,0 +1,69 @@ +package netext + +import ( + "net" +) + +// PacketSession contains additional information about a packet read from or +// written to a [SessionPacketConn]. +type PacketSession interface { + LocalAddr() (addr net.Addr) + RemoteAddr() (addr net.Addr) +} + +// NewSimplePacketSession returns a new packet session using the given +// parameters. +func NewSimplePacketSession(laddr, raddr net.Addr) (s PacketSession) { + return &simplePacketSession{ + laddr: laddr, + raddr: raddr, + } +} + +// simplePacketSession is a simple implementation of the [PacketSession] +// interface. +type simplePacketSession struct { + laddr net.Addr + raddr net.Addr +} + +// LocalAddr implements the [PacketSession] interface for *simplePacketSession. +func (s *simplePacketSession) LocalAddr() (addr net.Addr) { return s.laddr } + +// RemoteAddr implements the [PacketSession] interface for *simplePacketSession. +func (s *simplePacketSession) RemoteAddr() (addr net.Addr) { return s.raddr } + +// SessionPacketConn extends [net.PacketConn] with methods for working with +// packet sessions. +type SessionPacketConn interface { + net.PacketConn + + ReadFromSession(b []byte) (n int, s PacketSession, err error) + WriteToSession(b []byte, s PacketSession) (n int, err error) +} + +// ReadFromSession is a convenience wrapper for types that may or may not +// implement [SessionPacketConn]. If c implements it, ReadFromSession uses +// c.ReadFromSession. Otherwise, it uses c.ReadFrom and the session is created +// by using [NewSimplePacketSession] with c.LocalAddr. +func ReadFromSession(c net.PacketConn, b []byte) (n int, s PacketSession, err error) { + if spc, ok := c.(SessionPacketConn); ok { + return spc.ReadFromSession(b) + } + + n, raddr, err := c.ReadFrom(b) + s = NewSimplePacketSession(c.LocalAddr(), raddr) + + return n, s, err +} + +// WriteToSession is a convenience wrapper for types that may or may not +// implement [SessionPacketConn]. If c implements it, WriteToSession uses +// c.WriteToSession. Otherwise, it uses c.WriteTo using s.RemoteAddr. +func WriteToSession(c net.PacketConn, b []byte, s PacketSession) (n int, err error) { + if spc, ok := c.(SessionPacketConn); ok { + return spc.WriteToSession(b, s) + } + + return c.WriteTo(b, s.RemoteAddr()) +} diff --git a/internal/dnsserver/netext/packetconn_linux.go b/internal/dnsserver/netext/packetconn_linux.go new file mode 100644 index 0000000..8ffa687 --- /dev/null +++ b/internal/dnsserver/netext/packetconn_linux.go @@ -0,0 +1,159 @@ +//go:build linux + +// TODO(a.garipov): Technically, we can expand this to other platforms, but that +// would require separate udpOOBSize constants and tests. + +package netext + +import ( + "fmt" + "net" + "sync" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// type check +var _ PacketSession = (*packetSession)(nil) + +// packetSession contains additional information about the packet read from a +// UDP connection. It is basically an extended version of [dns.SessionUDP] that +// contains the local address as well. +type packetSession struct { + laddr *net.UDPAddr + raddr *net.UDPAddr + respOOB []byte +} + +// LocalAddr implements the [PacketSession] interface for *packetSession. +func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr } + +// RemoteAddr implements the [PacketSession] interface for *packetSession. +func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr } + +// type check +var _ SessionPacketConn = (*sessionPacketConn)(nil) + +// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports +// that. +func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) { + return &sessionPacketConn{ + UDPConn: *c.(*net.UDPConn), + } +} + +// sessionPacketConn wraps a UDP connection and implements [SessionPacketConn]. +type sessionPacketConn struct { + net.UDPConn +} + +// oobPool is the pool of byte slices for out-of-band data. +var oobPool = &sync.Pool{ + New: func() (v any) { + b := make([]byte, IPDstOOBSize) + + return &b + }, +} + +// IPDstOOBSize is the required size of the control-message buffer for +// [net.UDPConn.ReadMsgUDP] to read the original destination on Linux. +// +// See packetconn_linux_internal_test.go. +const IPDstOOBSize = 40 + +// ReadFromSession implements the [SessionPacketConn] interface for *packetConn. +func (c *sessionPacketConn) ReadFromSession(b []byte) (n int, s PacketSession, err error) { + oobPtr := oobPool.Get().(*[]byte) + defer oobPool.Put(oobPtr) + + var oobn int + oob := *oobPtr + ps := &packetSession{} + n, oobn, _, ps.raddr, err = c.ReadMsgUDP(b, oob) + if err != nil { + return 0, nil, err + } + + var origDstIP net.IP + sockLAddr := c.LocalAddr().(*net.UDPAddr) + origDstIP, err = origLAddr(oob[:oobn]) + if err != nil { + return 0, nil, fmt.Errorf("getting original addr: %w", err) + } + + if origDstIP == nil { + ps.laddr = sockLAddr + } else { + ps.respOOB = newRespOOB(origDstIP) + ps.laddr = &net.UDPAddr{ + IP: origDstIP, + Port: sockLAddr.Port, + } + } + + return n, ps, nil +} + +// origLAddr returns the original local address from the encoded control-message +// data, if there is one. If not nil, origDst will have a protocol-appropriate +// length. +func origLAddr(oob []byte) (origDst net.IP, err error) { + ctrlMsg6 := &ipv6.ControlMessage{} + err = ctrlMsg6.Parse(oob) + if err != nil { + return nil, fmt.Errorf("parsing ipv6 control message: %w", err) + } + + if dst := ctrlMsg6.Dst; dst != nil { + // Linux maps IPv4 addresses to IPv6 ones by default, so we can get an + // IPv4 dst from an IPv6 control-message. + origDst = dst.To4() + if origDst == nil { + origDst = dst + } + + return origDst, nil + } + + ctrlMsg4 := &ipv4.ControlMessage{} + err = ctrlMsg4.Parse(oob) + if err != nil { + return nil, fmt.Errorf("parsing ipv4 control message: %w", err) + } + + return ctrlMsg4.Dst.To4(), nil +} + +// newRespOOB returns an encoded control-message for the response for this IP +// address. origDst is expected to have a protocol-appropriate length. +func newRespOOB(origDst net.IP) (b []byte) { + switch len(origDst) { + case net.IPv4len: + cm := &ipv4.ControlMessage{ + Src: origDst, + } + + return cm.Marshal() + case net.IPv6len: + cm := &ipv6.ControlMessage{ + Src: origDst, + } + + return cm.Marshal() + default: + return nil + } +} + +// WriteToSession implements the [SessionPacketConn] interface for *packetConn. +func (c *sessionPacketConn) WriteToSession(b []byte, s PacketSession) (n int, err error) { + if ps, ok := s.(*packetSession); ok { + n, _, err = c.WriteMsgUDP(b, ps.respOOB, ps.raddr) + + return n, err + } + + return c.WriteTo(b, s.RemoteAddr()) +} diff --git a/internal/dnsserver/netext/packetconn_linux_internal_test.go b/internal/dnsserver/netext/packetconn_linux_internal_test.go new file mode 100644 index 0000000..ad31abf --- /dev/null +++ b/internal/dnsserver/netext/packetconn_linux_internal_test.go @@ -0,0 +1,25 @@ +//go:build linux + +package netext + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func TestUDPOOBSize(t *testing.T) { + // See https://github.com/miekg/dns/blob/v1.1.50/udp.go. + + len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)) + len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)) + + max := len4 + if len6 > max { + max = len6 + } + + assert.Equal(t, max, IPDstOOBSize) +} diff --git a/internal/dnsserver/netext/packetconn_linux_test.go b/internal/dnsserver/netext/packetconn_linux_test.go new file mode 100644 index 0000000..6999445 --- /dev/null +++ b/internal/dnsserver/netext/packetconn_linux_test.go @@ -0,0 +1,140 @@ +//go:build linux + +package netext_test + +import ( + "context" + "fmt" + "net" + "os" + "syscall" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO(a.garipov): Add IPv6 test. +func TestSessionPacketConn(t *testing.T) { + const numTries = 5 + + // Try the test multiple times to reduce flakiness due to UDP failures. + var success4, success6 bool + for i := 0; i < numTries; i++ { + var isTimeout4, isTimeout6 bool + success4 = t.Run(fmt.Sprintf("ipv4_%d", i), func(t *testing.T) { + isTimeout4 = testSessionPacketConn(t, "udp4", "0.0.0.0:0", net.IP{127, 0, 0, 1}) + }) + + success6 = t.Run(fmt.Sprintf("ipv6_%d", i), func(t *testing.T) { + isTimeout6 = testSessionPacketConn(t, "udp6", "[::]:0", net.IPv6loopback) + }) + + if success4 && success6 { + break + } else if isTimeout4 || isTimeout6 { + continue + } + + t.Fail() + } + + if !success4 { + t.Errorf("ipv4 test failed after %d attempts", numTries) + } else if !success6 { + t.Errorf("ipv6 test failed after %d attempts", numTries) + } +} + +func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) { + lc := netext.DefaultListenConfigWithOOB() + require.NotNil(t, lc) + + c, err := lc.ListenPacket(context.Background(), proto, addr) + if isTimeoutOrFail(t, err) { + return true + } + + require.NotNil(t, c) + + deadline := time.Now().Add(1 * time.Second) + err = c.SetDeadline(deadline) + require.NoError(t, err) + + laddr := testutil.RequireTypeAssert[*net.UDPAddr](t, c.LocalAddr()) + require.NotNil(t, laddr) + + dstAddr := &net.UDPAddr{ + IP: dstIP, + Port: laddr.Port, + } + + remoteConn, err := net.DialUDP(proto, nil, dstAddr) + if proto == "udp6" && errors.Is(err, syscall.EADDRNOTAVAIL) { + // Some CI machines have IPv6 disabled. + t.Skipf("ipv6 seems to not be supported: %s", err) + } else if isTimeoutOrFail(t, err) { + return true + } + + err = remoteConn.SetDeadline(deadline) + require.NoError(t, err) + + msg := []byte("hello") + msgLen := len(msg) + _, err = remoteConn.Write(msg) + if isTimeoutOrFail(t, err) { + return true + } + + require.Implements(t, (*netext.SessionPacketConn)(nil), c) + + buf := make([]byte, msgLen) + n, sess, err := netext.ReadFromSession(c, buf) + if isTimeoutOrFail(t, err) { + return true + } + + assert.Equal(t, msgLen, n) + assert.Equal(t, net.Addr(dstAddr), sess.LocalAddr()) + assert.Equal(t, remoteConn.LocalAddr(), sess.RemoteAddr()) + assert.Equal(t, msg, buf) + + respMsg := []byte("world") + respMsgLen := len(respMsg) + n, err = netext.WriteToSession(c, respMsg, sess) + if isTimeoutOrFail(t, err) { + return true + } + + assert.Equal(t, respMsgLen, n) + + buf = make([]byte, respMsgLen) + n, err = remoteConn.Read(buf) + if isTimeoutOrFail(t, err) { + return true + } + + assert.Equal(t, respMsgLen, n) + assert.Equal(t, respMsg, buf) + + return false +} + +// isTimeoutOrFail is a helper function that returns true if err is a timeout +// error and also calls require.NoError on err. +func isTimeoutOrFail(t *testing.T, err error) (ok bool) { + t.Helper() + + if err == nil { + return false + } + + defer require.NoError(t, err) + + return errors.Is(err, os.ErrDeadlineExceeded) +} diff --git a/internal/dnsserver/netext/packetconn_others.go b/internal/dnsserver/netext/packetconn_others.go new file mode 100644 index 0000000..d64c8f4 --- /dev/null +++ b/internal/dnsserver/netext/packetconn_others.go @@ -0,0 +1,11 @@ +//go:build !linux + +package netext + +import "net" + +// wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports +// that. +func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) { + return c +} diff --git a/internal/dnsserver/prometheus/ratelimit_test.go b/internal/dnsserver/prometheus/ratelimit_test.go index 15d634f..1cc1157 100644 --- a/internal/dnsserver/prometheus/ratelimit_test.go +++ b/internal/dnsserver/prometheus/ratelimit_test.go @@ -27,7 +27,8 @@ func TestRateLimiterMetricsListener_integration_cache(t *testing.T) { Duration: time.Minute, Count: rps, ResponseSizeEstimate: 1000, - RPS: rps, + IPv4RPS: rps, + IPv6RPS: rps, RefuseANY: true, }) rlMw, err := ratelimit.NewMiddleware(rl, nil) diff --git a/internal/dnsserver/ratelimit/backoff.go b/internal/dnsserver/ratelimit/backoff.go index e1ab0de..c2cc6ce 100644 --- a/internal/dnsserver/ratelimit/backoff.go +++ b/internal/dnsserver/ratelimit/backoff.go @@ -37,15 +37,22 @@ type BackOffConfig struct { // as several responses. ResponseSizeEstimate int - // RPS is the maximum number of requests per second allowed from a single - // subnet. Any requests above this rate are counted as the client's - // back-off count. RPS must be greater than zero. - RPS int + // IPv4RPS is the maximum number of requests per second allowed from a + // single subnet for IPv4 addresses. Any requests above this rate are + // counted as the client's back-off count. RPS must be greater than + // zero. + IPv4RPS int // IPv4SubnetKeyLen is the length of the subnet prefix used to calculate // rate limiter bucket keys for IPv4 addresses. Must be greater than zero. IPv4SubnetKeyLen int + // IPv6RPS is the maximum number of requests per second allowed from a + // single subnet for IPv6 addresses. Any requests above this rate are + // counted as the client's back-off count. RPS must be greater than + // zero. + IPv6RPS int + // IPv6SubnetKeyLen is the length of the subnet prefix used to calculate // rate limiter bucket keys for IPv6 addresses. Must be greater than zero. IPv6SubnetKeyLen int @@ -62,14 +69,18 @@ type BackOffConfig struct { // current implementation might be too abstract. Middlewares by themselves // already provide an interface that can be re-implemented by the users. // Perhaps, another layer of abstraction is unnecessary. +// +// TODO(ameshkov): Consider splitting rps and other properties by protocol +// family. type BackOff struct { rpsCounters *cache.Cache hitCounters *cache.Cache allowlist Allowlist count int - rps int respSzEst int + ipv4rps int ipv4SubnetKeyLen int + ipv6rps int ipv6SubnetKeyLen int refuseANY bool } @@ -84,9 +95,10 @@ func NewBackOff(c *BackOffConfig) (l *BackOff) { hitCounters: cache.New(c.Duration, c.Duration), allowlist: c.Allowlist, count: c.Count, - rps: c.RPS, respSzEst: c.ResponseSizeEstimate, + ipv4rps: c.IPv4RPS, ipv4SubnetKeyLen: c.IPv4SubnetKeyLen, + ipv6rps: c.IPv6RPS, ipv6SubnetKeyLen: c.IPv6SubnetKeyLen, refuseANY: c.RefuseANY, } @@ -124,7 +136,12 @@ func (l *BackOff) IsRateLimited( return true, false, nil } - return l.hasHitRateLimit(key), false, nil + rps := l.ipv4rps + if ip.Is6() { + rps = l.ipv6rps + } + + return l.hasHitRateLimit(key, rps), false, nil } // validateAddr returns an error if addr is not a valid IPv4 or IPv6 address. @@ -184,14 +201,15 @@ func (l *BackOff) incBackOff(key string) { l.hitCounters.SetDefault(key, counter) } -// hasHitRateLimit checks value for a subnet. -func (l *BackOff) hasHitRateLimit(subnetIPStr string) (ok bool) { - var r *rps +// hasHitRateLimit checks value for a subnet with rps as a maximum number +// requests per second. +func (l *BackOff) hasHitRateLimit(subnetIPStr string, rps int) (ok bool) { + var r *rpsCounter rVal, ok := l.rpsCounters.Get(subnetIPStr) if ok { - r = rVal.(*rps) + r = rVal.(*rpsCounter) } else { - r = newRPS(l.rps) + r = newRPSCounter(rps) l.rpsCounters.SetDefault(subnetIPStr, r) } diff --git a/internal/dnsserver/ratelimit/ratelimit_test.go b/internal/dnsserver/ratelimit/ratelimit_test.go index 8be231d..efa3ba4 100644 --- a/internal/dnsserver/ratelimit/ratelimit_test.go +++ b/internal/dnsserver/ratelimit/ratelimit_test.go @@ -98,8 +98,9 @@ func TestRatelimitMiddleware(t *testing.T) { Duration: time.Minute, Count: rps, ResponseSizeEstimate: 128, - RPS: rps, + IPv4RPS: rps, IPv4SubnetKeyLen: 24, + IPv6RPS: rps, IPv6SubnetKeyLen: 48, RefuseANY: true, }) diff --git a/internal/dnsserver/ratelimit/rps.go b/internal/dnsserver/ratelimit/rps.go index 823db8f..dcc63dd 100644 --- a/internal/dnsserver/ratelimit/rps.go +++ b/internal/dnsserver/ratelimit/rps.go @@ -7,17 +7,17 @@ import ( // Requests Per Second Counter -// rps is a single request per seconds counter. -type rps struct { +// rpsCounter is a single request per seconds counter. +type rpsCounter struct { // mu protects all fields. mu *sync.Mutex ring []int64 idx int } -// newRPS returns a new requests per second counter. n must be above zero. -func newRPS(n int) (r *rps) { - return &rps{ +// newRPSCounter returns a new requests per second counter. n must be above zero. +func newRPSCounter(n int) (r *rpsCounter) { + return &rpsCounter{ mu: &sync.Mutex{}, // Add one, because we need to always keep track of the previous // request. For example, consider n == 1. @@ -28,7 +28,7 @@ func newRPS(n int) (r *rps) { // add adds another request to the counter. above is true if the request goes // above the counter value. It is safe for concurrent use. -func (r *rps) add(t time.Time) (above bool) { +func (r *rpsCounter) add(t time.Time) (above bool) { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/dnsserver/serverbase.go b/internal/dnsserver/serverbase.go index 42512f5..5ea8206 100644 --- a/internal/dnsserver/serverbase.go +++ b/internal/dnsserver/serverbase.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -31,16 +32,19 @@ type ConfigBase struct { // Handler is a handler that processes incoming DNS messages. // If not set, we'll use the default handler that returns error response // to any query. - Handler Handler + // Metrics is the object we use for collecting performance metrics. // This field is optional. - Metrics MetricsListener // BaseContext is a function that should return the base context. If not // set, we'll be using context.Background(). BaseContext func() (ctx context.Context) + + // ListenConfig, when set, is used to set options of connections used by the + // DNS server. If nil, an appropriate default ListenConfig is used. + ListenConfig netext.ListenConfig } // ServerBase implements base methods that every Server implementation uses. @@ -67,14 +71,19 @@ type ServerBase struct { // metrics is the object we use for collecting performance metrics. metrics MetricsListener + // listenConfig is used to set tcpListener and udpListener. + listenConfig netext.ListenConfig + // Server operation // -- - // will be nil for servers that don't use TCP. + // tcpListener is used to accept new TCP connections. It is nil for servers + // that don't use TCP. tcpListener net.Listener - // will be nil for servers that don't use UDP. - udpListener *net.UDPConn + // udpListener is used to accept new UDP messages. It is nil for servers + // that don't use UDP. + udpListener net.PacketConn // Shutdown handling // -- @@ -94,13 +103,14 @@ var _ Server = (*ServerBase)(nil) // some of its internal properties. func newServerBase(proto Protocol, conf ConfigBase) (s *ServerBase) { s = &ServerBase{ - name: conf.Name, - addr: conf.Addr, - proto: proto, - network: conf.Network, - handler: conf.Handler, - metrics: conf.Metrics, - baseContext: conf.BaseContext, + name: conf.Name, + addr: conf.Addr, + proto: proto, + network: conf.Network, + handler: conf.Handler, + metrics: conf.Metrics, + listenConfig: conf.ListenConfig, + baseContext: conf.BaseContext, } if s.baseContext == nil { @@ -347,18 +357,17 @@ func (s *ServerBase) handlePanicAndRecover(ctx context.Context) { } } -// listenUDP creates a UDP listener for the ServerBase.addr. This function will -// initialize and start ServerBase.udpListener or return an error. If the TCP -// listener is already running, its address is used instead. The point of this -// is to properly handle the case when port 0 is used as both listeners should -// use the same port, and we only learn it after the first one was started. +// listenUDP initializes and starts s.udpListener using s.addr. If the TCP +// listener is already running, its address is used instead to properly handle +// the case when port 0 is used as both listeners should use the same port, and +// we only learn it after the first one was started. func (s *ServerBase) listenUDP(ctx context.Context) (err error) { addr := s.addr if s.tcpListener != nil { addr = s.tcpListener.Addr().String() } - conn, err := listenUDP(ctx, addr, true) + conn, err := s.listenConfig.ListenPacket(ctx, "udp", addr) if err != nil { return err } @@ -368,19 +377,17 @@ func (s *ServerBase) listenUDP(ctx context.Context) (err error) { return nil } -// listenTCP creates a TCP listener for the ServerBase.addr. This function will -// initialize and start ServerBase.tcpListener or return an error. If the UDP -// listener is already running, its address is used instead. The point of this -// is to properly handle the case when port 0 is used as both listeners should -// use the same port, and we only learn it after the first one was started. +// listenTCP initializes and starts s.tcpListener using s.addr. If the UDP +// listener is already running, its address is used instead to properly handle +// the case when port 0 is used as both listeners should use the same port, and +// we only learn it after the first one was started. func (s *ServerBase) listenTCP(ctx context.Context) (err error) { addr := s.addr if s.udpListener != nil { addr = s.udpListener.LocalAddr().String() } - var l net.Listener - l, err = listenTCP(ctx, addr) + l, err := s.listenConfig.Listen(ctx, "tcp", addr) if err != nil { return err } diff --git a/internal/dnsserver/serverdns.go b/internal/dnsserver/serverdns.go index c54beb4..88ee67e 100644 --- a/internal/dnsserver/serverdns.go +++ b/internal/dnsserver/serverdns.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" @@ -110,6 +111,10 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) { conf.TCPSize = dns.MinMsgSize } + if conf.ListenConfig == nil { + conf.ListenConfig = netext.DefaultListenConfigWithOOB() + } + s = &ServerDNS{ ServerBase: newServerBase(proto, conf.ConfigBase), conf: conf, @@ -262,10 +267,21 @@ func makePacketBuffer(size int) (f func() any) { } } +// writeDeadlineSetter is an interface for connections that can set write +// deadlines. +type writeDeadlineSetter interface { + SetWriteDeadline(t time.Time) (err error) +} + // withWriteDeadline is a helper that takes the deadline of the context and the // write timeout into account. It sets the write deadline on conn before // calling f and resets it once f is done. -func withWriteDeadline(ctx context.Context, writeTimeout time.Duration, conn net.Conn, f func()) { +func withWriteDeadline( + ctx context.Context, + writeTimeout time.Duration, + conn writeDeadlineSetter, + f func(), +) { dl, hasDeadline := ctx.Deadline() if !hasDeadline { dl = time.Now().Add(writeTimeout) diff --git a/internal/dnsserver/serverdns_test.go b/internal/dnsserver/serverdns_test.go index 1ca02f8..0e6f3d6 100644 --- a/internal/dnsserver/serverdns_test.go +++ b/internal/dnsserver/serverdns_test.go @@ -284,8 +284,8 @@ func TestServerDNS_integration_query(t *testing.T) { tc.expectedTruncated, ) - reqKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req) - respKeepAliveOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp) + reqKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](tc.req) + respKeepAliveOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_TCP_KEEPALIVE](resp) if tc.network == dnsserver.NetworkTCP && reqKeepAliveOpt != nil { require.NotNil(t, respKeepAliveOpt) expectedTimeout := uint16(dnsserver.DefaultTCPIdleTimeout.Milliseconds() / 100) diff --git a/internal/dnsserver/serverdnscrypt.go b/internal/dnsserver/serverdnscrypt.go index dfb94a1..84e30d0 100644 --- a/internal/dnsserver/serverdnscrypt.go +++ b/internal/dnsserver/serverdnscrypt.go @@ -2,7 +2,9 @@ package dnsserver import ( "context" + "net" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/ameshkov/dnscrypt/v2" @@ -38,6 +40,10 @@ var _ Server = (*ServerDNSCrypt)(nil) // NewServerDNSCrypt creates a new instance of ServerDNSCrypt. func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) { + if conf.ListenConfig == nil { + conf.ListenConfig = netext.DefaultListenConfig() + } + return &ServerDNSCrypt{ ServerBase: newServerBase(ProtoDNSCrypt, conf.ConfigBase), conf: conf, @@ -140,7 +146,10 @@ func (s *ServerDNSCrypt) startServeUDP(ctx context.Context) { // TODO(ameshkov): Add context to the ServeTCP and ServeUDP methods in // dnscrypt/v3. Or at least add ServeTCPContext and ServeUDPContext // methods for now. - err := s.dnsCryptServer.ServeUDP(s.udpListener) + // + // TODO(ameshkov): Redo the dnscrypt module to make it not depend on + // *net.UDPConn and use net.PacketConn instead. + err := s.dnsCryptServer.ServeUDP(s.udpListener.(*net.UDPConn)) if err != nil { log.Info("[%s]: Finished listening to udp://%s due to %v", s.Name(), s.Addr(), err) } diff --git a/internal/dnsserver/serverdnsudp.go b/internal/dnsserver/serverdnsudp.go index 2dc4ce2..ea0aecc 100644 --- a/internal/dnsserver/serverdnsudp.go +++ b/internal/dnsserver/serverdnsudp.go @@ -4,23 +4,21 @@ import ( "context" "fmt" "net" - "runtime" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) // serveUDP runs the UDP serving loop. -func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) { +func (s *ServerDNS) serveUDP(ctx context.Context, conn net.PacketConn) (err error) { defer log.OnCloserError(conn, log.DEBUG) for s.isStarted() { var m []byte - var sess *dns.SessionUDP + var sess netext.PacketSession m, sess, err = s.readUDPMsg(ctx, conn) if err != nil { // TODO(ameshkov): Consider the situation where the server is shut @@ -59,15 +57,15 @@ func (s *ServerDNS) serveUDP(ctx context.Context, conn *net.UDPConn) (err error) func (s *ServerDNS) serveUDPPacket( ctx context.Context, m []byte, - conn *net.UDPConn, - udpSession *dns.SessionUDP, + conn net.PacketConn, + sess netext.PacketSession, ) { defer s.wg.Done() defer s.handlePanicAndRecover(ctx) rw := &udpResponseWriter{ + udpSession: sess, conn: conn, - udpSession: udpSession, writeTimeout: s.conf.WriteTimeout, } s.serveDNS(ctx, m, rw) @@ -75,15 +73,18 @@ func (s *ServerDNS) serveUDPPacket( } // readUDPMsg reads the next incoming DNS message. -func (s *ServerDNS) readUDPMsg(ctx context.Context, conn *net.UDPConn) (msg []byte, sess *dns.SessionUDP, err error) { +func (s *ServerDNS) readUDPMsg( + ctx context.Context, + conn net.PacketConn, +) (msg []byte, sess netext.PacketSession, err error) { err = conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) if err != nil { return nil, nil, err } m := s.getUDPBuffer() - var n int - n, sess, err = dns.ReadFromSessionUDP(conn, m) + + n, sess, err := netext.ReadFromSession(conn, m) if err != nil { s.putUDPBuffer(m) @@ -120,30 +121,10 @@ func (s *ServerDNS) putUDPBuffer(m []byte) { s.udpPool.Put(&m) } -// setUDPSocketOptions is a function that is necessary to be able to use -// dns.ReadFromSessionUDP and dns.WriteToSessionUDP. -// TODO(ameshkov): https://github.com/AdguardTeam/AdGuardHome/issues/2807 -func setUDPSocketOptions(conn *net.UDPConn) (err error) { - if runtime.GOOS == "windows" { - return nil - } - - // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. - // Try enabling receiving of ECN and packet info for both IP versions. - // We expect at least one of those syscalls to succeed. - err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) - err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) - if err4 != nil && err6 != nil { - return errors.List("error while setting NetworkUDP socket options", err4, err6) - } - - return nil -} - // udpResponseWriter is a ResponseWriter implementation for DNS-over-UDP. type udpResponseWriter struct { - udpSession *dns.SessionUDP - conn *net.UDPConn + udpSession netext.PacketSession + conn net.PacketConn writeTimeout time.Duration } @@ -152,13 +133,15 @@ var _ ResponseWriter = (*udpResponseWriter)(nil) // LocalAddr implements the ResponseWriter interface for *udpResponseWriter. func (r *udpResponseWriter) LocalAddr() (addr net.Addr) { - return r.conn.LocalAddr() + // Don't use r.conn.LocalAddr(), since udpSession may actually contain the + // decoded OOB data, including the real local (dst) address. + return r.udpSession.LocalAddr() } // RemoteAddr implements the ResponseWriter interface for *udpResponseWriter. func (r *udpResponseWriter) RemoteAddr() (addr net.Addr) { - // Don't use r.conn.RemoteAddr(), since udpSession actually contains the - // decoded OOB data, including the remote address. + // Don't use r.conn.RemoteAddr(), since udpSession may actually contain the + // decoded OOB data, including the real remote (src) address. return r.udpSession.RemoteAddr() } @@ -173,7 +156,7 @@ func (r *udpResponseWriter) WriteMsg(ctx context.Context, req, resp *dns.Msg) (e } withWriteDeadline(ctx, r.writeTimeout, r.conn, func() { - _, err = dns.WriteToSessionUDP(r.conn, data, r.udpSession) + _, err = netext.WriteToSession(r.conn, data, r.udpSession) }) if err != nil { diff --git a/internal/dnsserver/serverhttps.go b/internal/dnsserver/serverhttps.go index 28bf24c..ba9c5d5 100644 --- a/internal/dnsserver/serverhttps.go +++ b/internal/dnsserver/serverhttps.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -94,6 +95,12 @@ var _ Server = (*ServerHTTPS)(nil) // NewServerHTTPS creates a new ServerHTTPS instance. func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) { + if conf.ListenConfig == nil { + // Do not enable OOB here, because ListenPacket is only used by HTTP/3, + // and quic-go sets the necessary flags. + conf.ListenConfig = netext.DefaultListenConfig() + } + s = &ServerHTTPS{ ServerBase: newServerBase(ProtoDoH, conf.ConfigBase), conf: conf, @@ -500,8 +507,7 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) { tlsConf.NextProtos = nextProtoDoH3 } - // Do not enable OOB here as quic-go will do that on its own. - conn, err := listenUDP(ctx, s.addr, false) + conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr) if err != nil { return err } @@ -518,24 +524,28 @@ func (s *ServerHTTPS) listenQUIC(ctx context.Context) (err error) { return nil } -// httpContextWithClientInfo adds client info to the context. +// httpContextWithClientInfo adds client info to the context. ctx is never nil, +// even when there is an error. func httpContextWithClientInfo( parent context.Context, r *http.Request, ) (ctx context.Context, err error) { + ctx = parent + ci := ClientInfo{ URL: netutil.CloneURL(r.URL), } - // Due to the quic-go bug we should use Host instead of r.TLS: - // https://github.com/lucas-clemente/quic-go/issues/3596 + // Due to the quic-go bug we should use Host instead of r.TLS. See + // https://github.com/quic-go/quic-go/issues/2879 and + // https://github.com/lucas-clemente/quic-go/issues/3596. // - // TODO(ameshkov): remove this when the bug is fixed in quic-go. + // TODO(ameshkov): Remove when quic-go is fixed, likely in v0.32.0. if r.ProtoAtLeast(3, 0) { var host string host, err = netutil.SplitHost(r.Host) if err != nil { - return nil, fmt.Errorf("failed to parse Host: %w", err) + return ctx, fmt.Errorf("failed to parse Host: %w", err) } ci.TLSServerName = host @@ -543,7 +553,7 @@ func httpContextWithClientInfo( ci.TLSServerName = strings.ToLower(r.TLS.ServerName) } - return ContextWithClientInfo(parent, ci), nil + return ContextWithClientInfo(ctx, ci), nil } // httpRequestToMsg reads the DNS message from http.Request. diff --git a/internal/dnsserver/serverhttps_test.go b/internal/dnsserver/serverhttps_test.go index 3a8918a..f927798 100644 --- a/internal/dnsserver/serverhttps_test.go +++ b/internal/dnsserver/serverhttps_test.go @@ -133,7 +133,7 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) { require.True(t, resp.Response) // EDNS0 padding is only present when request also has padding opt. - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.Nil(t, paddingOpt) }) } @@ -338,7 +338,7 @@ func TestServerHTTPS_integration_ENDS0Padding(t *testing.T) { resp := mustDoHReq(t, addr, tlsConfig, http.MethodGet, false, false, req) require.True(t, resp.Response) - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.NotNil(t, paddingOpt) require.NotEmpty(t, paddingOpt.Padding) } diff --git a/internal/dnsserver/serverquic.go b/internal/dnsserver/serverquic.go index 609d3e9..fbca4ee 100644 --- a/internal/dnsserver/serverquic.go +++ b/internal/dnsserver/serverquic.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/bluele/gcache" @@ -98,6 +99,11 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) { tlsConfig.NextProtos = append([]string{nextProtoDoQ}, compatProtoDQ...) } + if conf.ListenConfig == nil { + // Do not enable OOB here as quic-go will do that on its own. + conf.ListenConfig = netext.DefaultListenConfig() + } + s = &ServerQUIC{ ServerBase: newServerBase(ProtoDoQ, conf.ConfigBase), conf: conf, @@ -476,8 +482,7 @@ func (s *ServerQUIC) readQUICMsg( // listenQUIC creates the UDP listener for the ServerQUIC.addr and also starts // the QUIC listener. func (s *ServerQUIC) listenQUIC(ctx context.Context) (err error) { - // Do not enable OOB here as quic-go will do that on its own. - conn, err := listenUDP(ctx, s.addr, false) + conn, err := s.listenConfig.ListenPacket(ctx, "udp", s.addr) if err != nil { return err } diff --git a/internal/dnsserver/serverquic_test.go b/internal/dnsserver/serverquic_test.go index 88d09a1..600c975 100644 --- a/internal/dnsserver/serverquic_test.go +++ b/internal/dnsserver/serverquic_test.go @@ -59,7 +59,7 @@ func TestServerQUIC_integration_query(t *testing.T) { assert.True(t, resp.Response) // EDNS0 padding is only present when request also has padding opt. - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.Nil(t, paddingOpt) }() } @@ -97,7 +97,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) { require.True(t, resp.Response) require.False(t, resp.Truncated) - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.NotNil(t, paddingOpt) require.NotEmpty(t, paddingOpt.Padding) } diff --git a/internal/dnsserver/servertls.go b/internal/dnsserver/servertls.go index 0a8baaa..2659f5a 100644 --- a/internal/dnsserver/servertls.go +++ b/internal/dnsserver/servertls.go @@ -3,7 +3,6 @@ package dnsserver import ( "context" "crypto/tls" - "net" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -102,10 +101,9 @@ func (s *ServerTLS) startServeTCP(ctx context.Context) { } } -// listenTLS creates the TLS listener for the ServerTLS.addr. +// listenTLS creates the TLS listener for s.addr. func (s *ServerTLS) listenTLS(ctx context.Context) (err error) { - var l net.Listener - l, err = listenTCP(ctx, s.addr) + l, err := s.listenConfig.Listen(ctx, "tcp", s.addr) if err != nil { return err } diff --git a/internal/dnsserver/servertls_test.go b/internal/dnsserver/servertls_test.go index d33390a..aa37a3c 100644 --- a/internal/dnsserver/servertls_test.go +++ b/internal/dnsserver/servertls_test.go @@ -40,7 +40,7 @@ func TestServerTLS_integration_queryTLS(t *testing.T) { require.False(t, resp.Truncated) // EDNS0 padding is only present when request also has padding opt. - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.Nil(t, paddingOpt) } @@ -142,7 +142,7 @@ func TestServerTLS_integration_noTruncateQuery(t *testing.T) { require.False(t, resp.Truncated) // EDNS0 padding is only present when request also has padding opt. - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.Nil(t, paddingOpt) } @@ -231,7 +231,7 @@ func TestServerTLS_integration_ENDS0Padding(t *testing.T) { require.True(t, resp.Response) require.False(t, resp.Truncated) - paddingOpt := dnsservertest.FindENDS0Option[*dns.EDNS0_PADDING](resp) + paddingOpt := dnsservertest.FindEDNS0Option[*dns.EDNS0_PADDING](resp) require.NotNil(t, paddingOpt) require.NotEmpty(t, paddingOpt.Padding) } diff --git a/internal/dnssvc/debug_internal_test.go b/internal/dnssvc/debug_internal_test.go new file mode 100644 index 0000000..24d1c20 --- /dev/null +++ b/internal/dnssvc/debug_internal_test.go @@ -0,0 +1,172 @@ +package dnssvc + +import ( + "context" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTXTExtra is a helper function that converts strs into DNS TXT resource +// records with Name and Txt fields set to first and second values of each +// tuple. +func newTXTExtra(strs [][2]string) (extra []dns.RR) { + for _, v := range strs { + extra = append(extra, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: v[0], + Rrtype: dns.TypeTXT, + Class: dns.ClassCHAOS, + Ttl: 1, + }, + Txt: []string{v[1]}, + }) + } + + return extra +} + +func TestService_writeDebugResponse(t *testing.T) { + svc := &Service{messages: &dnsmsg.Constructor{FilteredResponseTTL: time.Second}} + + const ( + fltListID1 agd.FilterListID = "fl1" + fltListID2 agd.FilterListID = "fl2" + + blockRule = "||example.com^" + ) + + testCases := []struct { + name string + ri *agd.RequestInfo + reqRes filter.Result + respRes filter.Result + wantExtra []dns.RR + }{{ + name: "normal", + ri: &agd.RequestInfo{}, + reqRes: nil, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"resp.res-type.adguard-dns.com.", "normal"}, + }), + }, { + name: "request_result_blocked", + ri: &agd.RequestInfo{}, + reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule}, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"req.res-type.adguard-dns.com.", "blocked"}, + {"req.rule.adguard-dns.com.", "||example.com^"}, + {"req.rule-list-id.adguard-dns.com.", "fl1"}, + }), + }, { + name: "response_result_blocked", + ri: &agd.RequestInfo{}, + reqRes: nil, + respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule}, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"resp.res-type.adguard-dns.com.", "blocked"}, + {"resp.rule.adguard-dns.com.", "||example.com^"}, + {"resp.rule-list-id.adguard-dns.com.", "fl2"}, + }), + }, { + name: "request_result_allowed", + ri: &agd.RequestInfo{}, + reqRes: &filter.ResultAllowed{}, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"req.res-type.adguard-dns.com.", "allowed"}, + {"req.rule.adguard-dns.com.", ""}, + {"req.rule-list-id.adguard-dns.com.", ""}, + }), + }, { + name: "response_result_allowed", + ri: &agd.RequestInfo{}, + reqRes: nil, + respRes: &filter.ResultAllowed{}, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"resp.res-type.adguard-dns.com.", "allowed"}, + {"resp.rule.adguard-dns.com.", ""}, + {"resp.rule-list-id.adguard-dns.com.", ""}, + }), + }, { + name: "request_result_modified", + ri: &agd.RequestInfo{}, + reqRes: &filter.ResultModified{ + Rule: "||example.com^$dnsrewrite=REFUSED", + }, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"req.res-type.adguard-dns.com.", "modified"}, + {"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"}, + {"req.rule-list-id.adguard-dns.com.", ""}, + }), + }, { + name: "device", + ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}}, + reqRes: nil, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"device-id.adguard-dns.com.", "dev1234"}, + {"resp.res-type.adguard-dns.com.", "normal"}, + }), + }, { + name: "profile", + ri: &agd.RequestInfo{ + Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")}, + }, + reqRes: nil, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"profile-id.adguard-dns.com.", "some-profile-id"}, + {"resp.res-type.adguard-dns.com.", "normal"}, + }), + }, { + name: "location", + ri: &agd.RequestInfo{Location: &agd.Location{Country: agd.CountryAD}}, + reqRes: nil, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"country.adguard-dns.com.", "AD"}, + {"asn.adguard-dns.com.", "0"}, + {"resp.res-type.adguard-dns.com.", "normal"}, + }), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr) + + ctx := agd.ContextWithRequestInfo(context.Background(), tc.ri) + + req := dnsservertest.NewReq("example.com", dns.TypeA, dns.ClassINET) + resp := dnsservertest.NewResp(dns.RcodeSuccess, req) + + err := svc.writeDebugResponse(ctx, rw, req, resp, tc.reqRes, tc.respRes) + require.NoError(t, err) + + msg := rw.Msg() + require.NotNil(t, msg) + + assert.Equal(t, tc.wantExtra, msg.Extra) + }) + } +} diff --git a/internal/dnssvc/deviceid.go b/internal/dnssvc/deviceid.go index 60b67f9..9710664 100644 --- a/internal/dnssvc/deviceid.go +++ b/internal/dnssvc/deviceid.go @@ -195,7 +195,15 @@ func deviceIDFromEDNS(req *dns.Msg) (id agd.DeviceID, err error) { continue } - return agd.NewDeviceID(string(o.Data)) + id, err = agd.NewDeviceID(string(o.Data)) + if err != nil { + return "", &deviceIDError{ + err: err, + typ: "edns option", + } + } + + return id, nil } return "", nil diff --git a/internal/dnssvc/deviceid_internal_test.go b/internal/dnssvc/deviceid_internal_test.go index 30f7b50..2058ae8 100644 --- a/internal/dnssvc/deviceid_internal_test.go +++ b/internal/dnssvc/deviceid_internal_test.go @@ -233,7 +233,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) { Data: []byte{}, }, wantDeviceID: "", - wantErrMsg: `bad device id "": too short: got 0 bytes, min 1`, + wantErrMsg: `edns option device id check: bad device id "": ` + + `too short: got 0 bytes, min 1`, }, { name: "bad_device_id", opt: &dns.EDNS0_LOCAL{ @@ -241,7 +242,8 @@ func TestService_Wrap_deviceIDFromEDNS(t *testing.T) { Data: []byte("toolongdeviceid"), }, wantDeviceID: "", - wantErrMsg: `bad device id "toolongdeviceid": too long: got 15 bytes, max 8`, + wantErrMsg: `edns option device id check: bad device id "toolongdeviceid": ` + + `too long: got 15 bytes, max 8`, }, { name: "device_id", opt: &dns.EDNS0_LOCAL{ diff --git a/internal/dnssvc/dnssvc.go b/internal/dnssvc/dnssvc.go index f5a6a30..d3ba3da 100644 --- a/internal/dnssvc/dnssvc.go +++ b/internal/dnssvc/dnssvc.go @@ -9,7 +9,6 @@ import ( "fmt" "net/http" "net/netip" - "net/url" "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/billstat" @@ -79,10 +78,6 @@ type Config struct { // Upstream defines the upstream server and the group of fallback servers. Upstream *agd.Upstream - // RootRedirectURL is the URL to which non-DNS and non-Debug HTTP requests - // are redirected. - RootRedirectURL *url.URL - // NewListener, when set, is used instead of the package-level function // NewListener when creating a DNS listener. // @@ -119,6 +114,11 @@ type Config struct { // UseECSCache shows if the EDNS Client Subnet (ECS) aware cache should be // used. UseECSCache bool + + // ResearchMetrics controls whether research metrics are enabled or not. + // This is a set of metrics that we may need temporary, so its collection is + // controlled by a separate setting. + ResearchMetrics bool } // New returns a new DNS service. @@ -150,7 +150,6 @@ func New(c *Config) (svc *Service, err error) { groups := make([]*serverGroup, len(c.ServerGroups)) svc = &Service{ messages: c.Messages, - rootRedirectURL: c.RootRedirectURL, billStat: c.BillStat, errColl: c.ErrColl, fltStrg: c.FilterStorage, @@ -158,6 +157,7 @@ func New(c *Config) (svc *Service, err error) { queryLog: c.QueryLog, ruleStat: c.RuleStat, groups: groups, + researchMetrics: c.ResearchMetrics, } for i, srvGrp := range c.ServerGroups { @@ -212,7 +212,6 @@ var _ agd.Service = (*Service)(nil) // Service is the main DNS service of AdGuard DNS. type Service struct { messages *dnsmsg.Constructor - rootRedirectURL *url.URL billStat billstat.Recorder errColl agd.ErrorCollector fltStrg filter.Storage @@ -220,6 +219,7 @@ type Service struct { queryLog querylog.Interface ruleStat rulestat.Interface groups []*serverGroup + researchMetrics bool } // mustStartListener starts l and panics on any error. diff --git a/internal/dnssvc/initmw.go b/internal/dnssvc/initmw.go index e7561bf..87d5c27 100644 --- a/internal/dnssvc/initmw.go +++ b/internal/dnssvc/initmw.go @@ -260,14 +260,6 @@ func (mh *initMwHandler) ServeDNS( // Copy middleware to the local variable to make the code simpler. mw := mh.mw - if specHdlr, name := mw.noReqInfoSpecialHandler(fqdn, qt, cl); specHdlr != nil { - optlog.Debug1("init mw: got no-req-info special handler %s", name) - - // Don't wrap the error, because it's informative enough as is, and - // because if handled is true, the main flow terminates here. - return specHdlr(ctx, rw, req) - } - // Get the request's information, such as GeoIP data and user profiles. ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl) if err != nil { diff --git a/internal/dnssvc/initmw_internal_test.go b/internal/dnssvc/initmw_internal_test.go index ab05594..f5cb8b7 100644 --- a/internal/dnssvc/initmw_internal_test.go +++ b/internal/dnssvc/initmw_internal_test.go @@ -277,7 +277,7 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) { } } -func TestInitMw_ServeDNS_privateRelay(t *testing.T) { +func TestInitMw_ServeDNS_specialDomain(t *testing.T) { testCases := []struct { name string host string @@ -287,7 +287,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked bool wantRCode dnsmsg.RCode }{{ - name: "blocked_by_fltgrp", + name: "private_relay_blocked_by_fltgrp", host: applePrivateRelayMaskHost, qtype: dns.TypeA, fltGrpBlocked: true, @@ -295,7 +295,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked: false, wantRCode: dns.RcodeNameError, }, { - name: "no_private_relay_domain", + name: "no_special_domain", host: "www.example.com", qtype: dns.TypeA, fltGrpBlocked: true, @@ -311,7 +311,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked: false, wantRCode: dns.RcodeSuccess, }, { - name: "blocked_by_prof", + name: "private_relay_blocked_by_prof", host: applePrivateRelayMaskHost, qtype: dns.TypeA, fltGrpBlocked: false, @@ -319,7 +319,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked: true, wantRCode: dns.RcodeNameError, }, { - name: "allowed_by_prof", + name: "private_relay_allowed_by_prof", host: applePrivateRelayMaskHost, qtype: dns.TypeA, fltGrpBlocked: true, @@ -327,7 +327,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked: false, wantRCode: dns.RcodeSuccess, }, { - name: "allowed_by_both", + name: "private_relay_allowed_by_both", host: applePrivateRelayMaskHost, qtype: dns.TypeA, fltGrpBlocked: false, @@ -335,13 +335,45 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { profBlocked: false, wantRCode: dns.RcodeSuccess, }, { - name: "blocked_by_both", + name: "private_relay_blocked_by_both", host: applePrivateRelayMaskHost, qtype: dns.TypeA, fltGrpBlocked: true, hasProf: true, profBlocked: true, wantRCode: dns.RcodeNameError, + }, { + name: "firefox_canary_allowed_by_prof", + host: firefoxCanaryHost, + qtype: dns.TypeA, + fltGrpBlocked: false, + hasProf: true, + profBlocked: false, + wantRCode: dns.RcodeSuccess, + }, { + name: "firefox_canary_allowed_by_fltgrp", + host: firefoxCanaryHost, + qtype: dns.TypeA, + fltGrpBlocked: false, + hasProf: false, + profBlocked: false, + wantRCode: dns.RcodeSuccess, + }, { + name: "firefox_canary_blocked_by_prof", + host: firefoxCanaryHost, + qtype: dns.TypeA, + fltGrpBlocked: false, + hasProf: true, + profBlocked: true, + wantRCode: dns.RcodeRefused, + }, { + name: "firefox_canary_blocked_by_fltgrp", + host: firefoxCanaryHost, + qtype: dns.TypeA, + fltGrpBlocked: true, + hasProf: false, + profBlocked: false, + wantRCode: dns.RcodeRefused, }} for _, tc := range testCases { @@ -368,9 +400,12 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { return nil, nil, agd.DeviceNotFoundError{} } - return &agd.Profile{ - BlockPrivateRelay: tc.profBlocked, - }, &agd.Device{}, nil + prof := &agd.Profile{ + BlockPrivateRelay: tc.profBlocked, + BlockFirefoxCanary: tc.profBlocked, + } + + return prof, &agd.Device{}, nil } db := &agdtest.ProfileDB{ OnProfileByDeviceID: func( @@ -406,7 +441,8 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { FilteredResponseTTL: 10 * time.Second, }, fltGrp: &agd.FilteringGroup{ - BlockPrivateRelay: tc.fltGrpBlocked, + BlockPrivateRelay: tc.fltGrpBlocked, + BlockFirefoxCanary: tc.fltGrpBlocked, }, srvGrp: &agd.ServerGroup{}, srv: &agd.Server{ @@ -436,7 +472,7 @@ func TestInitMw_ServeDNS_privateRelay(t *testing.T) { resp := rw.Msg() require.NotNil(t, resp) - assert.Equal(t, dnsmsg.RCode(resp.Rcode), tc.wantRCode) + assert.Equal(t, tc.wantRCode, dnsmsg.RCode(resp.Rcode)) }) } } diff --git a/internal/dnssvc/middleware.go b/internal/dnssvc/middleware.go index 628dd68..5d0719b 100644 --- a/internal/dnssvc/middleware.go +++ b/internal/dnssvc/middleware.go @@ -101,7 +101,7 @@ func (svc *Service) filterQuery( ) (reqRes, respRes filter.Result) { start := time.Now() defer func() { - reportMetrics(ri, reqRes, respRes, time.Since(start)) + svc.reportMetrics(ri, reqRes, respRes, time.Since(start)) }() f := svc.fltStrg.FilterFromContext(ctx, ri) @@ -120,7 +120,7 @@ func (svc *Service) filterQuery( // reportMetrics extracts filtering metrics data from the context and reports it // to Prometheus. -func reportMetrics( +func (svc *Service) reportMetrics( ri *agd.RequestInfo, reqRes filter.Result, respRes filter.Result, @@ -139,7 +139,7 @@ func reportMetrics( metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc() metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc() - id, _, _ := filteringData(reqRes, respRes) + id, _, blocked := filteringData(reqRes, respRes) metrics.DNSSvcRequestByFilterTotal.WithLabelValues( string(id), metrics.BoolString(ri.Profile == nil), @@ -147,6 +147,22 @@ func reportMetrics( metrics.DNSSvcFilteringDuration.Observe(elapsedFiltering.Seconds()) metrics.DNSSvcUsersCountUpdate(ri.RemoteIP) + + if svc.researchMetrics { + anonymous := ri.Profile == nil + filteringEnabled := ri.FilteringGroup != nil && + ri.FilteringGroup.RuleListsEnabled && + len(ri.FilteringGroup.RuleListIDs) > 0 + + metrics.ReportResearchMetrics( + anonymous, + filteringEnabled, + asn, + ctry, + string(id), + blocked, + ) + } } // reportf is a helper method for reporting non-critical errors. diff --git a/internal/dnssvc/presvcmw_test.go b/internal/dnssvc/presvcmw_test.go new file mode 100644 index 0000000..345dac6 --- /dev/null +++ b/internal/dnssvc/presvcmw_test.go @@ -0,0 +1,135 @@ +package dnssvc + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPreServiceMwHandler_ServeDNS(t *testing.T) { + const safeBrowsingHost = "scam.example.net." + + var ( + ip = net.IP{127, 0, 0, 1} + name = "example.com" + ) + + sum := sha256.Sum256([]byte(safeBrowsingHost)) + hashStr := hex.EncodeToString(sum[:]) + hashes, herr := hashstorage.New(safeBrowsingHost) + require.NoError(t, herr) + + srv := filter.NewSafeBrowsingServer(hashes, nil) + host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix + + ctx := context.Background() + ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{}) + ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{}) + ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) + + const ttl = 60 + + testCases := []struct { + name string + req *dns.Msg + dnscheckResp *dns.Msg + ri *agd.RequestInfo + wantAns []dns.RR + }{{ + name: "normal", + req: dnsservertest.CreateMessage(name, dns.TypeA), + dnscheckResp: nil, + ri: &agd.RequestInfo{}, + wantAns: []dns.RR{ + dnsservertest.NewA(name, 100, ip), + }, + }, { + name: "dnscheck", + req: dnsservertest.CreateMessage(name, dns.TypeA), + dnscheckResp: dnsservertest.NewResp( + dns.RcodeSuccess, + dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET), + dnsservertest.RRSection{ + RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)}, + Sec: dnsservertest.SectionAnswer, + }, + ), + ri: &agd.RequestInfo{ + Host: name, + QType: dns.TypeA, + QClass: dns.ClassINET, + }, + wantAns: []dns.RR{ + dnsservertest.NewA(name, ttl, ip), + }, + }, { + name: "with_hashes", + req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT), + dnscheckResp: nil, + ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT}, + wantAns: []dns.RR{&dns.TXT{ + Hdr: dns.RR_Header{ + Name: safeBrowsingHost, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: ttl, + }, + Txt: []string{hashStr}, + }}, + }, { + name: "not_matched", + req: dnsservertest.CreateMessage(name, dns.TypeTXT), + dnscheckResp: nil, + ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT}, + wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr) + tctx := agd.ContextWithRequestInfo(ctx, tc.ri) + + dnsCk := &agdtest.DNSCheck{ + OnCheck: func( + ctx context.Context, + msg *dns.Msg, + ri *agd.RequestInfo, + ) (resp *dns.Msg, err error) { + return tc.dnscheckResp, nil + }, + } + + mw := &preServiceMw{ + messages: &dnsmsg.Constructor{ + FilteredResponseTTL: ttl * time.Second, + }, + filter: srv, + checker: dnsCk, + } + handler := dnsservertest.DefaultHandler() + h := mw.Wrap(handler) + + err := h.ServeDNS(tctx, rw, tc.req) + require.NoError(t, err) + + msg := rw.Msg() + require.NotNil(t, msg) + + assert.Equal(t, tc.wantAns, msg.Answer) + }) + } +} diff --git a/internal/dnssvc/preupstreammw_test.go b/internal/dnssvc/preupstreammw_test.go new file mode 100644 index 0000000..58a0c91 --- /dev/null +++ b/internal/dnssvc/preupstreammw_test.go @@ -0,0 +1,264 @@ +package dnssvc + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + reqHostname = "example.com." + defaultTTL = 3600 +) + +func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) { + remoteIP := netip.MustParseAddr("1.2.3.4") + aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) + respIP := remoteIP.AsSlice() + + resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ + RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)}, + Sec: dnsservertest.SectionAnswer, + }) + ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{ + Host: aReq.Question[0].Name, + }) + + const N = 5 + testCases := []struct { + name string + cacheSize int + wantNumReq int + }{{ + name: "no_cache", + cacheSize: 0, + wantNumReq: N, + }, { + name: "with_cache", + cacheSize: 100, + wantNumReq: 1, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + numReq := 0 + handler := dnsserver.HandlerFunc(func( + ctx context.Context, + rw dnsserver.ResponseWriter, + req *dns.Msg, + ) error { + numReq++ + + return rw.WriteMsg(ctx, req, resp) + }) + + mw := &preUpstreamMw{ + db: dnsdb.Empty{}, + cacheSize: tc.cacheSize, + } + h := mw.Wrap(handler) + + for i := 0; i < N; i++ { + req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) + addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53} + nrw := dnsserver.NewNonWriterResponseWriter(addr, addr) + + err := h.ServeDNS(ctx, nrw, req) + require.NoError(t, err) + } + + assert.Equal(t, tc.wantNumReq, numReq) + }) + } +} + +func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) { + aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) + remoteIP := netip.MustParseAddr("1.2.3.4") + subnet := netip.MustParsePrefix("1.2.3.4/24") + + const ctry = agd.CountryAD + + resp := dnsservertest.NewResp( + dns.RcodeSuccess, + aReq, + dnsservertest.RRSection{ + RRs: []dns.RR{dnsservertest.NewA( + reqHostname, + defaultTTL, + net.IP{1, 2, 3, 4}, + )}, + Sec: dnsservertest.SectionAnswer, + }, + ) + + numReq := 0 + handler := dnsserver.HandlerFunc( + func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) error { + numReq++ + + return rw.WriteMsg(ctx, req, resp) + }, + ) + + geoIP := &agdtest.GeoIP{ + OnSubnetByLocation: func( + ctry agd.Country, + _ agd.ASN, + _ netutil.AddrFamily, + ) (n netip.Prefix, err error) { + return netip.MustParsePrefix("1.2.0.0/16"), nil + }, + OnData: func(_ string, _ netip.Addr) (_ *agd.Location, _ error) { + panic("not implemented") + }, + } + + mw := &preUpstreamMw{ + db: dnsdb.Empty{}, + geoIP: geoIP, + cacheSize: 100, + ecsCacheSize: 100, + useECSCache: true, + } + h := mw.Wrap(handler) + + ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{ + Location: &agd.Location{ + Country: ctry, + }, + ECS: &agd.ECS{ + Location: &agd.Location{ + Country: ctry, + }, + Subnet: subnet, + Scope: 0, + }, + Host: aReq.Question[0].Name, + RemoteIP: remoteIP, + }) + + const N = 5 + var nrw *dnsserver.NonWriterResponseWriter + for i := 0; i < N; i++ { + addr := &net.UDPAddr{IP: remoteIP.AsSlice(), Port: 53} + nrw = dnsserver.NewNonWriterResponseWriter(addr, addr) + req := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) + + err := h.ServeDNS(ctx, nrw, req) + require.NoError(t, err) + } + + assert.Equal(t, 1, numReq) +} + +func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) { + mw := &preUpstreamMw{db: dnsdb.Empty{}} + + req := dnsservertest.CreateMessage("example.com", dns.TypeA) + resp := new(dns.Msg).SetReply(req) + + ctx := context.Background() + ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{}) + ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{}) + ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) + ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{}) + + const ttl = 100 + + testCases := []struct { + name string + req *dns.Msg + resp *dns.Msg + wantName string + wantAns []dns.RR + }{{ + name: "no_changes", + req: dnsservertest.CreateMessage("example.com.", dns.TypeA), + resp: resp, + wantName: "example.com.", + wantAns: nil, + }, { + name: "android-tls-metric", + req: dnsservertest.CreateMessage( + "12345678-dnsotls-ds.metric.gstatic.com.", + dns.TypeA, + ), + resp: resp, + wantName: "00000000-dnsotls-ds.metric.gstatic.com.", + wantAns: nil, + }, { + name: "android-https-metric", + req: dnsservertest.CreateMessage( + "123456-dnsohttps-ds.metric.gstatic.com.", + dns.TypeA, + ), + resp: resp, + wantName: "000000-dnsohttps-ds.metric.gstatic.com.", + wantAns: nil, + }, { + name: "multiple_answers_metric", + req: dnsservertest.CreateMessage( + "123456-dnsohttps-ds.metric.gstatic.com.", + dns.TypeA, + ), + resp: dnsservertest.NewResp( + dns.RcodeSuccess, + req, + dnsservertest.RRSection{ + RRs: []dns.RR{dnsservertest.NewA( + "123456-dnsohttps-ds.metric.gstatic.com.", + ttl, + net.IP{1, 2, 3, 4}, + ), dnsservertest.NewA( + "654321-dnsohttps-ds.metric.gstatic.com.", + ttl, + net.IP{1, 2, 3, 5}, + )}, + Sec: dnsservertest.SectionAnswer, + }, + ), + wantName: "000000-dnsohttps-ds.metric.gstatic.com.", + wantAns: []dns.RR{ + dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}), + dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}), + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler := dnsserver.HandlerFunc(func( + ctx context.Context, + rw dnsserver.ResponseWriter, + req *dns.Msg, + ) error { + assert.Equal(t, tc.wantName, req.Question[0].Name) + + return rw.WriteMsg(ctx, req, tc.resp) + }) + h := mw.Wrap(handler) + + rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr) + + err := h.ServeDNS(ctx, rw, tc.req) + require.NoError(t, err) + + msg := rw.Msg() + require.NotNil(t, msg) + + assert.Equal(t, tc.wantAns, msg.Answer) + }) + } +} diff --git a/internal/dnssvc/specialdomain.go b/internal/dnssvc/specialdomain.go index 2c5d605..2803626 100644 --- a/internal/dnssvc/specialdomain.go +++ b/internal/dnssvc/specialdomain.go @@ -32,57 +32,23 @@ const ( // Resolvers for querying the resolver with unknown or absent name. ddrDomain = ddrLabel + "." + resolverArpaDomain - // firefoxCanaryFQDN is the fully-qualified canary domain that Firefox uses - // to check if it should use its own DNS-over-HTTPS settings. + // firefoxCanaryHost is the hostname that Firefox uses to check if it + // should use its own DNS-over-HTTPS settings. // // See https://support.mozilla.org/en-US/kb/configuring-networks-disable-dns-over-https. - firefoxCanaryFQDN = "use-application-dns.net." - - // applePrivateRelayMaskHost and applePrivateRelayMaskH2Host are the - // hostnames that Apple devices use to check if Apple Private Relay can be - // enabled. Returning NXDOMAIN to queries for these domain names blocks - // Apple Private Relay. - // - // See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay. - applePrivateRelayMaskHost = "mask.icloud.com" - applePrivateRelayMaskH2Host = "mask-h2.icloud.com" + firefoxCanaryHost = "use-application-dns.net" ) -// noReqInfoSpecialHandler returns a handler that can handle a special-domain -// query based only on its question type, class, and target, as well as the -// handler's name for debugging. -func (mw *initMw) noReqInfoSpecialHandler( - fqdn string, - qt dnsmsg.RRType, - cl dnsmsg.Class, -) (f dnsserver.HandlerFunc, name string) { - if cl != dns.ClassINET { - return nil, "" - } - - if (qt == dns.TypeA || qt == dns.TypeAAAA) && fqdn == firefoxCanaryFQDN { - return mw.handleFirefoxCanary, "firefox" - } - - return nil, "" -} - -// Firefox Canary - -// handleFirefoxCanary checks if the request is for the fully-qualified domain -// name that Firefox uses to check DoH settings and writes a response if needed. -func (mw *initMw) handleFirefoxCanary( - ctx context.Context, - rw dnsserver.ResponseWriter, - req *dns.Msg, -) (err error) { - metrics.DNSSvcFirefoxRequestsTotal.Inc() - - resp := mw.messages.NewMsgREFUSED(req) - err = rw.WriteMsg(ctx, req, resp) - - return errors.Annotate(err, "writing firefox canary resp: %w") -} +// Hostnames that Apple devices use to check if Apple Private Relay can be +// enabled. Returning NXDOMAIN to queries for these domain names blocks Apple +// Private Relay. +// +// See https://developer.apple.com/support/prepare-your-network-for-icloud-private-relay. +const ( + applePrivateRelayMaskHost = "mask.icloud.com" + applePrivateRelayMaskH2Host = "mask-h2.icloud.com" + applePrivateRelayMaskCanaryHost = "mask-canary.icloud.com" +) // reqInfoSpecialHandler returns a handler that can handle a special-domain // query based on the request info, as well as the handler's name for debugging. @@ -99,11 +65,9 @@ func (mw *initMw) reqInfoSpecialHandler( } else if netutil.IsSubdomain(ri.Host, resolverArpaDomain) { // A badly formed resolver.arpa subdomain query. return mw.handleBadResolverARPA, "bad_resolver_arpa" - } else if shouldBlockPrivateRelay(ri) { - return mw.handlePrivateRelay, "apple_private_relay" } - return nil, "" + return mw.specialDomainHandler(ri) } // reqInfoHandlerFunc is an alias for handler functions that additionally accept @@ -222,24 +186,46 @@ func (mw *initMw) handleBadResolverARPA( return errors.Annotate(err, "writing nodata resp for %q: %w", ri.Host) } +// specialDomainHandler returns a handler that can handle a special-domain +// query for Apple Private Relay or Firefox canary domain based on the request +// or profile information, as well as the handler's name for debugging. +func (mw *initMw) specialDomainHandler( + ri *agd.RequestInfo, +) (f reqInfoHandlerFunc, name string) { + qt := ri.QType + if qt != dns.TypeA && qt != dns.TypeAAAA { + return nil, "" + } + + host := ri.Host + prof := ri.Profile + + switch host { + case + applePrivateRelayMaskHost, + applePrivateRelayMaskH2Host, + applePrivateRelayMaskCanaryHost: + if shouldBlockPrivateRelay(ri, prof) { + return mw.handlePrivateRelay, "apple_private_relay" + } + case firefoxCanaryHost: + if shouldBlockFirefoxCanary(ri, prof) { + return mw.handleFirefoxCanary, "firefox" + } + default: + // Go on. + } + + return nil, "" +} + // Apple Private Relay // shouldBlockPrivateRelay returns true if the query is for an Apple Private -// Relay check domain and the request information indicates that Apple Private -// Relay should be blocked. -func shouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) { - qt := ri.QType - host := ri.Host - - return (qt == dns.TypeA || qt == dns.TypeAAAA) && - (host == applePrivateRelayMaskHost || host == applePrivateRelayMaskH2Host) && - reqInfoShouldBlockPrivateRelay(ri) -} - -// reqInfoShouldBlockPrivateRelay returns true if Apple Private Relay queries -// should be blocked based on the request information. -func reqInfoShouldBlockPrivateRelay(ri *agd.RequestInfo) (ok bool) { - if prof := ri.Profile; prof != nil { +// Relay check domain and the request information or profile indicates that +// Apple Private Relay should be blocked. +func shouldBlockPrivateRelay(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) { + if prof != nil { return prof.BlockPrivateRelay } @@ -260,3 +246,32 @@ func (mw *initMw) handlePrivateRelay( return errors.Annotate(err, "writing private relay resp: %w") } + +// Firefox canary domain + +// shouldBlockFirefoxCanary returns true if the query is for a Firefox canary +// domain and the request information or profile indicates that Firefox canary +// domain should be blocked. +func shouldBlockFirefoxCanary(ri *agd.RequestInfo, prof *agd.Profile) (ok bool) { + if prof != nil { + return prof.BlockFirefoxCanary + } + + return ri.FilteringGroup.BlockFirefoxCanary +} + +// handleFirefoxCanary checks if the request is for the fully-qualified domain +// name that Firefox uses to check DoH settings and writes a response if needed. +func (mw *initMw) handleFirefoxCanary( + ctx context.Context, + rw dnsserver.ResponseWriter, + req *dns.Msg, + ri *agd.RequestInfo, +) (err error) { + metrics.DNSSvcFirefoxRequestsTotal.Inc() + + resp := mw.messages.NewMsgREFUSED(req) + err = rw.WriteMsg(ctx, req, resp) + + return errors.Annotate(err, "writing firefox canary resp: %w") +} diff --git a/internal/ecscache/cache.go b/internal/ecscache/cache.go index e59c900..860a504 100644 --- a/internal/ecscache/cache.go +++ b/internal/ecscache/cache.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/mathutil" "github.com/bluele/gcache" "github.com/miekg/dns" ) @@ -93,22 +94,16 @@ func (mw *Middleware) toCacheKey(cr *cacheRequest, hostHasECS bool) (key uint64) var buf [6]byte binary.LittleEndian.PutUint16(buf[:2], cr.qType) binary.LittleEndian.PutUint16(buf[2:4], cr.qClass) - if cr.reqDO { - buf[4] = 1 - } else { - buf[4] = 0 - } - if cr.subnet.Addr().Is4() { - buf[5] = 0 - } else { - buf[5] = 1 - } + buf[4] = mathutil.BoolToNumber[byte](cr.reqDO) + + addr := cr.subnet.Addr() + buf[5] = mathutil.BoolToNumber[byte](addr.Is6()) _, _ = h.Write(buf[:]) if hostHasECS { - _, _ = h.Write(cr.subnet.Addr().AsSlice()) + _, _ = h.Write(addr.AsSlice()) _ = h.WriteByte(byte(cr.subnet.Bits())) } diff --git a/internal/ecscache/msg.go b/internal/ecscache/msg.go index 1730302..5e7611d 100644 --- a/internal/ecscache/msg.go +++ b/internal/ecscache/msg.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/golibs/mathutil" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -269,9 +270,5 @@ func getTTLIfLower(r dns.RR, ttl uint32) (res uint32) { // Go on. } - if httl := r.Header().Ttl; httl < ttl { - return httl - } - - return ttl + return mathutil.Min(r.Header().Ttl, ttl) } diff --git a/internal/errcoll/sentry.go b/internal/errcoll/sentry.go index 26f7fdf..e4eb994 100644 --- a/internal/errcoll/sentry.go +++ b/internal/errcoll/sentry.go @@ -67,13 +67,13 @@ type SentryReportableError interface { // TODO(a.garipov): Make sure that we use this approach everywhere. func isReportable(err error) (ok bool) { var ( - ravErr SentryReportableError - fwdErr *forward.Error - dnsWErr *dnsserver.WriteError + sentryRepErr SentryReportableError + fwdErr *forward.Error + dnsWErr *dnsserver.WriteError ) - if errors.As(err, &ravErr) { - return ravErr.IsSentryReportable() + if errors.As(err, &sentryRepErr) { + return sentryRepErr.IsSentryReportable() } else if errors.As(err, &fwdErr) { return isReportableNetwork(fwdErr.Err) } else if errors.As(err, &dnsWErr) { diff --git a/internal/filter/compfilter.go b/internal/filter/compfilter.go index 090def6..cbfa521 100644 --- a/internal/filter/compfilter.go +++ b/internal/filter/compfilter.go @@ -21,8 +21,8 @@ var _ Interface = (*compFilter)(nil) // compFilter is a composite filter based on several types of safe search // filters and rule lists. type compFilter struct { - safeBrowsing *hashPrefixFilter - adultBlocking *hashPrefixFilter + safeBrowsing *HashPrefix + adultBlocking *HashPrefix genSafeSearch *safeSearch ytSafeSearch *safeSearch diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 6b8dfab..cdc402a 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -18,10 +18,11 @@ import ( // maxFilterSize is the maximum size of downloaded filters. const maxFilterSize = 196 * int64(datasize.MB) -// defaultTimeout is the default timeout to use when fetching filter data. +// defaultFilterRefreshTimeout is the default timeout to use when fetching +// filter lists data. // // TODO(a.garipov): Consider making timeouts where they are used configurable. -const defaultTimeout = 30 * time.Second +const defaultFilterRefreshTimeout = 180 * time.Second // defaultResolveTimeout is the default timeout for resolving hosts for safe // search and safe browsing filters. diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go index 300c3d7..39383a0 100644 --- a/internal/filter/filter_test.go +++ b/internal/filter/filter_test.go @@ -181,24 +181,18 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) { FilterIndexURL: fltsURL, GeneralSafeSearchRulesURL: ssURL, YoutubeSafeSearchRulesURL: ssURL, - SafeBrowsing: &filter.HashPrefixConfig{ - CacheTTL: 1 * time.Hour, - CacheSize: 100, - }, - AdultBlocking: &filter.HashPrefixConfig{ - CacheTTL: 1 * time.Hour, - CacheSize: 100, - }, - Now: time.Now, - ErrColl: nil, - Resolver: nil, - CacheDir: cacheDir, - CustomFilterCacheSize: 100, - SafeSearchCacheSize: 100, - SafeSearchCacheTTL: 1 * time.Hour, - RuleListCacheSize: 100, - RefreshIvl: testRefreshIvl, - UseRuleListCache: false, + SafeBrowsing: &filter.HashPrefix{}, + AdultBlocking: &filter.HashPrefix{}, + Now: time.Now, + ErrColl: nil, + Resolver: nil, + CacheDir: cacheDir, + CustomFilterCacheSize: 100, + SafeSearchCacheSize: 100, + SafeSearchCacheTTL: 1 * time.Hour, + RuleListCacheSize: 100, + RefreshIvl: testRefreshIvl, + UseRuleListCache: false, } } diff --git a/internal/filter/hashprefix.go b/internal/filter/hashprefix.go index 3bfea1c..3d0dd48 100644 --- a/internal/filter/hashprefix.go +++ b/internal/filter/hashprefix.go @@ -3,13 +3,17 @@ package filter import ( "context" "fmt" + "net/url" "strings" "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" + "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -21,13 +25,32 @@ import ( // HashPrefixConfig is the hash-prefix filter configuration structure. type HashPrefixConfig struct { // Hashes are the hostname hashes for this filter. - Hashes *HashStorage + Hashes *hashstorage.Storage + + // URL is the URL used to update the filter. + URL *url.URL + + // ErrColl is used to collect non-critical and rare errors. + ErrColl agd.ErrorCollector + + // Resolver is used to resolve hosts for the hash-prefix filter. + Resolver agdnet.Resolver + + // ID is the ID of this hash storage for logging and error reporting. + ID agd.FilterListID + + // CachePath is the path to the file containing the cached filtered + // hostnames, one per line. + CachePath string // ReplacementHost is the replacement host for this filter. Queries // matched by the filter receive a response with the IP addresses of // this host. ReplacementHost string + // Staleness is the time after which a file is considered stale. + Staleness time.Duration + // CacheTTL is the time-to-live value used to cache the results of the // filter. // @@ -38,57 +61,58 @@ type HashPrefixConfig struct { CacheSize int } -// hashPrefixFilter is a filter that matches hosts by their hashes based on -// a hash-prefix table. -type hashPrefixFilter struct { - hashes *HashStorage +// HashPrefix is a filter that matches hosts by their hashes based on a +// hash-prefix table. +type HashPrefix struct { + hashes *hashstorage.Storage + refr *refreshableFilter resCache *resultcache.Cache[*ResultModified] resolver agdnet.Resolver errColl agd.ErrorCollector repHost string - id agd.FilterListID } -// newHashPrefixFilter returns a new hash-prefix filter. c must not be nil. -func newHashPrefixFilter( - c *HashPrefixConfig, - resolver agdnet.Resolver, - errColl agd.ErrorCollector, - id agd.FilterListID, -) (f *hashPrefixFilter) { - f = &hashPrefixFilter{ - hashes: c.Hashes, +// NewHashPrefix returns a new hash-prefix filter. c must not be nil. +func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) { + f = &HashPrefix{ + hashes: c.Hashes, + refr: &refreshableFilter{ + http: agdhttp.NewClient(&agdhttp.ClientConfig{ + Timeout: defaultFilterRefreshTimeout, + }), + url: c.URL, + id: c.ID, + cachePath: c.CachePath, + typ: "hash storage", + staleness: c.Staleness, + }, resCache: resultcache.New[*ResultModified](c.CacheSize), - resolver: resolver, - errColl: errColl, + resolver: c.Resolver, + errColl: c.ErrColl, repHost: c.ReplacementHost, - id: id, } - // Patch the refresh function of the hash storage, if there is one, to make - // sure that we clear the result cache during every refresh. - // - // TODO(a.garipov): Create a better way to do that than this spaghetti-style - // patching. Perhaps, lift the entire logic of hash-storage refresh into - // hashPrefixFilter. - if f.hashes != nil && f.hashes.refr != nil { - resetRules := f.hashes.refr.resetRules - f.hashes.refr.resetRules = func(text string) (err error) { - f.resCache.Clear() + f.refr.resetRules = f.resetRules - return resetRules(text) - } + err = f.refresh(context.Background(), true) + if err != nil { + return nil, err } - return f + return f, nil +} + +// id returns the ID of the hash storage. +func (f *HashPrefix) id() (fltID agd.FilterListID) { + return f.refr.id } // type check -var _ qtHostFilter = (*hashPrefixFilter)(nil) +var _ qtHostFilter = (*HashPrefix)(nil) // filterReq implements the qtHostFilter interface for *hashPrefixFilter. It // modifies the response if host matches f. -func (f *hashPrefixFilter) filterReq( +func (f *HashPrefix) filterReq( ctx context.Context, ri *agd.RequestInfo, req *dns.Msg, @@ -115,7 +139,7 @@ func (f *hashPrefixFilter) filterReq( var matched string sub := hashableSubdomains(host) for _, s := range sub { - if f.hashes.hashMatches(s) { + if f.hashes.Matches(s) { matched = s break @@ -134,19 +158,19 @@ func (f *hashPrefixFilter) filterReq( var result *dns.Msg ips, err := f.resolver.LookupIP(ctx, fam, f.repHost) if err != nil { - agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err) + agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err) result = ri.Messages.NewMsgSERVFAIL(req) } else { result, err = ri.Messages.NewIPRespMsg(req, ips...) if err != nil { - return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err) + return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err) } } rm = &ResultModified{ Msg: result, - List: f.id, + List: f.id(), Rule: agd.FilterRuleText(matched), } @@ -161,21 +185,21 @@ func (f *hashPrefixFilter) filterReq( } // updateCacheSizeMetrics updates cache size metrics. -func (f *hashPrefixFilter) updateCacheSizeMetrics(size int) { - switch f.id { +func (f *HashPrefix) updateCacheSizeMetrics(size int) { + switch id := f.id(); id { case agd.FilterListIDSafeBrowsing: metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size)) case agd.FilterListIDAdultBlocking: metrics.HashPrefixFilterAdultBlockingCacheSize.Set(float64(size)) default: - panic(fmt.Errorf("unsupported FilterListID %s", f.id)) + panic(fmt.Errorf("unsupported FilterListID %s", id)) } } // updateCacheLookupsMetrics updates cache lookups metrics. -func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) { +func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) { var hitsMetric, missesMetric prometheus.Counter - switch f.id { + switch id := f.id(); id { case agd.FilterListIDSafeBrowsing: hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses @@ -183,7 +207,7 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) { hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses default: - panic(fmt.Errorf("unsupported FilterListID %s", f.id)) + panic(fmt.Errorf("unsupported FilterListID %s", id)) } if hit { @@ -194,12 +218,50 @@ func (f *hashPrefixFilter) updateCacheLookupsMetrics(hit bool) { } // name implements the qtHostFilter interface for *hashPrefixFilter. -func (f *hashPrefixFilter) name() (n string) { +func (f *HashPrefix) name() (n string) { if f == nil { return "" } - return string(f.id) + return string(f.id()) +} + +// type check +var _ agd.Refresher = (*HashPrefix)(nil) + +// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter. +func (f *HashPrefix) Refresh(ctx context.Context) (err error) { + return f.refresh(ctx, false) +} + +// refresh reloads the hash filter data. If acceptStale is true, do not try to +// load the list from its URL when there is already a file in the cache +// directory, regardless of its staleness. +func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) { + return f.refr.refresh(ctx, acceptStale) +} + +// resetRules resets the hosts in the index. +func (f *HashPrefix) resetRules(text string) (err error) { + n, err := f.hashes.Reset(text) + + // Report the filter update to prometheus. + promLabels := prometheus.Labels{ + "filter": string(f.id()), + } + + metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err) + + if err != nil { + return err + } + + metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime() + metrics.FilterRulesTotal.With(promLabels).Set(float64(n)) + + log.Info("filter %s: reset %d hosts", f.id(), n) + + return nil } // subDomainNum defines how many labels should be hashed to match against a hash diff --git a/internal/filter/hashprefix_test.go b/internal/filter/hashprefix_test.go index e5e2c86..fac4d93 100644 --- a/internal/filter/hashprefix_test.go +++ b/internal/filter/hashprefix_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -21,8 +22,10 @@ import ( func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { cacheDir := t.TempDir() cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing)) - hosts := "scam.example.net\n" - err := os.WriteFile(cachePath, []byte(hosts), 0o644) + err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644) + require.NoError(t, err) + + hashes, err := hashstorage.New("") require.NoError(t, err) errColl := &agdtest.ErrorCollector{ @@ -31,46 +34,33 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { }, } - hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ - URL: nil, - ErrColl: errColl, - ID: agd.FilterListIDSafeBrowsing, - CachePath: cachePath, - RefreshIvl: testRefreshIvl, - }) - require.NoError(t, err) - - ctx := context.Background() - err = hashes.Start() - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { - return hashes.Shutdown(ctx) - }) - - // Fake Data - - onLookupIP := func( - _ context.Context, - _ netutil.AddrFamily, - _ string, - ) (ips []net.IP, err error) { - return []net.IP{safeBrowsingSafeIP4}, nil + resolver := &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + return []net.IP{safeBrowsingSafeIP4}, nil + }, } c := prepareConf(t) - c.SafeBrowsing = &filter.HashPrefixConfig{ + c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ Hashes: hashes, + ErrColl: errColl, + Resolver: resolver, + ID: agd.FilterListIDSafeBrowsing, + CachePath: cachePath, ReplacementHost: safeBrowsingSafeHost, + Staleness: 1 * time.Hour, CacheTTL: 10 * time.Second, CacheSize: 100, - } + }) + require.NoError(t, err) c.ErrColl = errColl - - c.Resolver = &agdtest.Resolver{ - OnLookupIP: onLookupIP, - } + c.Resolver = resolver s, err := filter.NewDefaultStorage(c) require.NoError(t, err) @@ -93,7 +83,7 @@ func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { } ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA) - ctx = agd.ContextWithRequestInfo(ctx, ri) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) diff --git a/internal/filter/hashstorage.go b/internal/filter/hashstorage.go deleted file mode 100644 index 9586b45..0000000 --- a/internal/filter/hashstorage.go +++ /dev/null @@ -1,352 +0,0 @@ -package filter - -import ( - "bufio" - "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "net/url" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" - "github.com/AdguardTeam/AdGuardDNS/internal/metrics" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" - "github.com/prometheus/client_golang/prometheus" -) - -// Hash Storage - -// Hash and hash part length constants. -const ( - // hashPrefixLen is the length of the prefix of the hash of the filtered - // hostname. - hashPrefixLen = 2 - - // HashPrefixEncLen is the encoded length of the hash prefix. Two text - // bytes per one binary byte. - HashPrefixEncLen = hashPrefixLen * 2 - - // hashLen is the length of the whole hash of the checked hostname. - hashLen = sha256.Size - - // hashSuffixLen is the length of the suffix of the hash of the filtered - // hostname. - hashSuffixLen = hashLen - hashPrefixLen - - // hashEncLen is the encoded length of the hash. Two text bytes per one - // binary byte. - hashEncLen = hashLen * 2 - - // legacyHashPrefixEncLen is the encoded length of a legacy hash. - legacyHashPrefixEncLen = 8 -) - -// hashPrefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of -// a host being checked. -type hashPrefix [hashPrefixLen]byte - -// hashSuffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of -// a host being checked. -type hashSuffix [hashSuffixLen]byte - -// HashStorage is a storage for hashes of the filtered hostnames. -type HashStorage struct { - // mu protects hashSuffixes. - mu *sync.RWMutex - hashSuffixes map[hashPrefix][]hashSuffix - - // refr contains data for refreshing the filter. - refr *refreshableFilter - - refrWorker *agd.RefreshWorker -} - -// HashStorageConfig is the configuration structure for a *HashStorage. -type HashStorageConfig struct { - // URL is the URL used to update the filter. - URL *url.URL - - // ErrColl is used to collect non-critical and rare errors. - ErrColl agd.ErrorCollector - - // ID is the ID of this hash storage for logging and error reporting. - ID agd.FilterListID - - // CachePath is the path to the file containing the cached filtered - // hostnames, one per line. - CachePath string - - // RefreshIvl is the refresh interval. - RefreshIvl time.Duration -} - -// NewHashStorage returns a new hash storage containing hashes of all hostnames. -func NewHashStorage(c *HashStorageConfig) (hs *HashStorage, err error) { - hs = &HashStorage{ - mu: &sync.RWMutex{}, - hashSuffixes: map[hashPrefix][]hashSuffix{}, - refr: &refreshableFilter{ - http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultTimeout, - }), - url: c.URL, - id: c.ID, - cachePath: c.CachePath, - typ: "hash storage", - refreshIvl: c.RefreshIvl, - }, - } - - // Do not set this in the literal above, since hs is nil there. - hs.refr.resetRules = hs.resetHosts - - refrWorker := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{ - Context: func() (ctx context.Context, cancel context.CancelFunc) { - return context.WithTimeout(context.Background(), defaultTimeout) - }, - Refresher: hs, - ErrColl: c.ErrColl, - Name: string(c.ID), - Interval: c.RefreshIvl, - RefreshOnShutdown: false, - RoutineLogsAreDebug: false, - }) - - hs.refrWorker = refrWorker - - err = hs.refresh(context.Background(), true) - if err != nil { - return nil, fmt.Errorf("initializing %s: %w", c.ID, err) - } - - return hs, nil -} - -// hashes returns all hashes starting with the given prefixes, if any. The -// resulting slice shares storage for all underlying strings. -// -// TODO(a.garipov): This currently doesn't take duplicates into account. -func (hs *HashStorage) hashes(hps []hashPrefix) (hashes []string) { - if len(hps) == 0 { - return nil - } - - hs.mu.RLock() - defer hs.mu.RUnlock() - - // First, calculate the number of hashes to allocate the buffer. - l := 0 - for _, hp := range hps { - hashSufs := hs.hashSuffixes[hp] - l += len(hashSufs) - } - - // Then, allocate the buffer of the appropriate size and write all hashes - // into one big buffer and slice it into separate strings to make the - // garbage collector's work easier. This assumes that all references to - // this buffer will become unreachable at the same time. - // - // The fact that we iterate over the map twice shouldn't matter, since we - // assume that len(hps) will be below 5 most of the time. - b := &strings.Builder{} - b.Grow(l * hashEncLen) - - // Use a buffer and write the resulting buffer into b directly instead of - // using hex.NewEncoder, because that seems to incur a significant - // performance hit. - var buf [hashEncLen]byte - for _, hp := range hps { - hashSufs := hs.hashSuffixes[hp] - for _, suf := range hashSufs { - // Slicing is safe here, since the contents of hp and suf are being - // encoded. - - // nolint:looppointer - hex.Encode(buf[:], hp[:]) - // nolint:looppointer - hex.Encode(buf[HashPrefixEncLen:], suf[:]) - _, _ = b.Write(buf[:]) - } - } - - s := b.String() - hashes = make([]string, 0, l) - for i := 0; i < l; i++ { - hashes = append(hashes, s[i*hashEncLen:(i+1)*hashEncLen]) - } - - return hashes -} - -// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for -// concurrent use. -func (hs *HashStorage) loadHashSuffixes(hp hashPrefix) (sufs []hashSuffix, ok bool) { - hs.mu.RLock() - defer hs.mu.RUnlock() - - sufs, ok = hs.hashSuffixes[hp] - - return sufs, ok -} - -// hashMatches returns true if the host matches one of the hashes. -func (hs *HashStorage) hashMatches(host string) (ok bool) { - sum := sha256.Sum256([]byte(host)) - hp := hashPrefix{sum[0], sum[1]} - - var buf [hashLen]byte - hashSufs, ok := hs.loadHashSuffixes(hp) - if !ok { - return false - } - - copy(buf[:], hp[:]) - for _, suf := range hashSufs { - // Slicing is safe here, because we make a copy. - - // nolint:looppointer - copy(buf[hashPrefixLen:], suf[:]) - if buf == sum { - return true - } - } - - return false -} - -// hashPrefixesFromStr returns hash prefixes from a dot-separated string. -func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashPrefix, err error) { - if prefixesStr == "" { - return nil, nil - } - - prefixSet := stringutil.NewSet() - prefixStrs := strings.Split(prefixesStr, ".") - for _, s := range prefixStrs { - if len(s) != HashPrefixEncLen { - // Some legacy clients send eight-character hashes instead of - // four-character ones. For now, remove the final four characters. - // - // TODO(a.garipov): Either remove this crutch or support such - // prefixes better. - if len(s) == legacyHashPrefixEncLen { - s = s[:HashPrefixEncLen] - } else { - return nil, fmt.Errorf("bad hash len for %q", s) - } - } - - prefixSet.Add(s) - } - - hashPrefixes = make([]hashPrefix, prefixSet.Len()) - prefixStrs = prefixSet.Values() - for i, s := range prefixStrs { - _, err = hex.Decode(hashPrefixes[i][:], []byte(s)) - if err != nil { - return nil, fmt.Errorf("bad hash encoding for %q", s) - } - } - - return hashPrefixes, nil -} - -// type check -var _ agd.Refresher = (*HashStorage)(nil) - -// Refresh implements the agd.Refresher interface for *HashStorage. If the file -// at the storage's path exists and its mtime shows that it's still fresh, it -// loads the data from the file. Otherwise, it uses the URL of the storage. -func (hs *HashStorage) Refresh(ctx context.Context) (err error) { - err = hs.refresh(ctx, false) - - // Report the filter update to prometheus. - promLabels := prometheus.Labels{ - "filter": string(hs.id()), - } - - metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err) - - if err == nil { - metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime() - - // Count the total number of hashes loaded. - count := 0 - for _, v := range hs.hashSuffixes { - count += len(v) - } - - metrics.FilterRulesTotal.With(promLabels).Set(float64(count)) - } - - return err -} - -// id returns the ID of the hash storage. -func (hs *HashStorage) id() (fltID agd.FilterListID) { - return hs.refr.id -} - -// refresh reloads the hash filter data. If acceptStale is true, do not try to -// load the list from its URL when there is already a file in the cache -// directory, regardless of its staleness. -func (hs *HashStorage) refresh(ctx context.Context, acceptStale bool) (err error) { - return hs.refr.refresh(ctx, acceptStale) -} - -// resetHosts resets the hosts in the index. -func (hs *HashStorage) resetHosts(hostsStr string) (err error) { - hs.mu.Lock() - defer hs.mu.Unlock() - - // Delete all elements without allocating a new map to safe space and - // performance. - // - // This is optimized, see https://github.com/golang/go/issues/20138. - for hp := range hs.hashSuffixes { - delete(hs.hashSuffixes, hp) - } - - var n int - s := bufio.NewScanner(strings.NewReader(hostsStr)) - for s.Scan() { - host := s.Text() - if len(host) == 0 || host[0] == '#' { - continue - } - - sum := sha256.Sum256([]byte(host)) - hp := hashPrefix{sum[0], sum[1]} - - // TODO(a.garipov): Convert to array directly when proposal - // golang/go#46505 is implemented in Go 1.20. - suf := *(*hashSuffix)(sum[hashPrefixLen:]) - hs.hashSuffixes[hp] = append(hs.hashSuffixes[hp], suf) - - n++ - } - - err = s.Err() - if err != nil { - return fmt.Errorf("scanning hosts: %w", err) - } - - log.Info("filter %s: reset %d hosts", hs.id(), n) - - return nil -} - -// Start implements the agd.Service interface for *HashStorage. -func (hs *HashStorage) Start() (err error) { - return hs.refrWorker.Start() -} - -// Shutdown implements the agd.Service interface for *HashStorage. -func (hs *HashStorage) Shutdown(ctx context.Context) (err error) { - return hs.refrWorker.Shutdown(ctx) -} diff --git a/internal/filter/hashstorage/hashstorage.go b/internal/filter/hashstorage/hashstorage.go new file mode 100644 index 0000000..8234279 --- /dev/null +++ b/internal/filter/hashstorage/hashstorage.go @@ -0,0 +1,203 @@ +// Package hashstorage defines a storage of hashes of domain names used for +// filtering. +package hashstorage + +import ( + "bufio" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" +) + +// Hash and hash part length constants. +const ( + // PrefixLen is the length of the prefix of the hash of the filtered + // hostname. + PrefixLen = 2 + + // PrefixEncLen is the encoded length of the hash prefix. Two text + // bytes per one binary byte. + PrefixEncLen = PrefixLen * 2 + + // hashLen is the length of the whole hash of the checked hostname. + hashLen = sha256.Size + + // suffixLen is the length of the suffix of the hash of the filtered + // hostname. + suffixLen = hashLen - PrefixLen + + // hashEncLen is the encoded length of the hash. Two text bytes per one + // binary byte. + hashEncLen = hashLen * 2 +) + +// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a +// host being checked. +type Prefix [PrefixLen]byte + +// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a +// host being checked. +type suffix [suffixLen]byte + +// Storage stores hashes of the filtered hostnames. All methods are safe for +// concurrent use. +type Storage struct { + // mu protects hashSuffixes. + mu *sync.RWMutex + hashSuffixes map[Prefix][]suffix +} + +// New returns a new hash storage containing hashes of the domain names listed +// in hostnames, one domain name per line. +func New(hostnames string) (s *Storage, err error) { + s = &Storage{ + mu: &sync.RWMutex{}, + hashSuffixes: map[Prefix][]suffix{}, + } + + if hostnames != "" { + _, err = s.Reset(hostnames) + if err != nil { + return nil, err + } + } + + return s, nil +} + +// Hashes returns all hashes starting with the given prefixes, if any. The +// resulting slice shares storage for all underlying strings. +// +// TODO(a.garipov): This currently doesn't take duplicates into account. +func (s *Storage) Hashes(hps []Prefix) (hashes []string) { + if len(hps) == 0 { + return nil + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // First, calculate the number of hashes to allocate the buffer. + l := 0 + for _, hp := range hps { + hashSufs := s.hashSuffixes[hp] + l += len(hashSufs) + } + + // Then, allocate the buffer of the appropriate size and write all hashes + // into one big buffer and slice it into separate strings to make the + // garbage collector's work easier. This assumes that all references to + // this buffer will become unreachable at the same time. + // + // The fact that we iterate over the [s.hashSuffixes] map twice shouldn't + // matter, since we assume that len(hps) will be below 5 most of the time. + b := &strings.Builder{} + b.Grow(l * hashEncLen) + + // Use a buffer and write the resulting buffer into b directly instead of + // using hex.NewEncoder, because that seems to incur a significant + // performance hit. + var buf [hashEncLen]byte + for _, hp := range hps { + hashSufs := s.hashSuffixes[hp] + for _, suf := range hashSufs { + // Slicing is safe here, since the contents of hp and suf are being + // encoded. + + // nolint:looppointer + hex.Encode(buf[:], hp[:]) + // nolint:looppointer + hex.Encode(buf[PrefixEncLen:], suf[:]) + _, _ = b.Write(buf[:]) + } + } + + str := b.String() + hashes = make([]string, 0, l) + for i := 0; i < l; i++ { + hashes = append(hashes, str[i*hashEncLen:(i+1)*hashEncLen]) + } + + return hashes +} + +// Matches returns true if the host matches one of the hashes. +func (s *Storage) Matches(host string) (ok bool) { + sum := sha256.Sum256([]byte(host)) + hp := *(*Prefix)(sum[:PrefixLen]) + + var buf [hashLen]byte + hashSufs, ok := s.loadHashSuffixes(hp) + if !ok { + return false + } + + copy(buf[:], hp[:]) + for _, suf := range hashSufs { + // Slicing is safe here, because we make a copy. + + // nolint:looppointer + copy(buf[PrefixLen:], suf[:]) + if buf == sum { + return true + } + } + + return false +} + +// Reset resets the hosts in the index using the domain names listed in +// hostnames, one domain name per line, and returns the total number of +// processed rules. +func (s *Storage) Reset(hostnames string) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Delete all elements without allocating a new map to save space and + // improve performance. + // + // This is optimized, see https://github.com/golang/go/issues/20138. + // + // TODO(a.garipov): Use clear once golang/go#56351 is implemented. + for hp := range s.hashSuffixes { + delete(s.hashSuffixes, hp) + } + + sc := bufio.NewScanner(strings.NewReader(hostnames)) + for sc.Scan() { + host := sc.Text() + if len(host) == 0 || host[0] == '#' { + continue + } + + sum := sha256.Sum256([]byte(host)) + hp := *(*Prefix)(sum[:PrefixLen]) + + // TODO(a.garipov): Here and everywhere, convert to array directly when + // proposal golang/go#46505 is implemented in Go 1.20. + suf := *(*suffix)(sum[PrefixLen:]) + s.hashSuffixes[hp] = append(s.hashSuffixes[hp], suf) + + n++ + } + + err = sc.Err() + if err != nil { + return 0, fmt.Errorf("scanning hosts: %w", err) + } + + return n, nil +} + +// loadHashSuffixes returns hash suffixes for the given prefix. It is safe for +// concurrent use. sufs must not be modified. +func (s *Storage) loadHashSuffixes(hp Prefix) (sufs []suffix, ok bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + sufs, ok = s.hashSuffixes[hp] + + return sufs, ok +} diff --git a/internal/filter/hashstorage/hashstorage_test.go b/internal/filter/hashstorage/hashstorage_test.go new file mode 100644 index 0000000..944d304 --- /dev/null +++ b/internal/filter/hashstorage/hashstorage_test.go @@ -0,0 +1,142 @@ +package hashstorage_test + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Common hostnames for tests. +const ( + testHost = "porn.example" + otherHost = "otherporn.example" +) + +func TestStorage_Hashes(t *testing.T) { + s, err := hashstorage.New(testHost) + require.NoError(t, err) + + h := sha256.Sum256([]byte(testHost)) + want := []string{hex.EncodeToString(h[:])} + + p := hashstorage.Prefix{h[0], h[1]} + got := s.Hashes([]hashstorage.Prefix{p}) + assert.Equal(t, want, got) + + wrong := s.Hashes([]hashstorage.Prefix{{}}) + assert.Empty(t, wrong) +} + +func TestStorage_Matches(t *testing.T) { + s, err := hashstorage.New(testHost) + require.NoError(t, err) + + got := s.Matches(testHost) + assert.True(t, got) + + got = s.Matches(otherHost) + assert.False(t, got) +} + +func TestStorage_Reset(t *testing.T) { + s, err := hashstorage.New(testHost) + require.NoError(t, err) + + n, err := s.Reset(otherHost) + require.NoError(t, err) + + assert.Equal(t, 1, n) + + h := sha256.Sum256([]byte(otherHost)) + want := []string{hex.EncodeToString(h[:])} + + p := hashstorage.Prefix{h[0], h[1]} + got := s.Hashes([]hashstorage.Prefix{p}) + assert.Equal(t, want, got) + + prevHash := sha256.Sum256([]byte(testHost)) + prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}}) + assert.Empty(t, prev) +} + +// Sinks for benchmarks. +var ( + errSink error + strsSink []string +) + +func BenchmarkStorage_Hashes(b *testing.B) { + const N = 10_000 + + var hosts []string + for i := 0; i < N; i++ { + hosts = append(hosts, fmt.Sprintf("%d."+testHost, i)) + } + + s, err := hashstorage.New(strings.Join(hosts, "\n")) + require.NoError(b, err) + + var hashPrefixes []hashstorage.Prefix + for i := 0; i < 4; i++ { + hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]}) + } + + for n := 1; n <= 4; n++ { + b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) { + hps := hashPrefixes[:n] + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + strsSink = s.Hashes(hps) + } + }) + } + + // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op + // BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op + // BenchmarkStorage_Hashes/3-16 13492526 92.22 ns/op 0 B/op 0 allocs/op + // BenchmarkStorage_Hashes/4-16 9542425 109.2 ns/op 0 B/op 0 allocs/op +} + +func BenchmarkStorage_ResetHosts(b *testing.B) { + const N = 1_000 + + var hosts []string + for i := 0; i < N; i++ { + hosts = append(hosts, fmt.Sprintf("%d."+testHost, i)) + } + + hostnames := strings.Join(hosts, "\n") + s, err := hashstorage.New(hostnames) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, errSink = s.Reset(hostnames) + } + + require.NoError(b, errSink) + + // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op +} diff --git a/internal/filter/hashstorage_internal_test.go b/internal/filter/hashstorage_internal_test.go deleted file mode 100644 index db61dce..0000000 --- a/internal/filter/hashstorage_internal_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package filter - -import ( - "fmt" - "strconv" - "strings" - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -var strsSink []string - -func BenchmarkHashStorage_Hashes(b *testing.B) { - const N = 10_000 - - var hosts []string - for i := 0; i < N; i++ { - hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i)) - } - - // Don't use a constructor, since we don't need the whole contents of the - // storage. - // - // TODO(a.garipov): Think of a better way to do this. - hs := &HashStorage{ - mu: &sync.RWMutex{}, - hashSuffixes: map[hashPrefix][]hashSuffix{}, - refr: &refreshableFilter{id: "test_filter"}, - } - - err := hs.resetHosts(strings.Join(hosts, "\n")) - require.NoError(b, err) - - var hashPrefixes []hashPrefix - for i := 0; i < 4; i++ { - hashPrefixes = append(hashPrefixes, hashPrefix{hosts[i][0], hosts[i][1]}) - } - - for n := 1; n <= 4; n++ { - b.Run(strconv.FormatInt(int64(n), 10), func(b *testing.B) { - hps := hashPrefixes[:n] - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - strsSink = hs.hashes(hps) - } - }) - } -} - -func BenchmarkHashStorage_resetHosts(b *testing.B) { - const N = 1_000 - - var hosts []string - for i := 0; i < N; i++ { - hosts = append(hosts, fmt.Sprintf("%d.porn.example.com", i)) - } - - // Don't use a constructor, since we don't need the whole contents of the - // storage. - // - // TODO(a.garipov): Think of a better way to do this. - hs := &HashStorage{ - mu: &sync.RWMutex{}, - hashSuffixes: map[hashPrefix][]hashSuffix{}, - refr: &refreshableFilter{id: "test_filter"}, - } - - // Reset them once to fill the initial map. - err := hs.resetHosts(strings.Join(hosts, "\n")) - require.NoError(b, err) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err = hs.resetHosts(strings.Join(hosts, "\n")) - } - - require.NoError(b, err) - - // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU: - // - // goos: linux - // goarch: amd64 - // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter - // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics - // BenchmarkHashStorage_resetHosts-16 1404 829505 ns/op 58289 B/op 1011 allocs/op -} diff --git a/internal/filter/internal/resultcache/resultcache.go b/internal/filter/internal/resultcache/resultcache.go index ba5a2db..f230be5 100644 --- a/internal/filter/internal/resultcache/resultcache.go +++ b/internal/filter/internal/resultcache/resultcache.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/mathutil" "github.com/bluele/gcache" ) @@ -93,13 +94,9 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) { // Save on allocations by reusing a buffer. var buf [3]byte binary.LittleEndian.PutUint16(buf[:2], qt) - if isAns { - buf[2] = 1 - } else { - buf[2] = 0 - } + buf[2] = mathutil.BoolToNumber[byte](isAns) - _, _ = h.Write(buf[:3]) + _, _ = h.Write(buf[:]) return Key(h.Sum64()) } diff --git a/internal/filter/refrfilter.go b/internal/filter/refrfilter.go index 55ec1c4..f776266 100644 --- a/internal/filter/refrfilter.go +++ b/internal/filter/refrfilter.go @@ -45,9 +45,8 @@ type refreshableFilter struct { // typ is the type of this filter used for logging and error reporting. typ string - // refreshIvl is the refresh interval for this filter. It is also used to - // check if the cached file is fresh enough. - refreshIvl time.Duration + // staleness is the time after which a file is considered stale. + staleness time.Duration } // refresh reloads the filter data. If acceptStale is true, refresh doesn't try @@ -118,7 +117,7 @@ func (f *refreshableFilter) refreshFromFile( return "", fmt.Errorf("reading filter file stat: %w", err) } - if mtime := fi.ModTime(); !mtime.Add(f.refreshIvl).After(time.Now()) { + if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) { return "", nil } } diff --git a/internal/filter/refrfilter_internal_test.go b/internal/filter/refrfilter_internal_test.go index e989af0..8f33eb2 100644 --- a/internal/filter/refrfilter_internal_test.go +++ b/internal/filter/refrfilter_internal_test.go @@ -30,43 +30,43 @@ func TestRefreshableFilter_RefreshFromFile(t *testing.T) { name string cachePath string wantText string - refreshIvl time.Duration + staleness time.Duration acceptStale bool }{{ name: "no_file", cachePath: "does_not_exist", wantText: "", - refreshIvl: 0, + staleness: 0, acceptStale: true, }, { name: "file", cachePath: cachePath, wantText: defaultText, - refreshIvl: 0, + staleness: 0, acceptStale: true, }, { name: "file_stale", cachePath: cachePath, wantText: "", - refreshIvl: -1 * time.Second, + staleness: -1 * time.Second, acceptStale: false, }, { name: "file_stale_accept", cachePath: cachePath, wantText: defaultText, - refreshIvl: -1 * time.Second, + staleness: -1 * time.Second, acceptStale: true, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { f := &refreshableFilter{ - http: nil, - url: nil, - id: "test_filter", - cachePath: tc.cachePath, - typ: "test filter", - refreshIvl: tc.refreshIvl, + http: nil, + url: nil, + id: "test_filter", + cachePath: tc.cachePath, + typ: "test filter", + staleness: tc.staleness, } var text string @@ -161,12 +161,12 @@ func TestRefreshableFilter_RefreshFromURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { f := &refreshableFilter{ - http: httpCli, - url: u, - id: "test_filter", - cachePath: tc.cachePath, - typ: "test filter", - refreshIvl: testTimeout, + http: httpCli, + url: u, + id: "test_filter", + cachePath: tc.cachePath, + typ: "test filter", + staleness: testTimeout, } if tc.expectReq { diff --git a/internal/filter/rulelist.go b/internal/filter/rulelist.go index 367d00c..ba67be2 100644 --- a/internal/filter/rulelist.go +++ b/internal/filter/rulelist.go @@ -72,13 +72,13 @@ func newRuleListFilter( mu: &sync.RWMutex{}, refr: &refreshableFilter{ http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultTimeout, + Timeout: defaultFilterRefreshTimeout, }), - url: l.URL, - id: l.ID, - cachePath: filepath.Join(fileCacheDir, string(l.ID)), - typ: "rule list", - refreshIvl: l.RefreshIvl, + url: l.URL, + id: l.ID, + cachePath: filepath.Join(fileCacheDir, string(l.ID)), + typ: "rule list", + staleness: l.RefreshIvl, }, urlFilterID: newURLFilterID(), } diff --git a/internal/filter/safebrowsing.go b/internal/filter/safebrowsing.go index a37fdec..2087e0a 100644 --- a/internal/filter/safebrowsing.go +++ b/internal/filter/safebrowsing.go @@ -2,9 +2,13 @@ package filter import ( "context" + "encoding/hex" + "fmt" "strings" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/stringutil" ) // Safe Browsing TXT Record Server @@ -14,12 +18,12 @@ import ( // // TODO(a.garipov): Consider making an interface to simplify testing. type SafeBrowsingServer struct { - generalHashes *HashStorage - adultBlockingHashes *HashStorage + generalHashes *hashstorage.Storage + adultBlockingHashes *hashstorage.Storage } // NewSafeBrowsingServer returns a new safe browsing DNS server. -func NewSafeBrowsingServer(general, adultBlocking *HashStorage) (f *SafeBrowsingServer) { +func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) { return &SafeBrowsingServer{ generalHashes: general, adultBlockingHashes: adultBlocking, @@ -48,8 +52,7 @@ func (srv *SafeBrowsingServer) Hashes( } var prefixesStr string - var strg *HashStorage - + var strg *hashstorage.Storage if strings.HasSuffix(host, GeneralTXTSuffix) { prefixesStr = host[:len(host)-len(GeneralTXTSuffix)] strg = srv.generalHashes @@ -67,5 +70,45 @@ func (srv *SafeBrowsingServer) Hashes( return nil, false, err } - return strg.hashes(hashPrefixes), true, nil + return strg.Hashes(hashPrefixes), true, nil +} + +// legacyPrefixEncLen is the encoded length of a legacy hash. +const legacyPrefixEncLen = 8 + +// hashPrefixesFromStr returns hash prefixes from a dot-separated string. +func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) { + if prefixesStr == "" { + return nil, nil + } + + prefixSet := stringutil.NewSet() + prefixStrs := strings.Split(prefixesStr, ".") + for _, s := range prefixStrs { + if len(s) != hashstorage.PrefixEncLen { + // Some legacy clients send eight-character hashes instead of + // four-character ones. For now, remove the final four characters. + // + // TODO(a.garipov): Either remove this crutch or support such + // prefixes better. + if len(s) == legacyPrefixEncLen { + s = s[:hashstorage.PrefixEncLen] + } else { + return nil, fmt.Errorf("bad hash len for %q", s) + } + } + + prefixSet.Add(s) + } + + hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len()) + prefixStrs = prefixSet.Values() + for i, s := range prefixStrs { + _, err = hex.Decode(hashPrefixes[i][:], []byte(s)) + if err != nil { + return nil, fmt.Errorf("bad hash encoding for %q", s) + } + } + + return hashPrefixes, nil } diff --git a/internal/filter/safebrowsing_test.go b/internal/filter/safebrowsing_test.go index 0c811a9..55afa30 100644 --- a/internal/filter/safebrowsing_test.go +++ b/internal/filter/safebrowsing_test.go @@ -4,17 +4,11 @@ import ( "context" "crypto/sha256" "encoding/hex" - "net/url" - "os" - "path/filepath" "strings" "testing" - "time" - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,41 +37,10 @@ func TestSafeBrowsingServer(t *testing.T) { hashStrs[i] = hex.EncodeToString(sum[:]) } - // Hash Storage - - errColl := &agdtest.ErrorCollector{ - OnCollect: func(_ context.Context, err error) { - panic("not implemented") - }, - } - - cacheDir := t.TempDir() - cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing)) - err := os.WriteFile(cachePath, []byte(strings.Join(hosts, "\n")), 0o644) - require.NoError(t, err) - - hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ - URL: &url.URL{}, - ErrColl: errColl, - ID: agd.FilterListIDSafeBrowsing, - CachePath: cachePath, - RefreshIvl: testRefreshIvl, - }) + hashes, err := hashstorage.New(strings.Join(hosts, "\n")) require.NoError(t, err) ctx := context.Background() - - err = hashes.Start() - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { - return hashes.Shutdown(ctx) - }) - - // Give the storage some time to process the hashes. - // - // TODO(a.garipov): Think of a less stupid way of doing this. - time.Sleep(100 * time.Millisecond) - testCases := []struct { name string host string @@ -90,14 +53,14 @@ func TestSafeBrowsingServer(t *testing.T) { wantMatched: false, }, { name: "realistic", - host: hashStrs[realisticHostIdx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix, + host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix, wantHashStrs: []string{ hashStrs[realisticHostIdx], }, wantMatched: true, }, { name: "same_prefix", - host: hashStrs[samePrefixHost1Idx][:filter.HashPrefixEncLen] + filter.GeneralTXTSuffix, + host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix, wantHashStrs: []string{ hashStrs[samePrefixHost1Idx], hashStrs[samePrefixHost2Idx], diff --git a/internal/filter/safesearch_test.go b/internal/filter/safesearch_test.go index dafaf4e..a295e81 100644 --- a/internal/filter/safesearch_test.go +++ b/internal/filter/safesearch_test.go @@ -3,9 +3,7 @@ package filter_test import ( "context" "net" - "os" "testing" - "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" @@ -20,45 +18,29 @@ import ( func TestStorage_FilterFromContext_safeSearch(t *testing.T) { numLookupIP := 0 - onLookupIP := func( - _ context.Context, - fam netutil.AddrFamily, - _ string, - ) (ips []net.IP, err error) { - numLookupIP++ + resolver := &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + fam netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + numLookupIP++ - if fam == netutil.AddrFamilyIPv4 { - return []net.IP{safeSearchIPRespIP4}, nil - } + if fam == netutil.AddrFamilyIPv4 { + return []net.IP{safeSearchIPRespIP4}, nil + } - return []net.IP{safeSearchIPRespIP6}, nil + return []net.IP{safeSearchIPRespIP6}, nil + }, } - tmpFile, err := os.CreateTemp(t.TempDir(), "") - require.NoError(t, err) - - _, err = tmpFile.Write([]byte("bad.example.com\n")) - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) }) - - hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ - CachePath: tmpFile.Name(), - RefreshIvl: 1 * time.Hour, - }) - require.NoError(t, err) - c := prepareConf(t) - c.SafeBrowsing.Hashes = hashes - c.AdultBlocking.Hashes = hashes - c.ErrColl = &agdtest.ErrorCollector{ OnCollect: func(_ context.Context, err error) { panic("not implemented") }, } - c.Resolver = &agdtest.Resolver{ - OnLookupIP: onLookupIP, - } + c.Resolver = resolver s, err := filter.NewDefaultStorage(c) require.NoError(t, err) @@ -80,13 +62,13 @@ func TestStorage_FilterFromContext_safeSearch(t *testing.T) { host: safeSearchIPHost, wantIP: safeSearchIPRespIP4, rrtype: dns.TypeA, - wantLookups: 0, + wantLookups: 1, }, { name: "ip6", host: safeSearchIPHost, - wantIP: nil, + wantIP: safeSearchIPRespIP6, rrtype: dns.TypeAAAA, - wantLookups: 0, + wantLookups: 1, }, { name: "host_ip4", host: safeSearchHost, diff --git a/internal/filter/serviceblocker.go b/internal/filter/serviceblocker.go index fa35448..04cff42 100644 --- a/internal/filter/serviceblocker.go +++ b/internal/filter/serviceblocker.go @@ -48,7 +48,7 @@ func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *servic return &serviceBlocker{ url: indexURL, http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultTimeout, + Timeout: defaultFilterRefreshTimeout, }), mu: &sync.RWMutex{}, errColl: errColl, diff --git a/internal/filter/storage.go b/internal/filter/storage.go index 5fbff15..3a3b114 100644 --- a/internal/filter/storage.go +++ b/internal/filter/storage.go @@ -15,7 +15,6 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/timeutil" "github.com/bluele/gcache" "github.com/prometheus/client_golang/prometheus" ) @@ -55,10 +54,10 @@ type DefaultStorage struct { services *serviceBlocker // safeBrowsing is the general safe browsing filter. - safeBrowsing *hashPrefixFilter + safeBrowsing *HashPrefix // adultBlocking is the adult content blocking safe browsing filter. - adultBlocking *hashPrefixFilter + adultBlocking *HashPrefix // genSafeSearch is the general safe search filter. genSafeSearch *safeSearch @@ -111,11 +110,11 @@ type DefaultStorageConfig struct { // SafeBrowsing is the configuration for the default safe browsing filter. // It must not be nil. - SafeBrowsing *HashPrefixConfig + SafeBrowsing *HashPrefix // AdultBlocking is the configuration for the adult content blocking safe // browsing filter. It must not be nil. - AdultBlocking *HashPrefixConfig + AdultBlocking *HashPrefix // Now is a function that returns current time. Now func() (now time.Time) @@ -156,25 +155,8 @@ type DefaultStorageConfig struct { // NewDefaultStorage returns a new filter storage. c must not be nil. func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) { - // TODO(ameshkov): Consider making configurable. - resolver := agdnet.NewCachingResolver(c.Resolver, 1*timeutil.Day) - - safeBrowsing := newHashPrefixFilter( - c.SafeBrowsing, - resolver, - c.ErrColl, - agd.FilterListIDSafeBrowsing, - ) - - adultBlocking := newHashPrefixFilter( - c.AdultBlocking, - resolver, - c.ErrColl, - agd.FilterListIDAdultBlocking, - ) - genSafeSearch := newSafeSearch(&safeSearchConfig{ - resolver: resolver, + resolver: c.Resolver, errColl: c.ErrColl, list: &agd.FilterList{ URL: c.GeneralSafeSearchRulesURL, @@ -187,7 +169,7 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) { }) ytSafeSearch := newSafeSearch(&safeSearchConfig{ - resolver: resolver, + resolver: c.Resolver, errColl: c.ErrColl, list: &agd.FilterList{ URL: c.YoutubeSafeSearchRulesURL, @@ -203,11 +185,11 @@ func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) { mu: &sync.RWMutex{}, url: c.FilterIndexURL, http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultTimeout, + Timeout: defaultFilterRefreshTimeout, }), services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl), - safeBrowsing: safeBrowsing, - adultBlocking: adultBlocking, + safeBrowsing: c.SafeBrowsing, + adultBlocking: c.AdultBlocking, genSafeSearch: genSafeSearch, ytSafeSearch: ytSafeSearch, now: c.Now, @@ -322,7 +304,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b func (s *DefaultStorage) safeBrowsingForProfile( p *agd.Profile, parentalEnabled bool, -) (safeBrowsing, adultBlocking *hashPrefixFilter) { +) (safeBrowsing, adultBlocking *HashPrefix) { if p.SafeBrowsingEnabled { safeBrowsing = s.safeBrowsing } @@ -359,7 +341,7 @@ func (s *DefaultStorage) safeSearchForProfile( // in the filtering group. g must not be nil. func (s *DefaultStorage) safeBrowsingForGroup( g *agd.FilteringGroup, -) (safeBrowsing, adultBlocking *hashPrefixFilter) { +) (safeBrowsing, adultBlocking *HashPrefix) { if g.SafeBrowsingEnabled { safeBrowsing = s.safeBrowsing } diff --git a/internal/filter/storage_test.go b/internal/filter/storage_test.go index 75d0edf..9839abc 100644 --- a/internal/filter/storage_test.go +++ b/internal/filter/storage_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -124,34 +125,11 @@ func TestStorage_FilterFromContext(t *testing.T) { } func TestStorage_FilterFromContext_customAllow(t *testing.T) { - // Initialize the hashes file and use it with the storage. - tmpFile, err := os.CreateTemp(t.TempDir(), "") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) }) - - _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") - require.NoError(t, err) - - hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ - CachePath: tmpFile.Name(), - RefreshIvl: 1 * time.Hour, - }) - require.NoError(t, err) - - c := prepareConf(t) - - c.SafeBrowsing = &filter.HashPrefixConfig{ - Hashes: hashes, - ReplacementHost: safeBrowsingSafeHost, - CacheTTL: 10 * time.Second, - CacheSize: 100, - } - - c.ErrColl = &agdtest.ErrorCollector{ + errColl := &agdtest.ErrorCollector{ OnCollect: func(_ context.Context, err error) { panic("not implemented") }, } - c.Resolver = &agdtest.Resolver{ + resolver := &agdtest.Resolver{ OnLookupIP: func( _ context.Context, _ netutil.AddrFamily, @@ -161,6 +139,35 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) { }, } + // Initialize the hashes file and use it with the storage. + tmpFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) }) + + _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") + require.NoError(t, err) + + hashes, err := hashstorage.New(safeBrowsingHost) + require.NoError(t, err) + + c := prepareConf(t) + + c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ + Hashes: hashes, + ErrColl: errColl, + Resolver: resolver, + ID: agd.FilterListIDSafeBrowsing, + CachePath: tmpFile.Name(), + ReplacementHost: safeBrowsingSafeHost, + Staleness: 1 * time.Hour, + CacheTTL: 10 * time.Second, + CacheSize: 100, + }) + require.NoError(t, err) + + c.ErrColl = errColl + c.Resolver = resolver + s, err := filter.NewDefaultStorage(c) require.NoError(t, err) @@ -211,39 +218,11 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { // parental protection from 11:00:00 until 12:59:59. nowTime := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC) - // Initialize the hashes file and use it with the storage. - tmpFile, err := os.CreateTemp(t.TempDir(), "") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) }) - - _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") - require.NoError(t, err) - - hashes, err := filter.NewHashStorage(&filter.HashStorageConfig{ - CachePath: tmpFile.Name(), - RefreshIvl: 1 * time.Hour, - }) - require.NoError(t, err) - - c := prepareConf(t) - - // Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule. - c.AdultBlocking = &filter.HashPrefixConfig{ - Hashes: hashes, - ReplacementHost: safeBrowsingSafeHost, - CacheTTL: 10 * time.Second, - CacheSize: 100, - } - - c.Now = func() (t time.Time) { - return nowTime - } - - c.ErrColl = &agdtest.ErrorCollector{ + errColl := &agdtest.ErrorCollector{ OnCollect: func(_ context.Context, err error) { panic("not implemented") }, } - c.Resolver = &agdtest.Resolver{ + resolver := &agdtest.Resolver{ OnLookupIP: func( _ context.Context, _ netutil.AddrFamily, @@ -253,6 +232,40 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { }, } + // Initialize the hashes file and use it with the storage. + tmpFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(tmpFile.Name()) }) + + _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") + require.NoError(t, err) + + hashes, err := hashstorage.New(safeBrowsingHost) + require.NoError(t, err) + + c := prepareConf(t) + + // Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule. + c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ + Hashes: hashes, + ErrColl: errColl, + Resolver: resolver, + ID: agd.FilterListIDAdultBlocking, + CachePath: tmpFile.Name(), + ReplacementHost: safeBrowsingSafeHost, + Staleness: 1 * time.Hour, + CacheTTL: 10 * time.Second, + CacheSize: 100, + }) + require.NoError(t, err) + + c.Now = func() (t time.Time) { + return nowTime + } + + c.ErrColl = errColl + c.Resolver = resolver + s, err := filter.NewDefaultStorage(c) require.NoError(t, err) diff --git a/internal/geoip/file.go b/internal/geoip/file.go index e7845b8..7207558 100644 --- a/internal/geoip/file.go +++ b/internal/geoip/file.go @@ -20,9 +20,6 @@ import ( // FileConfig is the file-based GeoIP configuration structure. type FileConfig struct { - // ErrColl is the error collector that is used to report errors. - ErrColl agd.ErrorCollector - // ASNPath is the path to the GeoIP database of ASNs. ASNPath string @@ -39,8 +36,6 @@ type FileConfig struct { // File is a file implementation of [geoip.Interface]. type File struct { - errColl agd.ErrorCollector - // mu protects asn, country, country subnet maps, and caches against // simultaneous access during a refresh. mu *sync.RWMutex @@ -77,8 +72,6 @@ type asnSubnets map[agd.ASN]netip.Prefix // NewFile returns a new GeoIP database that reads information from a file. func NewFile(c *FileConfig) (f *File, err error) { f = &File{ - errColl: c.ErrColl, - mu: &sync.RWMutex{}, asnPath: c.ASNPath, diff --git a/internal/geoip/file_test.go b/internal/geoip/file_test.go index 3288b99..58c629c 100644 --- a/internal/geoip/file_test.go +++ b/internal/geoip/file_test.go @@ -1,12 +1,10 @@ package geoip_test import ( - "context" "net/netip" "testing" "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/golibs/netutil" "github.com/stretchr/testify/assert" @@ -14,12 +12,7 @@ import ( ) func TestFile_Data(t *testing.T) { - var ec agd.ErrorCollector = &agdtest.ErrorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - } - conf := &geoip.FileConfig{ - ErrColl: ec, ASNPath: asnPath, CountryPath: countryPath, HostCacheSize: 0, @@ -42,12 +35,7 @@ func TestFile_Data(t *testing.T) { } func TestFile_Data_hostCache(t *testing.T) { - var ec agd.ErrorCollector = &agdtest.ErrorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - } - conf := &geoip.FileConfig{ - ErrColl: ec, ASNPath: asnPath, CountryPath: countryPath, HostCacheSize: 1, @@ -74,12 +62,7 @@ func TestFile_Data_hostCache(t *testing.T) { } func TestFile_SubnetByLocation(t *testing.T) { - var ec agd.ErrorCollector = &agdtest.ErrorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - } - conf := &geoip.FileConfig{ - ErrColl: ec, ASNPath: asnPath, CountryPath: countryPath, HostCacheSize: 0, @@ -106,12 +89,7 @@ var locSink *agd.Location var errSink error func BenchmarkFile_Data(b *testing.B) { - var ec agd.ErrorCollector = &agdtest.ErrorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - } - conf := &geoip.FileConfig{ - ErrColl: ec, ASNPath: asnPath, CountryPath: countryPath, HostCacheSize: 0, @@ -162,12 +140,7 @@ func BenchmarkFile_Data(b *testing.B) { var fileSink *geoip.File func BenchmarkNewFile(b *testing.B) { - var ec agd.ErrorCollector = &agdtest.ErrorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - } - conf := &geoip.FileConfig{ - ErrColl: ec, ASNPath: asnPath, CountryPath: countryPath, HostCacheSize: 0, diff --git a/internal/metrics/dnssvc.go b/internal/metrics/dnssvc.go index 8b02047..1acf3c7 100644 --- a/internal/metrics/dnssvc.go +++ b/internal/metrics/dnssvc.go @@ -11,7 +11,7 @@ var DNSSvcRequestByCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "request_per_country_total", Namespace: namespace, Subsystem: subsystemDNSSvc, - Help: "The number of filtered DNS requests labeled by country and continent.", + Help: "The number of processed DNS requests labeled by country and continent.", }, []string{"continent", "country"}) // DNSSvcRequestByASNTotal is a counter with the total number of queries @@ -20,13 +20,14 @@ var DNSSvcRequestByASNTotal = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "request_per_asn_total", Namespace: namespace, Subsystem: subsystemDNSSvc, - Help: "The number of filtered DNS requests labeled by country and ASN.", + Help: "The number of processed DNS requests labeled by country and ASN.", }, []string{"country", "asn"}) // DNSSvcRequestByFilterTotal is a counter with the total number of queries -// processed labeled by filter. "filter" contains the ID of the filter list -// applied. "anonymous" is "0" if the request is from a AdGuard DNS customer, -// otherwise it is "1". +// processed labeled by filter. Processed could mean that the request was +// blocked or unblocked by a rule from that filter list. "filter" contains +// the ID of the filter list applied. "anonymous" is "0" if the request is +// from a AdGuard DNS customer, otherwise it is "1". var DNSSvcRequestByFilterTotal = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "request_per_filter_total", Namespace: namespace, diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 0d163c9..5371034 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -26,6 +26,8 @@ const ( subsystemQueryLog = "querylog" subsystemRuleStat = "rulestat" subsystemTLS = "tls" + subsystemResearch = "research" + subsystemWebSvc = "websvc" ) // SetUpGauge signals that the server has been started. diff --git a/internal/metrics/research.go b/internal/metrics/research.go new file mode 100644 index 0000000..50646a6 --- /dev/null +++ b/internal/metrics/research.go @@ -0,0 +1,61 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// ResearchRequestsPerCountryTotal counts the total number of queries per +// country from anonymous users. +var ResearchRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "requests_per_country_total", + Namespace: namespace, + Subsystem: subsystemResearch, + Help: "The total number of DNS queries per country from anonymous users.", +}, []string{"country"}) + +// ResearchBlockedRequestsPerCountryTotal counts the number of blocked queries +// per country from anonymous users. +var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "blocked_per_country_total", + Namespace: namespace, + Subsystem: subsystemResearch, + Help: "The number of blocked DNS queries per country from anonymous users.", +}, []string{"filter", "country"}) + +// ReportResearchMetrics reports metrics to prometheus that we may need to +// conduct researches. +// +// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved. +func ReportResearchMetrics( + anonymous bool, + filteringEnabled bool, + asn string, + ctry string, + filterID string, + blocked bool, +) { + // The current research metrics only count queries that come to public + // DNS servers where filtering is enabled. + if !filteringEnabled || !anonymous { + return + } + + // Ignore AdGuard ASN specifically in order to avoid counting queries that + // come from the monitoring. This part is ugly, but since these metrics + // are a one-time deal, this is acceptable. + // + // TODO(ameshkov): think of a better way later if we need to do that again. + if asn == "212772" { + return + } + + if blocked { + ResearchBlockedRequestsPerCountryTotal.WithLabelValues( + filterID, + ctry, + ).Inc() + } + + ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc() +} diff --git a/internal/metrics/usercount.go b/internal/metrics/usercount.go index b49d2b3..b06cf3f 100644 --- a/internal/metrics/usercount.go +++ b/internal/metrics/usercount.go @@ -86,7 +86,7 @@ func (c *userCounter) record(now time.Time, ip netip.Addr, syncUpdate bool) { prevMinuteCounter := c.currentMinuteCounter c.currentMinute = minuteOfTheDay - c.currentMinuteCounter = hyperloglog.New() + c.currentMinuteCounter = newHyperLogLog() // If this is the first iteration and prevMinute is -1, don't update the // counters, since there are none. @@ -123,7 +123,7 @@ func (c *userCounter) updateCounters(prevMinute int, prevCounter *hyperloglog.Sk // estimate uses HyperLogLog counters to estimate the hourly and daily users // count, starting with the minute of the day m. func (c *userCounter) estimate(m int) (hourly, daily uint64) { - hourlyCounter, dailyCounter := hyperloglog.New(), hyperloglog.New() + hourlyCounter, dailyCounter := newHyperLogLog(), newHyperLogLog() // Go through all minutes in a day while decreasing the current minute m. // Decreasing m, as opposed to increasing it or using i as the minute, is @@ -168,6 +168,11 @@ func decrMod(n, m int) (res int) { return n - 1 } +// newHyperLogLog creates a new instance of hyperloglog.Sketch. +func newHyperLogLog() (sk *hyperloglog.Sketch) { + return hyperloglog.New16() +} + // defaultUserCounter is the main user statistics counter. var defaultUserCounter = newUserCounter() diff --git a/internal/metrics/websvc.go b/internal/metrics/websvc.go new file mode 100644 index 0000000..e324e18 --- /dev/null +++ b/internal/metrics/websvc.go @@ -0,0 +1,69 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + webSvcRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "websvc_requests_total", + Namespace: namespace, + Subsystem: subsystemWebSvc, + Help: "The number of DNS requests for websvc.", + }, []string{"kind"}) + + // WebSvcError404RequestsTotal is a counter with total number of + // requests with error 404. + WebSvcError404RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "error404", + }) + + // WebSvcError500RequestsTotal is a counter with total number of + // requests with error 500. + WebSvcError500RequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "error500", + }) + + // WebSvcStaticContentRequestsTotal is a counter with total number of + // requests for static content. + WebSvcStaticContentRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "static_content", + }) + + // WebSvcDNSCheckTestRequestsTotal is a counter with total number of + // requests for dnscheck_test. + WebSvcDNSCheckTestRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "dnscheck_test", + }) + + // WebSvcRobotsTxtRequestsTotal is a counter with total number of + // requests for robots_txt. + WebSvcRobotsTxtRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "robots_txt", + }) + + // WebSvcRootRedirectRequestsTotal is a counter with total number of + // root redirected requests. + WebSvcRootRedirectRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "root_redirect", + }) + + // WebSvcLinkedIPProxyRequestsTotal is a counter with total number of + // requests with linked ip. + WebSvcLinkedIPProxyRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "linkip", + }) + + // WebSvcAdultBlockingPageRequestsTotal is a counter with total number + // of requests for adult blocking page. + WebSvcAdultBlockingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "adult_blocking_page", + }) + + // WebSvcSafeBrowsingPageRequestsTotal is a counter with total number + // of requests for safe browsing page. + WebSvcSafeBrowsingPageRequestsTotal = webSvcRequestsTotal.With(prometheus.Labels{ + "kind": "safe_browsing_page", + }) +) diff --git a/internal/querylog/fs.go b/internal/querylog/fs.go index 55f8db1..1e4449d 100644 --- a/internal/querylog/fs.go +++ b/internal/querylog/fs.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/mathutil" ) // FileSystemConfig is the configuration of the file system query log. @@ -71,11 +72,6 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) { metrics.QueryLogItemsCount.Inc() }() - var dnssec uint8 = 0 - if e.DNSSEC { - dnssec = 1 - } - entBuf := l.bufferPool.Get().(*entryBuffer) defer l.bufferPool.Put(entBuf) entBuf.buf.Reset() @@ -94,7 +90,7 @@ func (l *FileSystem) Write(_ context.Context, e *Entry) (err error) { ClientASN: e.ClientASN, Elapsed: e.Elapsed, RequestType: e.RequestType, - DNSSEC: dnssec, + DNSSEC: mathutil.BoolToNumber[uint8](e.DNSSEC), Protocol: e.Protocol, ResultCode: c, ResponseCode: e.ResponseCode, diff --git a/internal/websvc/handler.go b/internal/websvc/handler.go index 55c20d9..ed09a80 100644 --- a/internal/websvc/handler.go +++ b/internal/websvc/handler.go @@ -7,6 +7,7 @@ import ( "os" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -58,12 +59,16 @@ func (svc *Service) processRec( body = svc.error404 respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) } + + metrics.WebSvcError404RequestsTotal.Inc() case http.StatusInternalServerError: action = "writing 500" if len(svc.error500) != 0 { body = svc.error500 respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) } + + metrics.WebSvcError500RequestsTotal.Inc() default: action = "writing response" for k, v := range rec.Header() { @@ -87,6 +92,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/dnscheck/test": svc.dnsCheck.ServeHTTP(w, r) + + metrics.WebSvcDNSCheckTestRequestsTotal.Inc() case "/robots.txt": serveRobotsDisallow(w.Header(), w, "handler") case "/": @@ -94,6 +101,8 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } else { http.Redirect(w, r, svc.rootRedirectURL, http.StatusFound) + + metrics.WebSvcRootRedirectRequestsTotal.Inc() } default: http.NotFound(w, r) @@ -101,7 +110,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) { } // safeBrowsingHandler returns an HTTP handler serving the block page from the -// blockPagePath. name is used for logging. +// blockPagePath. name is used for logging and metrics. func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) { f := func(w http.ResponseWriter, r *http.Request) { hdr := w.Header() @@ -122,6 +131,15 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) { if err != nil { logErrorByType(err, "websvc: %s: writing response: %s", name, err) } + + switch name { + case adultBlockingName: + metrics.WebSvcAdultBlockingPageRequestsTotal.Inc() + case safeBrowsingName: + metrics.WebSvcSafeBrowsingPageRequestsTotal.Inc() + default: + // Go on. + } } } @@ -140,10 +158,11 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve return false } - if f.AllowOrigin != "" { - w.Header().Set(agdhttp.HdrNameAccessControlAllowOrigin, f.AllowOrigin) + h := w.Header() + for k, v := range f.Headers { + h[k] = v } - w.Header().Set(agdhttp.HdrNameContentType, f.ContentType) + w.WriteHeader(http.StatusOK) _, err := w.Write(f.Content) @@ -151,16 +170,15 @@ func (sc StaticContent) serveHTTP(w http.ResponseWriter, r *http.Request) (serve logErrorByType(err, "websvc: static content: writing %s: %s", p, err) } + metrics.WebSvcStaticContentRequestsTotal.Inc() + return true } // StaticFile is a single file in a StaticFS. type StaticFile struct { - // AllowOrigin is the value for the HTTP Access-Control-Allow-Origin header. - AllowOrigin string - - // ContentType is the value for the HTTP Content-Type header. - ContentType string + // Headers contains headers of the HTTP response. + Headers http.Header // Content is the file content. Content []byte @@ -174,6 +192,8 @@ func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) { if err != nil { logErrorByType(err, "websvc: %s: writing response: %s", name, err) } + + metrics.WebSvcRobotsTxtRequestsTotal.Inc() } // logErrorByType writes err to the error log, unless err is a network error or diff --git a/internal/websvc/handler_test.go b/internal/websvc/handler_test.go index 128fd5f..24afecf 100644 --- a/internal/websvc/handler_test.go +++ b/internal/websvc/handler_test.go @@ -31,8 +31,10 @@ func TestService_ServeHTTP(t *testing.T) { staticContent := map[string]*websvc.StaticFile{ "/favicon.ico": { - ContentType: "image/x-icon", - Content: []byte{}, + Content: []byte{}, + Headers: http.Header{ + agdhttp.HdrNameContentType: []string{"image/x-icon"}, + }, }, } @@ -56,22 +58,33 @@ func TestService_ServeHTTP(t *testing.T) { }) // DNSCheck path. - assertPathResponse(t, svc, "/dnscheck/test", http.StatusOK) + assertResponse(t, svc, "/dnscheck/test", http.StatusOK) - // Static content path. - assertPathResponse(t, svc, "/favicon.ico", http.StatusOK) + // Static content path with headers. + h := http.Header{ + agdhttp.HdrNameContentType: []string{"image/x-icon"}, + agdhttp.HdrNameServer: []string{"AdGuardDNS/"}, + } + assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h) // Robots path. - assertPathResponse(t, svc, "/robots.txt", http.StatusOK) + assertResponse(t, svc, "/robots.txt", http.StatusOK) // Root redirect path. - assertPathResponse(t, svc, "/", http.StatusFound) + assertResponse(t, svc, "/", http.StatusFound) // Other path. - assertPathResponse(t, svc, "/other", http.StatusNotFound) + assertResponse(t, svc, "/other", http.StatusNotFound) } -func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCode int) { +// assertResponse is a helper function that checks status code of HTTP +// response. +func assertResponse( + t *testing.T, + svc *websvc.Service, + path string, + statusCode int, +) (rw *httptest.ResponseRecorder) { t.Helper() r := httptest.NewRequest(http.MethodGet, (&url.URL{ @@ -79,9 +92,27 @@ func assertPathResponse(t *testing.T, svc *websvc.Service, path string, statusCo Host: "127.0.0.1", Path: path, }).String(), strings.NewReader("")) - rw := httptest.NewRecorder() + rw = httptest.NewRecorder() svc.ServeHTTP(rw, r) assert.Equal(t, statusCode, rw.Code) assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer)) + + return rw +} + +// assertResponseWithHeaders is a helper function that checks status code and +// headers of HTTP response. +func assertResponseWithHeaders( + t *testing.T, + svc *websvc.Service, + path string, + statusCode int, + header http.Header, +) { + t.Helper() + + rw := assertResponse(t, svc, path, statusCode) + + assert.Equal(t, header, rw.Header()) } diff --git a/internal/websvc/linkip.go b/internal/websvc/linkip.go index e7d7a58..90f1668 100644 --- a/internal/websvc/linkip.go +++ b/internal/websvc/linkip.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -148,6 +149,8 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID) prx.httpProxy.ServeHTTP(w, r) + + metrics.WebSvcLinkedIPProxyRequestsTotal.Inc() } else if r.URL.Path == "/robots.txt" { serveRobotsDisallow(respHdr, w, prx.logPrefix) } else { diff --git a/internal/websvc/websvc.go b/internal/websvc/websvc.go index d812bad..7a1f4e1 100644 --- a/internal/websvc/websvc.go +++ b/internal/websvc/websvc.go @@ -118,8 +118,8 @@ func New(c *Config) (svc *Service) { error404: c.Error404, error500: c.Error500, - adultBlocking: blockPageServers(c.AdultBlocking, "adult blocking", c.Timeout), - safeBrowsing: blockPageServers(c.SafeBrowsing, "safe browsing", c.Timeout), + adultBlocking: blockPageServers(c.AdultBlocking, adultBlockingName, c.Timeout), + safeBrowsing: blockPageServers(c.SafeBrowsing, safeBrowsingName, c.Timeout), } if c.RootRedirectURL != nil { @@ -164,6 +164,12 @@ func New(c *Config) (svc *Service) { return svc } +// Names for safeBrowsingHandler for logging and metrics. +const ( + safeBrowsingName = "safe browsing" + adultBlockingName = "adult blocking" +) + // blockPageServers is a helper function that converts a *BlockPageServer into // HTTP servers. func blockPageServers( diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 5eb2c89..efc84a8 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -117,7 +117,9 @@ underscores() { git ls-files '*_*.go'\ | grep -F\ -e '_generate.go'\ + -e '_linux.go'\ -e '_noreuseport.go'\ + -e '_others.go'\ -e '_reuseport.go'\ -e '_test.go'\ -e '_unix.go'\