diff --git a/.gitignore b/.gitignore index a52cf5c..57dd421 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ /github-mirror/ AdGuardDNS asn.mmdb -config.yml +config.yaml country.mmdb dnsdb.bolt querylog.jsonl diff --git a/CHANGELOG.md b/CHANGELOG.md index a222fcf..982a33f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,117 @@ The format is **not** based on [Keep a Changelog][kec], since the project +## AGDNS-1498 / Build 527 + + * Object `ratelimit` has a new property, `connection_limit`, which allows + setting stream-connection limits. Example configuration: + + ```yaml + ratelimit: + # … + connection_limit: + enabled: true + stop: 1000 + resume: 800 + ``` + + + +## AGDNS-1383 / Build 525 + + * The environment variable `PROFILES_CACHE_PATH` is now sensitive to the file + extension. Use `.json` for the previous behavior of encoding the cache into + a JSON file or `.pb` for encoding it into protobuf. Other extensions are + invalid. + + + +## AGDNS-1381 / Build 518 + + * The new object `network` has been added: + + ```yaml + network: + so_sndbuf: 0 + so_rcvbuf: 0 + ``` + + + +## AGDNS-1383 / Build 515 + + * The environment variable `PROFILES_CACHE_PATH` now has a new special value, + `none`, which disables profile caching entirely. The default value of + `./profilecache.json` has not been changed. + + + +## AGDNS-1479 / Build 513 + + * The profile-cache version has been changed to `6`. Versions of the profile + cache from `3` to `5` are invalid and should not be reused. + + + +## AGDNS-1473 / Build 506 + + * The profile-cache version has been changed to `5`. + + + +## AGDNS-1247 / Build 484 + + * The new object `interface_listeners` has been added: + + ```yaml + interface_listeners: + channel_buffer_size: 1000 + list: + eth0_plain_dns: + interface: 'eth0' + port': 53 + eth0_plain_dns_secondary: + interface: 'eth0' + port': 5353 + ``` + + * The objects within the `server_groups.*.servers` array have a new optional + property, `bind_interfaces`: + + ```yaml + server_groups: + - + # … + servers: + - name: 'default_dns' + # … + bind_interfaces: + - id: 'eth0_plain_dns' + subnet: '127.0.0.0/8' + - id: 'eth0_plain_dns_secondary' + subnet: '127.0.0.0/8' + ``` + + It is mutually exclusive with the current `bind_addresses` field. + + + +## AGDNS-1406 / Build 480 + + * The default behavior of the environment variable `DNSDB_PATH` has been + changed. Previously, if the variable was unset then the default value, + `./dnsdb.bolt`, was used, but if it was an empty string, DNSDB was disabled. + Now both unset and empty value disable DNSDB, which is consistent with the + documentation. + + This means that DNSDB is disabled by default. + + * The default configuration file path has been changed from `config.yml` to + ./config.yaml for consistency with other + services. + + + ## AGDNS-916 / Build 456 * `ratelimit` now defines rate of requests per second for IPv4 and IPv6 @@ -181,7 +292,7 @@ The format is **not** based on [Keep a Changelog][kec], since the project ## AGDNS-842 / Build 372 - * The new environment variable `PROFILES_CACHE_PATH` has been added. Its + * The new environment variable `PROFILES_CACHE_PATH` has been added. Its default value is `./profilecache.json`. Adjust the value, if necessary. @@ -189,7 +300,7 @@ The format is **not** based on [Keep a Changelog][kec], since the project ## AGDNS-891 / Build 371 * The property `server` of `upstream` object has been changed. Now it - is a URL optionally starting with `tcp://` or `udp://`, and then an address + is a URL optionally starting with `tcp://` or `udp://`, and then an address in `ip:port` format. ```yaml diff --git a/HACKING.md b/HACKING.md index 1ce3bde..4e1b89b 100644 --- a/HACKING.md +++ b/HACKING.md @@ -1 +1 @@ -See Adguard Home [`HACKING.md`](https://github.com/AdguardTeam/AdGuardHome/blob/master/HACKING.md). \ No newline at end of file +See the [Adguard Code Guidelines](https://github.com/AdguardTeam/CodeGuidelines/). diff --git a/Makefile b/Makefile index 78e925d..736b97d 100644 --- a/Makefile +++ b/Makefile @@ -59,6 +59,9 @@ go-gen: cd ./internal/agd/ && "$(GO.MACRO)" run ./country_generate.go cd ./internal/geoip/ && "$(GO.MACRO)" run ./asntops_generate.go + cd ./internal/profiledb/internal/filecachepb/ &&\ + protoc --go_opt=paths=source_relative --go_out=. ./filecache.proto + go-check: go-tools go-lint go-test # A quick check to make sure that all operating systems relevant to the diff --git a/README.md b/README.md index 78bbcaf..358f379 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ following features: you need that. [rules system]: https://adguard-dns.io/kb/general/dns-filtering-syntax/ -[API]: https://adguard-dns.io/kb/private-dns/api/ +[API]: https://adguard-dns.io/kb/private-dns/api/overview/ diff --git a/config.dist.yml b/config.dist.yaml similarity index 89% rename from config.dist.yml rename to config.dist.yaml index 02d83fb..10bc23d 100644 --- a/config.dist.yml +++ b/config.dist.yaml @@ -43,6 +43,15 @@ ratelimit: # Time between two updates of allow list. refresh_interval: 1h + # Configuration for the stream connection limiting. + connection_limit: + enabled: true + # The point at which the limiter stops accepting new connections. Once + # the number of active connections reaches this limit, new connections + # wait for the number to decrease below resume. + stop: 1000 + resume: 800 + # DNS cache configuration. cache: # The type of cache to use. Can be 'simple' (a simple LRU cache) or 'ecs' @@ -257,6 +266,21 @@ filtering_groups: block_private_relay: false block_firefox_canary: true +# The configuration for the device-listening feature. Works only on Linux with +# SO_BINDTODEVICE support. +interface_listeners: + # The size of the buffers of the channels used to dispatch TCP connections + # and UDP sessions. + channel_buffer_size: 1000 + # List is the mapping of interface-listener IDs to their configuration. + list: + 'eth0_plain_dns': + interface: 'eth0' + port: 53 + 'eth0_plain_dns_secondary': + interface: 'eth0' + port: 5353 + # Server groups and servers. server_groups: - name: 'adguard_dns_default' @@ -302,8 +326,13 @@ server_groups: # See README for the list of protocol values. protocol: 'dns' linked_ip_enabled: true - bind_addresses: - - '127.0.0.1:53' + # Either bind_interfaces or bind_addresses (see below) can be used for + # the plain-DNS servers. + bind_interfaces: + - id: 'eth0_plain_dns' + subnet: '127.0.0.0/8' + - id: 'eth0_plain_dns_secondary' + subnet: '127.0.0.0/8' - name: 'default_dot' protocol: 'tls' linked_ip_enabled: false @@ -351,3 +380,12 @@ connectivity_check: # Additional information to be exposed through metrics. additional_metrics_info: test_key: 'test_value' + +# Network settings. +network: + # Defines the size of socket send buffer in bytes. Default is zero (uses + # system settings). + so_sndbuf: 0 + # Defines the size of socket receive buffer in bytes. Default is zero + # (uses system settings). + so_rcvbuf: 0 diff --git a/doc/configuration.md b/doc/configuration.md index 49fdc69..50e631c 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -6,9 +6,12 @@ configuration file with comments. ## Contents - * [Recommended values](#recommended) + * [Recommended values and notes](#recommended) * [Result cache sizes](#recommended-result_cache) + * [`SO_RCVBUF` and `SO_SNDBUF` on Linux](#recommended-buffers) + * [Connection limiter](#recommended-connection_limit) * [Rate limiting](#ratelimit) + * [Stream connection limit](#ratelimit-connection_limit) * [Cache](#cache) * [Upstream](#upstream) * [Healthcheck](#upstream-healthcheck) @@ -21,11 +24,13 @@ configuration file with comments. * [Adult-content blocking](#adult_blocking) * [Filters](#filters) * [Filtering groups](#filtering_groups) + * [Network interface listeners](#interface_listeners) * [Server groups](#server_groups) * [TLS](#server_groups-*-tls) * [DDR](#server_groups-*-ddr) * [Servers](#server_groups-*-servers-*) * [Connectivity check](#connectivity-check) + * [Network settings](#network) * [Additional metrics information](#additional_metrics_info) [dist]: ../config.dist.yml @@ -34,7 +39,7 @@ configuration file with comments. -## Recommended values +## Recommended values and notes ### Result cache sizes @@ -59,6 +64,55 @@ from answers, you'll need to multiply the value from the statistic by 5 or 6. + ### `SO_RCVBUF` and `SO_SNDBUF` on Linux + +On Linux OSs the values for these socket options coming from the configuration +file (parameters [`network.so_rcvbuf`](#network-so_rcvbuf) and +[`network.so_sndbuf`](#network-so_sndbuf)) is doubled, and the maximum and +minimum values are controlled by the values in `/proc/`. See `man 7 socket`: + + > `SO_RCVBUF` + > + > \[…\] The kernel doubles this value (to allow space for bookkeeping + > overhead) when it is set using setsockopt(2), and this doubled value is + > returned by getsockopt(2). The default value is set by the + > `/proc/sys/net/core/rmem_default` file, and the maximum allowed value is set + > by the `/proc/sys/net/core/rmem_max` file. The minimum (doubled) value for + > this option is `256`. + > + > \[…\] + > + > `SO_SNDBUF` + > + > \[…\] The default value is set by the `/proc/sys/net/core/wmem_default` + > file, and the maximum allowed value is set by the + > `/proc/sys/net/core/wmem_max` file. The minimum (doubled) value for this + > option is `2048`. + + + + ### Stream connection limit + +Currently, there are the following recommendations for parameters +[`ratelimit.connection_limit.stop`](#ratelimit-connection_limit-stop) and +[`ratelimit.connection_limit.resume`](#ratelimit-connection_limit-resume): + + * `stop` should be about 25 % above the current maximum daily number of used + TCP sockets. That is, if the instance currently has a maximum of 100 000 + TCP sockets in use every day, `stop` should be set to about `125000`. + + * `resume` should be about 20 % above the current maximum daily number of used + TCP sockets. That is, if the instance currently has a maximum of 100 000 + TCP sockets in use every day, `resume` should be set to about `120000`. + +**NOTE:** The number of active stream-connections includes sockets that are +in the process of accepting new connections but have not yet accepted one. That +means that `resume` should be greater than the number of bound addresses. + +These recommendations are to be revised based on the metrics. + + + ## Rate limiting The `ratelimit` object has the following properties: @@ -138,6 +192,30 @@ by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests (one above `rps`) every second for 10 seconds within one minute, the client is blocked for `back_off_duration`. + ### Stream connection limit + +The `connection_limit` object has the following properties: + + * `enabled`: + Whether or not the stream-connection limit should be enforced. + + **Example:** `true`. + + * `stop`: + The point at which the limiter stops accepting new connections. Once the + number of active connections reaches this limit, new connections wait for + the number to decrease to or below `resume`. + + **Example:** `1000`. + + * `resume`: + The point at which the limiter starts accepting new connections again after + reaching `stop`. + + **Example:** `800`. + +See also [notes on these parameters](#recommended-connection_limit). + [env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL @@ -660,6 +738,39 @@ The items of the `filtering_groups` array have the following properties: +## Network interface listeners + +**NOTE:** The network interface listening works only on Linux with +`SO_BINDTODEVICE` support (2.0.30 and later) and properly setup IP routes. See +the [section on testing `SO_BINDTODEVICE` using Docker][dev-btd]. + +The `interface_listeners` object has the following properties: + + * `channel_buffer_size`: + The size of the buffers of the channels used to dispatch TCP connections and + UDP sessions. + + **Example:** `1000`. + + * `list`: + The mapping of interface-listener IDs to their configuration. + + **Property example:** + + ```yaml + list: + 'eth0_plain_dns': + interface: 'eth0' + port: 53 + 'eth0_plain_dns_secondary': + interface: 'eth0' + port: 5353 + ``` + +[dev-btd]: development.md#testing-bindtodevice + + + ## Server groups The items of the `server_groups` array have the following properties: @@ -829,10 +940,25 @@ The items of the `servers` array have the following properties: **Example:** `true`. * `bind_addresses`: - The array of `ip:port` addresses to listen on. + The array of `ip:port` addresses to listen on. If `bind_addresses` is set, + `bind_interfaces` (see below) should not be set. **Example:** `[127.0.0.1:53, 192.168.1.1:53]`. + * `bind_interfaces`: + The array of [interface listener](#ifl-list) data. If `bind_interfaces` is + set, `bind_addresses` (see above) should not be set. + + **Property example:** + + ```yaml + 'bind_interfaces': + - 'id': eth0_plain_dns' + 'subnet': '172.17.0.0/16' + - 'id': eth0_plain_dns_secondary' + 'subnet': '172.17.0.0/16' + ``` + * `dnscrypt`: The optional DNSCrypt configuration object. It has the following properties: @@ -886,6 +1012,28 @@ The `connectivity_check` object has the following properties: +## Network settings + +The `network` object has the following properties: + + * `so_rcvbuf`: + The size of socket receive buffer (`SO_RCVBUF`), in bytes. Default is zero, + which means use the default system settings. + + See also [notes on these parameters](#recommended-buffers). + + **Example:** `1048576`. + + * `so_sndbuf`: + The size of socket send buffer (`SO_SNDBUF`), in bytes. Default is zero, + which means use the default system settings. + + See also [notes on these parameters](#recommended-buffers). + + **Example:** `1048576`. + + + ## Additional metrics information The `additional_metrics_info` object is a map of strings with extra information diff --git a/doc/debugdns.md b/doc/debugdns.md index a2d63ef..16e64fc 100644 --- a/doc/debugdns.md +++ b/doc/debugdns.md @@ -87,6 +87,17 @@ In the `ADDITIONAL SECTION`, the following debug information is returned: ```none asn.adguard-dns.com. 10 CH TXT "1234" ``` + + * `subdivision`: + User's location subdivision code. This field could be empty even if user's + country code is present. The full name is `subdivision.adguard-dns.com`. + + **Example:** + + ```none + country.adguard-dns.com. 10 CH TXT "US" + subdivision.adguard-dns.com. 10 CH TXT "CA" + ``` The following debug records can have one of two prefixes: `req` or `resp`. The prefix depends on whether the filtering was applied to the request or the diff --git a/doc/development.md b/doc/development.md index c31a5e4..b925b2d 100644 --- a/doc/development.md +++ b/doc/development.md @@ -68,10 +68,26 @@ This is not an extensive list. See `../Makefile`.
make go-gen
- Regenerate the automatically generated Go files. Those generated files - are ../internal/agd/country_generate.go and - ../internal/geoip/asntops_generate.go. They need to be - periodically updated. +

+ Regenerate the automatically generated Go files that need to be + periodically updated. Those generated files are: +

+ +

+ You'll need to + install protoc + for the last one. +

make go-lint
@@ -158,7 +174,7 @@ dnscrypt generate -p testdns -o ./dnscrypt.yml ```sh cd ../ -cp -f config.dist.yml config.yml +cp -f config.dist.yaml config.yaml ``` @@ -190,6 +206,7 @@ We'll use the test versions of the GeoIP databases here. rm -f -r ./test/cache/ mkdir ./test/cache curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoIP2-Country-Test.mmdb' -o ./test/GeoIP2-Country-Test.mmdb +curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoIP2-City-Test.mmdb' -o ./test/GeoIP2-City-Test.mmdb curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoLite2-ASN-Test.mmdb' -o ./test/GeoLite2-ASN-Test.mmdb ``` @@ -206,8 +223,8 @@ You'll need to supply the following: See the [external HTTP API documentation][externalhttp]. -You may need to change the listen ports in `config.yml` which are less than 1024 -to some other ports. Otherwise, `sudo` or `doas` is required to run +You may need to change the listen ports in `config.yaml` which are less than +1024 to some other ports. Otherwise, `sudo` or `doas` is required to run `AdGuardDNS`. Examples below are for the configuration with the following changes: @@ -224,14 +241,14 @@ env \ BACKEND_ENDPOINT='PUT BACKEND URL HERE' \ BLOCKED_SERVICE_INDEX_URL='https://atropnikov.github.io/HostlistsRegistry/assets/services.json'\ CONSUL_ALLOWLIST_URL='PUT CONSUL ALLOWLIST URL HERE' \ - CONFIG_PATH='./config.yml' \ + CONFIG_PATH='./config.yaml' \ DNSDB_PATH='./test/cache/dnsdb.bolt' \ FILTER_INDEX_URL='https://atropnikov.github.io/HostlistsRegistry/assets/filters.json' \ FILTER_CACHE_PATH='./test/cache' \ PROFILES_CACHE_PATH='./test/profilecache.json' \ GENERAL_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/engines_safe_search.txt' \ GEOIP_ASN_PATH='./test/GeoLite2-ASN-Test.mmdb' \ - GEOIP_COUNTRY_PATH='./test/GeoIP2-Country-Test.mmdb' \ + GEOIP_COUNTRY_PATH='./test/GeoIP2-City-Test.mmdb' \ QUERYLOG_PATH='./test/cache/querylog.jsonl' \ LISTEN_ADDR='127.0.0.1' \ LISTEN_PORT='8081' \ diff --git a/doc/environment.md b/doc/environment.md index ffc39e3..f4d8828 100644 --- a/doc/environment.md +++ b/doc/environment.md @@ -59,7 +59,7 @@ requirements section][ext-blocked] on the expected format of the response. The path to the configuration file. -**Default:** `./config.yml`. +**Default:** `./config.yaml`. @@ -122,8 +122,29 @@ The path to the directory with the filter lists cache. ## `PROFILES_CACHE_PATH` -The path to the profile cache file. The profile cache is read on start and is -later updated on every [full refresh][conf-backend-full_refresh_interval]. +The path to the profile cache file: + + * `none` means that the profile caching is disabled. + + * A file with the extension `.pb` means that the profiles are cached in the + protobuf format. + + Use the following command to inspect the cache, assuming that the version is + correct: + + ```sh + protoc\ + --decode\ + profiledb.FileCache\ + ./internal/profiledb/internal/filecachepb/filecache.proto\ + < /path/to/profilecache.pb + ``` + + * A file with the extension `.json` means that the profiles are cached in the + JSON format. This format is **deprecated** and is not recommended. + +The profile cache is read on start and is later updated on every +[full refresh][conf-backend-full_refresh_interval]. **Default:** `./profilecache.json`. diff --git a/go.mod b/go.mod index f2a2282..6db877b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0 - github.com/AdguardTeam/golibs v0.12.1 + github.com/AdguardTeam/golibs v0.13.2 github.com/AdguardTeam/urlfilter v0.16.1 github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc @@ -19,13 +19,14 @@ require ( github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_model v0.3.0 github.com/prometheus/common v0.41.0 - github.com/quic-go/quic-go v0.33.0 + github.com/quic-go/quic-go v0.35.1 github.com/stretchr/testify v1.8.2 go.etcd.io/bbolt v1.3.7 - golang.org/x/exp v0.0.0-20230307190834-24139beb5833 + golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/net v0.8.0 golang.org/x/sys v0.6.0 golang.org/x/time v0.3.0 + google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -47,13 +48,12 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect - github.com/quic-go/qtls-go1-19 v0.2.1 // indirect - github.com/quic-go/qtls-go1-20 v0.1.1 // indirect + github.com/quic-go/qtls-go1-19 v0.3.2 // indirect + github.com/quic-go/qtls-go1-20 v0.2.2 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/mod v0.9.0 // indirect golang.org/x/text v0.8.0 // indirect golang.org/x/tools v0.7.0 // indirect - google.golang.org/protobuf v1.28.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 703fbdc..da53b2a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= -github.com/AdguardTeam/golibs v0.12.1 h1:bJfFzCnUCl+QsP6prUltM2Sjt0fTiDBPlxuAwfKP3g8= -github.com/AdguardTeam/golibs v0.12.1/go.mod h1:rIglKDHdLvFT1UbhumBLHO9S4cvWS9MEyT1njommI/Y= +github.com/AdguardTeam/golibs v0.13.2 h1:BPASsyQKmb+b8VnvsNOHp7bKfcZl9Z+Z2UhPjOiupSc= +github.com/AdguardTeam/golibs v0.13.2/go.mod h1:7ylQLv2Lqsc3UW3jHoITynYk6Y1tYtgEMkR09ppfsN8= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw= github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= @@ -91,12 +91,12 @@ github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJf github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-19 v0.2.1 h1:aJcKNMkH5ASEJB9FXNeZCyTEIHU1J7MmHyz1Q1TSG1A= -github.com/quic-go/qtls-go1-19 v0.2.1/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.1.1 h1:KbChDlg82d3IHqaj2bn6GfKRj84Per2VGf5XV3wSwQk= -github.com/quic-go/qtls-go1-20 v0.1.1/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/quic-go v0.33.0 h1:ItNoTDN/Fm/zBlq769lLJc8ECe9gYaW40veHCCco7y0= -github.com/quic-go/quic-go v0.33.0/go.mod h1:YMuhaAV9/jIu0XclDXwZPAsP/2Kgr5yMYhe9oxhhOFA= +github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= +github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= +github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= +github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= +github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/shirou/gopsutil/v3 v3.21.8 h1:nKct+uP0TV8DjjNiHanKf8SAuub+GNsbrOtM9Nl9biA= github.com/shirou/gopsutil/v3 v3.21.8/go.mod h1:YWp/H8Qs5fVmf17v7JNZzA0mPJ+mS2e9JdiUF9LlKzQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -120,8 +120,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s= -golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= +golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -170,8 +170,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/go.work.sum b/go.work.sum index 820e538..a688d3a 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,17 +1,27 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= cloud.google.com/go v0.65.0 h1:Dg9iHVQfrhq82rUNu9ZxUDrJLaxFUe/HlCVaLyRruq8= cloud.google.com/go/bigquery v1.8.0 h1:PQcPefKFdaIzjQFbiyOgAqyx8q5djaE7x9Sqe712DPA= cloud.google.com/go/datastore v1.1.0 h1:/May9ojXjRkPBNVrq+oWLqmWCkr4OU5uRY29bu0mRyQ= cloud.google.com/go/pubsub v1.3.1 h1:ukjixP1wl0LpnZ6LWtZJ0mX5tBmjp1f8Sqer8Z2OMUU= cloud.google.com/go/storage v1.10.0 h1:STgFzyU5/8miMl0//zKh2aQeTyeaUH3WN9bSUiJ09bA= dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3 h1:hJiie5Bf3QucGRa4ymsAUOxyhYwGEz1xrsVk0P8erlw= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9 h1:VpgP7xuJadIUuKccphEpTJnWhS2jkQyMt6Y7pJCD7fY= dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVls75mg+X7CXigS71EnsfVUK/2CgVrwqgw= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc= github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 h1:sR+/8Yb4slttB4vD+b9btVEnWgL3Q00OBTzVT8B9C0c= @@ -22,39 +32,58 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 h1:kFOfPq6dUM1hTo4JG6LR5AXSUEsOjtdm0kw0FtQtMJA= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625 h1:ckJgFhFWywOx+YLEMIJsTb+NV6NexWICk5+AMSuz3ss= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23 h1:D21IyuvjDCshj1/qq+pCNd3VZOAEI9jy6Bi131YlXgI= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v1.5.0 h1:lSwwFrbNviGePhkewF1az4oLmcwqCZijQ2/Wi3BGHAI= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU= github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 h1:clC1lXBpe2kTj2VHdaIu9ajZQe4kcEY9j0NsnDDBZ3o= github.com/envoyproxy/go-control-plane v0.9.4 h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/flosch/pongo2/v4 v4.0.2 h1:gv+5Pe3vaSVmiJvh/BZa82b7/00YUGm0PIyVVLop0Hw= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/getsentry/sentry-go v0.13.0 h1:20dgTiUSfxRB/EhMPtxcL9ZEbM1ZdR+W/7f7NWD+xWo= github.com/getsentry/sentry-go v0.13.0/go.mod h1:EOsfu5ZdvKPfeHYV6pTVQnsjfp30+XA7//UooKNumH0= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/gliderlabs/ssh v0.1.1 h1:j3L6gSLQalDETeEg/Jg0mGY0/y/N6zI2xX1978P0Uqw= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1 h1:QbL/5oDUmRBzO9/Z7Seo6zf912W/a6Sr4Eu0G/3Jho0= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4 h1:WtGNWLvXpe6ZudgnXrq0barxBImvnnJoMEhXAzcbM0I= github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk= @@ -68,42 +97,65 @@ github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJ github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99 h1:Ak8CrdlwwXwAZxzS66vgPt4U8yUZX7JwLvVR58FN5jM= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/googleapis/gax-go v2.0.0+incompatible h1:j0GKcs05QVmm7yesiZq2+9cxHkNK9YM6zKx4D2qucQU= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0 h1:WcmKMm43DR7RdtlkEXQJyo5ws8iTp98CyhCCbOHMvNI= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6 h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639 h1:mV02weKRL81bEnm8A0HT1/CAelMQDBuQIfLw8n+d6xI= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2 h1:rcanfLhLDA8nozr/K289V1zcntHr3V+SHlXwzz1ZI2g= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= +github.com/ianlancetaylor/demangle v0.0.0-20220517205856-0058ec4f073c/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/influxdata/influxdb v1.7.6 h1:8mQ7A/V+3noMGCt/P9pD09ISaiz9XvgCk303UYA3gcs= github.com/iris-contrib/jade v1.1.4 h1:WoYdfyJFfZIUgqNAeOyRfTNQZOksSlZ6+FnXR3AEpX0= github.com/iris-contrib/schema v0.0.6 h1:CPSBLyx2e91H2yJzPuhGuifVRnZBBJ3pCOMbOvPZaTw= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1 h1:ujPKutqRlJtcfWk6toYVYagwra7HQHbXOaS171b4Tg8= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/kataras/blocks v0.0.7 h1:cF3RDY/vxnSRezc7vLFlQFTYXG/yAr1o7WImJuZbzC4= @@ -113,20 +165,25 @@ github.com/kataras/pio v0.0.11 h1:kqreJ5KOEXGMwHAWHDwIl+mjfNCPhAwZPa8gK7MKlyw= github.com/kataras/sitemap v0.0.6 h1:w71CRMMKYMJh6LR2wTgnk5hSgjVNB9KL60n5e2KHvLY= github.com/kataras/tunnel v0.0.4 h1:sCAqWuJV7nPzGrlb0os3j49lk2JhILT0rID38NHNLpA= github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c= github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.3 h1:/Um6a/ZmD5tF7peoOJ5oN5KMQ0DrGVQSXLNwyckutPk= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/labstack/echo/v4 v4.9.0 h1:wPOF1CE6gvt/kmbMR4dGzWvHMPT+sAEUJOwOTtvITVY= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lucas-clemente/quic-go v0.25.0/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= github.com/lucas-clemente/quic-go v0.27.1/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+LVb8= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailgun/raymond/v2 v2.0.46 h1:aOYHhvTpF5USySJ0o7cpPno/Uh2I5qg2115K25A+Ft4= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe h1:W/GaMY0y69G4cFlmsC6B9sbuo2fP8OFP1ABjt4kPz+w= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/marten-seemann/qtls-go1-15 v0.1.4/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I= github.com/marten-seemann/qtls-go1-16 v0.1.4/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= @@ -137,54 +194,98 @@ github.com/marten-seemann/qtls-go1-18 v0.1.0/go.mod h1:PUhIQk19LoFt2174H4+an8TYv github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1 h1:SIYunPjnlXcW+gVfvm0IlSeR5U3WZUOLfVmqg85Go44= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg= github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 h1:D6paGObi5Wud7xg83MaEFyjxQB1W5bz5d0IFppr+ymk= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab h1:eFXv9Nu1lGbrNbj619aWwZfVF5HBrm9Plte8aNptuTI= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= +github.com/onsi/ginkgo/v2 v2.8.1/go.mod h1:N1/NbDngAFcSLdyZ+/aYTYGSlq9qMCS/cNKGJjy+csc= +github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= +github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw= github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rogpeppe/go-internal v1.3.0 h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk= github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/schollz/closestmatch v2.1.0+incompatible h1:Uel2GXEpJqOWBrlyI+oY9LTiyyjYS17cCYRqP13/SHk= github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4 h1:Fth6mevc5rX7glNLpbAMJnqKlfIkcTjZCSHEeqvKbcI= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48 h1:vabduItPAIz9px5iryD5peyx7O3Ya8TBThapgXim98o= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470 h1:qb9IthCFBmROJ6YBS31BEMeSYjOscSiG+EO+JVNTz64= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e h1:MZM7FHLqUHYI0Y/mQAt3d2aYa0SiNms/hFqC9qJYolM= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041 h1:llrF3Fs4018ePo4+G/HV/uQUqEI1HMDjCeOf2V6puPc= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d h1:Yoy/IzG4lULT6qZg62sVC+qyBL8DQkmD2zv6i7OImrc= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c h1:UOk+nlt1BJtTcH15CT7iNO7YVWTfTv/DNwEAQHLIaDQ= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b h1:vYEG87HxbU6dXj5npkeulCS96Dtz5xg3jcfCgpcvbIw= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20 h1:7pDq9pAMCQgRohFmd25X8hIH8VxmT3TaDm+r9LHxgBk= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9 h1:MPblCbqA5+z6XARjScMfz1TqtJC7TuTRj0U9VqIBs6k= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50 h1:crYRwvwjdVh1biHzzciFHe8DrZcYrVcZFlJtykhRctg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc h1:eHRtZoIi6n9Wo1uR+RU44C247msLWwyA89hVKwRLkMk= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371 h1:SWV2fHctRpRrp49VXJ6UZja7gU9QLHwRpIPBN89SKEo= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9 h1:fxoFD0in0/CBzXoyNhMTjvBZYW6ilSnTw7N7y/8vkmM= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191 h1:T4wuULTrzCKMFlg3HmKHgXAF8oStFb/+lOIupLV2v+o= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241 h1:Y+TeIabU8sJD10Qwd/zMty2/LEaT9GNDaA6nyZf+jgo= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122 h1:TQVQrsyNaimGwF7bIhzoVC9QkKm4KsWd8cECGzFx8gI= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2 h1:bu666BQci+y4S0tVRVjsHUeRon6vUXmsGBwdowgMrg4= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82 h1:LneqU9PHDsg/AkPDU3AkqMxnMYL+imaqkpflHu73us8= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95 h1:/vdW8Cb7EXrkqWGufVMES1OH2sU9gKVb2n9/1y5NMBY= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537 h1:YGaxtkYjb8mnTvtufv2LKLwCQu2/C7qFB7UtrOlTWOY= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133 h1:JtcyT0rk/9PKOdnKQzuDR+FSjh7SGtJwpgVpfZBRKlQ= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e h1:qpG93cPwA5f7s/ZPBJnGOYQNK/vKsaDaseuKT5Asee8= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tdewolff/minify/v2 v2.12.4 h1:kejsHQMM17n6/gwdw53qsi6lg0TGddZADVyQOz1KMdE= github.com/tdewolff/parse/v2 v2.6.4 h1:KCkDvNUMof10e3QExio9OPZJT8SbdKojLBumw8YZycQ= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= @@ -193,60 +294,137 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/viant/assertly v0.4.8 h1:5x1GzBaRteIwTr5RAGFVG14uNeRFxVNbXPWrK2qAgpc= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0 h1:6TteTDQ68CjgcCe8wH3D3ZhUQQOJXMTbj/D9rkk2a1k= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA= github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20230306221820-f0f767cdffd6/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= sourcegraph.com/sourcegraph/go-diff v0.5.0 h1:eTiIR0CoWjGzJcnQ3OkhIl/b9GJovq4lSAVRt0ZFEG8= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4 h1:JPJh2pk3+X4lXAkZIk2RuE/7/FoK9maXw+TNPJhVS/c= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/internal/agd/agd_test.go b/internal/agd/agd_test.go index 960e071..b1395a5 100644 --- a/internal/agd/agd_test.go +++ b/internal/agd/agd_test.go @@ -1,11 +1,9 @@ package agd_test import ( - "net/netip" "testing" "time" - "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/golibs/testutil" ) @@ -17,12 +15,3 @@ func TestMain(m *testing.M) { // testTimeout is the timeout for common test operations. const testTimeout = 1 * time.Second - -// testProfID is the profile ID for tests. -const testProfID agd.ProfileID = "prof1234" - -// testDevID is the device ID for tests. -const testDevID agd.DeviceID = "dev1234" - -// testClientIPv4 is the client IP for tests -var testClientIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) diff --git a/internal/agd/country_generate.go b/internal/agd/country_generate.go index db1c993..066ef1f 100644 --- a/internal/agd/country_generate.go +++ b/internal/agd/country_generate.go @@ -10,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "golang.org/x/exp/slices" ) @@ -22,7 +23,7 @@ func main() { req, err := http.NewRequest(http.MethodGet, csvURL, nil) check(err) - req.Header.Add("User-Agent", agdhttp.UserAgent()) + req.Header.Add(httphdr.UserAgent, agdhttp.UserAgent()) resp, err := c.Do(req) check(err) diff --git a/internal/agd/device.go b/internal/agd/device.go index 221f815..a2b7c26 100644 --- a/internal/agd/device.go +++ b/internal/agd/device.go @@ -12,17 +12,25 @@ import ( // Devices // Device is a device of a device attached to a profile. +// +// NOTE: Do not change fields of this structure without incrementing +// [internal/profiledb/internal.FileCacheVersion]. type Device struct { // ID is the unique ID of the device. ID DeviceID - // LinkedIP, when non-nil, allows AdGuard DNS to identify a device by its IP - // address when it can only use plain DNS. - LinkedIP *netip.Addr + // LinkedIP, when non-empty, allows AdGuard DNS to identify a device by its + // IP address when it can only use plain DNS. + LinkedIP netip.Addr // Name is the human-readable name of the device. Name DeviceName + // DedicatedIPs are the destination (server) IP-addresses dedicated to this + // device, if any. A device can use one of these addresses as a DNS server + // address for AdGuard DNS to recognize it. + DedicatedIPs []netip.Addr + // FilteringEnabled defines whether queries from the device should be // filtered in any way at all. FilteringEnabled bool diff --git a/internal/agd/error.go b/internal/agd/error.go index 0fc315c..6057f70 100644 --- a/internal/agd/error.go +++ b/internal/agd/error.go @@ -25,53 +25,6 @@ func (err *ArgumentError) Error() (msg string) { return fmt.Sprintf("argument %s is invalid: %s", err.Name, err.Message) } -// EntityName is the type for names of entities. Currently only used in errors. -type EntityName string - -// Current entity names. -const ( - EntityNameDevice EntityName = "device" - EntityNameProfile EntityName = "profile" -) - -// NotFoundError is an error returned by lookup methods when an entity wasn't -// found. -// -// We use separate types that implement a common interface instead of a single -// structure to reduce allocations. -type NotFoundError interface { - error - - // EntityName returns the name of the entity that couldn't be found. - EntityName() (e EntityName) -} - -// DeviceNotFoundError is a NotFoundError returned by lookup methods when -// a device wasn't found. -type DeviceNotFoundError struct{} - -// type check -var _ NotFoundError = DeviceNotFoundError{} - -// Error implements the NotFoundError interface for DeviceNotFoundError. -func (DeviceNotFoundError) Error() (msg string) { return "device not found" } - -// EntityName implements the NotFoundError interface for DeviceNotFoundError. -func (DeviceNotFoundError) EntityName() (e EntityName) { return EntityNameDevice } - -// ProfileNotFoundError is a NotFoundError returned by lookup methods when -// a profile wasn't found. -type ProfileNotFoundError struct{} - -// type check -var _ NotFoundError = ProfileNotFoundError{} - -// Error implements the NotFoundError interface for ProfileNotFoundError. -func (ProfileNotFoundError) Error() (msg string) { return "profile not found" } - -// EntityName implements the NotFoundError interface for ProfileNotFoundError. -func (ProfileNotFoundError) EntityName() (e EntityName) { return EntityNameProfile } - // NotACountryError is returned from NewCountry when the string doesn't represent // a valid country. type NotACountryError struct { diff --git a/internal/agd/location.go b/internal/agd/location.go index 406c15b..a41780b 100644 --- a/internal/agd/location.go +++ b/internal/agd/location.go @@ -4,9 +4,19 @@ package agd // Location represents the GeoIP location data about an IP address. type Location struct { - Country Country + // Country is the country whose subnets contain the IP address. + Country Country + + // Continent is the continent whose subnets contain the IP address. Continent Continent - ASN ASN + + // TopSubdivision is the ISO-code of the political subdivision of a country + // whose subnets contain the IP address. This field may be empty. + TopSubdivision string + + // ASN is the number of the autonomous system whose subnets contain the IP + // address. + ASN ASN } // ASN is the autonomous system number of an IP address. diff --git a/internal/agd/profile.go b/internal/agd/profile.go index 1b0696a..913ea3c 100644 --- a/internal/agd/profile.go +++ b/internal/agd/profile.go @@ -5,6 +5,7 @@ import ( "math" "time" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/golibs/errors" ) @@ -15,8 +16,8 @@ import ( // the infrastructure, a profile is also called a “DNS server”. We call it // profile, because it's less confusing. // -// NOTE: Increment [defaultProfileDBCacheVersion] on any change of this -// structure. +// NOTE: Do not change fields of this structure without incrementing +// [internal/profiledb/internal.FileCacheVersion]. // // TODO(a.garipov): Consider making it closer to the config file and the backend // response by grouping parental, rule list, and safe browsing settings into @@ -24,63 +25,107 @@ import ( type Profile struct { // Parental are the parental settings for this profile. They are ignored if // FilteringEnabled is set to false. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. Parental *ParentalProtectionSettings // BlockingMode defines the way blocked responses are constructed. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. BlockingMode dnsmsg.BlockingModeCodec // ID is the unique ID of this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. ID ProfileID // UpdateTime shows the last time this profile was updated from the backend. // This is NOT the time of update in the backend's database, since the // backend doesn't send this information. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. UpdateTime time.Time - // Devices are the devices attached to this profile. Every element of the - // slice must be non-nil. - Devices []*Device + // DeviceIDs are the IDs of devices attached to this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. + DeviceIDs []DeviceID // RuleListIDs are the IDs of the filtering rule lists enabled for this // profile. They are ignored if FilteringEnabled or RuleListsEnabled are // set to false. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. RuleListIDs []FilterListID // CustomRules are the custom filtering rules for this profile. They are // ignored if RuleListsEnabled is set to false. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. CustomRules []FilterRuleText // FilteredResponseTTL is the time-to-live value used for responses sent to // the devices of this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. FilteredResponseTTL time.Duration // FilteringEnabled defines whether queries from devices of this profile // should be filtered in any way at all. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. FilteringEnabled bool // SafeBrowsingEnabled defines whether queries from devices of this profile // should be filtered using the safe browsing filter. Requires // FilteringEnabled to be set to true. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. SafeBrowsingEnabled bool // RuleListsEnabled defines whether queries from devices of this profile // should be filtered using the filtering rule lists in RuleListIDs. // Requires FilteringEnabled to be set to true. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. RuleListsEnabled bool // QueryLogEnabled defines whether query logs should be saved for the // devices of this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. QueryLogEnabled bool // Deleted shows if this profile is deleted. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. Deleted bool // BlockPrivateRelay shows if Apple Private Relay queries are blocked for // requests from all devices in this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. BlockPrivateRelay bool // BlockFirefoxCanary shows if Firefox canary domain is blocked for // requests from all devices in this profile. + // + // NOTE: Do not change fields of this structure without incrementing + // [internal/profiledb/internal.FileCacheVersion]. BlockFirefoxCanary bool } @@ -163,23 +208,26 @@ type WeeklySchedule [7]DayRange // ParentalProtectionSchedule is the schedule of a client's parental protection. // All fields must not be nil. +// +// NOTE: Do not change fields of this structure without incrementing +// [internal/profiledb/internal.FileCacheVersion]. type ParentalProtectionSchedule struct { // Week is the parental protection schedule for every day of the week. Week *WeeklySchedule // TimeZone is the profile's time zone. - TimeZone *time.Location + TimeZone *agdtime.Location } // Contains returns true if t is within the allowed schedule. func (s *ParentalProtectionSchedule) Contains(t time.Time) (ok bool) { - t = t.In(s.TimeZone) + t = t.In(&s.TimeZone.Location) r := s.Week[int(t.Weekday())] if r.IsZeroLength() { return false } - day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, s.TimeZone) + day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, &s.TimeZone.Location) start := day.Add(time.Duration(r.Start) * time.Minute) end := day.Add(time.Duration(r.End+1)*time.Minute - 1*time.Nanosecond) @@ -187,6 +235,9 @@ func (s *ParentalProtectionSchedule) Contains(t time.Time) (ok bool) { } // ParentalProtectionSettings are the parental protection settings of a profile. +// +// NOTE: Do not change fields of this structure without incrementing +// [internal/profiledb/internal.FileCacheVersion]. type ParentalProtectionSettings struct { Schedule *ParentalProtectionSchedule diff --git a/internal/agd/profile_test.go b/internal/agd/profile_test.go index 7ade079..5c61018 100644 --- a/internal/agd/profile_test.go +++ b/internal/agd/profile_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" @@ -78,7 +79,7 @@ func TestParentalProtectionSchedule_Contains(t *testing.T) { time.Saturday: agd.ZeroLengthDayRange(), }, - TimeZone: time.UTC, + TimeZone: agdtime.UTC(), } // allDaySchedule, 00:00:00 to 23:59:59. @@ -95,7 +96,7 @@ func TestParentalProtectionSchedule_Contains(t *testing.T) { time.Saturday: agd.ZeroLengthDayRange(), }, - TimeZone: time.UTC, + TimeZone: agdtime.UTC(), } testCases := []struct { diff --git a/internal/agd/profiledb.go b/internal/agd/profiledb.go deleted file mode 100644 index fcb1a3d..0000000 --- a/internal/agd/profiledb.go +++ /dev/null @@ -1,423 +0,0 @@ -package agd - -import ( - "context" - "encoding/json" - "fmt" - "net/netip" - "os" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/metrics" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" -) - -// Data Storage - -// ProfileDB is the local database of profiles and other data. -// -// TODO(a.garipov): move this logic to the backend package. -type ProfileDB interface { - ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error) - ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error) -} - -// DefaultProfileDB is the default implementation of the ProfileDB interface -// that can refresh itself from the provided storage. -type DefaultProfileDB struct { - // mapsMu protects the deviceToProfile, deviceIDToIP, and ipToDeviceID maps. - mapsMu *sync.RWMutex - - // refreshMu protects syncTime and syncTimeFull. These are only used within - // Refresh, so this is also basically a refresh serializer. - refreshMu *sync.Mutex - - // storage returns the data for this profiledb. - storage ProfileStorage - - // deviceToProfile maps device IDs to their profiles. It is cleared lazily - // whenever a device is found to be missing from a profile. - deviceToProfile map[DeviceID]*Profile - - // deviceIDToIP maps device IDs to their linked IP addresses. It is used to - // take changes in IP address linking into account during refreshes. It is - // cleared lazily whenever a device is found to be missing from a profile. - deviceIDToIP map[DeviceID]netip.Addr - - // ipToDeviceID maps linked IP addresses to the IDs of their devices. It is - // cleared lazily whenever a device is found to be missing from a profile. - ipToDeviceID map[netip.Addr]DeviceID - - // syncTime is the time of the last synchronization. It is used in refresh - // requests to the storage. - syncTime time.Time - - // syncTimeFull is the time of the last full synchronization. - syncTimeFull time.Time - - // cacheFilePath is the path to profiles cache file. - cacheFilePath string - - // fullSyncIvl is the interval between two full synchronizations with the - // storage - fullSyncIvl time.Duration -} - -// NewDefaultProfileDB returns a new default profile profiledb. The initial -// refresh is performed immediately with the constant timeout of 1 minute, -// beyond which an empty profiledb is returned. db is never nil. -func NewDefaultProfileDB( - ds ProfileStorage, - fullRefreshIvl time.Duration, - cacheFilePath string, -) (db *DefaultProfileDB, err error) { - db = &DefaultProfileDB{ - mapsMu: &sync.RWMutex{}, - refreshMu: &sync.Mutex{}, - storage: ds, - syncTime: time.Time{}, - syncTimeFull: time.Time{}, - deviceToProfile: make(map[DeviceID]*Profile), - deviceIDToIP: make(map[DeviceID]netip.Addr), - ipToDeviceID: make(map[netip.Addr]DeviceID), - fullSyncIvl: fullRefreshIvl, - cacheFilePath: cacheFilePath, - } - - err = db.loadProfileCache() - if err != nil { - log.Error("profiledb: cache: loading: %s", err) - } - - // initialTimeout defines the maximum duration of the first attempt to load - // the profiledb. - const initialTimeout = 1 * time.Minute - - ctx, cancel := context.WithTimeout(context.Background(), initialTimeout) - defer cancel() - - log.Info("profiledb: initial refresh") - - err = db.Refresh(ctx) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - log.Info("profiledb: warning: initial refresh timeout: %s", err) - - return db, nil - } - - return nil, fmt.Errorf("initial refresh: %w", err) - } - - log.Info("profiledb: initial refresh succeeded") - - return db, nil -} - -// type check -var _ Refresher = (*DefaultProfileDB)(nil) - -// Refresh implements the Refresher interface for *DefaultProfileDB. It updates -// the internal maps from the data it receives from the storage. -func (db *DefaultProfileDB) Refresh(ctx context.Context) (err error) { - startTime := time.Now() - defer func() { - metrics.ProfilesSyncTime.SetToCurrentTime() - metrics.ProfilesCountGauge.Set(float64(len(db.deviceToProfile))) - metrics.ProfilesSyncDuration.Observe(time.Since(startTime).Seconds()) - metrics.SetStatusGauge(metrics.ProfilesSyncStatus, err) - }() - - reqID := NewRequestID() - ctx = WithRequestID(ctx, reqID) - - defer func() { err = errors.Annotate(err, "req %s: %w", reqID) }() - - db.refreshMu.Lock() - defer db.refreshMu.Unlock() - - isFullSync := time.Since(db.syncTimeFull) >= db.fullSyncIvl - syncTime := db.syncTime - if isFullSync { - syncTime = time.Time{} - } - - var resp *PSProfilesResponse - resp, err = db.storage.Profiles(ctx, &PSProfilesRequest{ - SyncTime: syncTime, - }) - if err != nil { - return fmt.Errorf("updating profiles: %w", err) - } - - profiles := resp.Profiles - devNum := db.setProfiles(profiles) - log.Debug("profiledb: req %s: got %d profiles with %d devices", reqID, len(profiles), devNum) - metrics.ProfilesNewCountGauge.Set(float64(len(profiles))) - - db.syncTime = resp.SyncTime - if isFullSync { - db.syncTimeFull = time.Now() - - err = db.saveProfileCache(ctx) - if err != nil { - return fmt.Errorf("saving cache: %w", err) - } - } - - return nil -} - -// profileCache is the structure for profiles db cache. -type profileCache struct { - SyncTime time.Time `json:"sync_time"` - Profiles []*Profile `json:"profiles"` - Version int `json:"version"` -} - -// saveStorageCache saves profiles data to cache file. -func (db *DefaultProfileDB) saveProfileCache(ctx context.Context) (err error) { - log.Info("profiledb: saving profile cache") - - var resp *PSProfilesResponse - resp, err = db.storage.Profiles(ctx, &PSProfilesRequest{ - SyncTime: time.Time{}, - }) - - if err != nil { - return err - } - - data := &profileCache{ - Profiles: resp.Profiles, - Version: defaultProfileDBCacheVersion, - SyncTime: time.Now(), - } - - cache, err := json.Marshal(data) - if err != nil { - return fmt.Errorf("encoding json: %w", err) - } - - err = os.WriteFile(db.cacheFilePath, cache, 0o600) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - log.Info("profiledb: cache: saved %d profiles to %q", len(resp.Profiles), db.cacheFilePath) - - return nil -} - -// defaultProfileDBCacheVersion is the version of cached data structure. It's -// manually incremented on every change in [profileCache] structure. -const defaultProfileDBCacheVersion = 2 - -// loadProfileCache loads profiles data from cache file. -func (db *DefaultProfileDB) loadProfileCache() (err error) { - log.Info("profiledb: loading cache") - - data, err := db.loadStorageCache() - if err != nil { - return fmt.Errorf("loading cache: %w", err) - } - - if data == nil { - log.Info("profiledb: cache is empty") - - return nil - } - - if data.Version == defaultProfileDBCacheVersion { - profiles := data.Profiles - devNum := db.setProfiles(profiles) - log.Info("profiledb: cache: got %d profiles with %d devices", len(profiles), devNum) - - db.syncTime = data.SyncTime - db.syncTimeFull = data.SyncTime - } else { - log.Info( - "profiledb: cache version %d is different from %d", - data.Version, - defaultProfileDBCacheVersion, - ) - } - - return nil -} - -// loadStorageCache loads data from cache file. -func (db *DefaultProfileDB) loadStorageCache() (data *profileCache, err error) { - file, err := os.Open(db.cacheFilePath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - // File could be deleted or not yet created, go on. - return nil, nil - } - - return nil, err - } - defer func() { err = errors.WithDeferred(err, file.Close()) }() - - data = &profileCache{} - err = json.NewDecoder(file).Decode(data) - if err != nil { - return nil, fmt.Errorf("decoding json: %w", err) - } - - return data, nil -} - -// setProfiles adds or updates the data for all profiles. -func (db *DefaultProfileDB) setProfiles(profiles []*Profile) (devNum int) { - db.mapsMu.Lock() - defer db.mapsMu.Unlock() - - for _, p := range profiles { - devNum += len(p.Devices) - for _, d := range p.Devices { - db.deviceToProfile[d.ID] = p - if d.LinkedIP == nil { - // Delete any records from the device-to-IP map just in case - // there used to be one. - delete(db.deviceIDToIP, d.ID) - - continue - } - - newIP := *d.LinkedIP - if prevIP, ok := db.deviceIDToIP[d.ID]; !ok || prevIP != newIP { - // The IP has changed. Remove the previous records before - // setting the new ones. - delete(db.ipToDeviceID, prevIP) - delete(db.deviceIDToIP, d.ID) - } - - db.ipToDeviceID[newIP] = d.ID - db.deviceIDToIP[d.ID] = newIP - } - } - - return devNum -} - -// type check -var _ ProfileDB = (*DefaultProfileDB)(nil) - -// ProfileByDeviceID implements the ProfileDB interface for *DefaultProfileDB. -func (db *DefaultProfileDB) ProfileByDeviceID( - ctx context.Context, - id DeviceID, -) (p *Profile, d *Device, err error) { - db.mapsMu.RLock() - defer db.mapsMu.RUnlock() - - return db.profileByDeviceID(ctx, id) -} - -// profileByDeviceID returns the profile and the device by the ID of the device, -// if found. Any returned errors will have the underlying type of -// NotFoundError. It assumes that db is currently locked for reading. -func (db *DefaultProfileDB) profileByDeviceID( - _ context.Context, - id DeviceID, -) (p *Profile, d *Device, err error) { - // Do not use errors.Annotate here, because it allocates even when the error - // is nil. Also do not use fmt.Errorf in a defer, because it allocates when - // a device is not found, which is the most common case. - // - // TODO(a.garipov): Find out, why does it allocate and perhaps file an - // issue about that in the Go issue tracker. - - var ok bool - p, ok = db.deviceToProfile[id] - if !ok { - return nil, nil, ProfileNotFoundError{} - } - - for _, pd := range p.Devices { - if pd.ID == id { - d = pd - - break - } - } - - if d == nil { - // Perhaps, the device has been deleted. May happen when the device was - // found by a linked IP. - return nil, nil, fmt.Errorf("rechecking devices: %w", DeviceNotFoundError{}) - } - - return p, d, nil -} - -// ProfileByIP implements the ProfileDB interface for *DefaultProfileDB. ip -// must be valid. -func (db *DefaultProfileDB) ProfileByIP( - ctx context.Context, - ip netip.Addr, -) (p *Profile, d *Device, err error) { - // Do not use errors.Annotate here, because it allocates even when the error - // is nil. Also do not use fmt.Errorf in a defer, because it allocates when - // a device is not found, which is the most common case. - - db.mapsMu.RLock() - defer db.mapsMu.RUnlock() - - id, ok := db.ipToDeviceID[ip] - if !ok { - return nil, nil, DeviceNotFoundError{} - } - - p, d, err = db.profileByDeviceID(ctx, id) - if errors.Is(err, DeviceNotFoundError{}) { - // Probably, the device has been deleted. Remove it from our profiledb - // in a goroutine, since that requires a write lock. - go db.removeDeviceByIP(id, ip) - - // Go on and return the error. - } - - if err != nil { - // Don't add the device ID to the error here, since it is already added - // by profileByDeviceID. - return nil, nil, fmt.Errorf("profile by linked device id: %w", err) - } - - return p, d, nil -} - -// removeDeviceByIP removes the device with the given ID and linked IP address -// from the profiledb. It is intended to be used as a goroutine. -func (db *DefaultProfileDB) removeDeviceByIP(id DeviceID, ip netip.Addr) { - defer log.OnPanicAndExit("removeDeviceByIP", 1) - - db.mapsMu.Lock() - defer db.mapsMu.Unlock() - - delete(db.ipToDeviceID, ip) - delete(db.deviceIDToIP, id) - delete(db.deviceToProfile, id) -} - -// ProfileStorage is a storage of data about profiles and other entities. -type ProfileStorage interface { - // Profiles returns all profiles known to this particular data storage. req - // must not be nil. - Profiles(ctx context.Context, req *PSProfilesRequest) (resp *PSProfilesResponse, err error) -} - -// PSProfilesRequest is the ProfileStorage.Profiles request. -type PSProfilesRequest struct { - SyncTime time.Time -} - -// PSProfilesResponse is the ProfileStorage.Profiles response. -type PSProfilesResponse struct { - SyncTime time.Time - Profiles []*Profile -} diff --git a/internal/agd/profiledb_test.go b/internal/agd/profiledb_test.go deleted file mode 100644 index 37020e7..0000000 --- a/internal/agd/profiledb_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package agd_test - -import ( - "context" - "net/netip" - "path/filepath" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// newDefaultProfileDB returns a new default profile database for tests. -func newDefaultProfileDB(tb testing.TB, dev *agd.Device) (db *agd.DefaultProfileDB) { - tb.Helper() - - onProfiles := func( - _ context.Context, - _ *agd.PSProfilesRequest, - ) (resp *agd.PSProfilesResponse, err error) { - return &agd.PSProfilesResponse{ - Profiles: []*agd.Profile{{ - BlockingMode: dnsmsg.BlockingModeCodec{ - Mode: &dnsmsg.BlockingModeNullIP{}, - }, - ID: testProfID, - Devices: []*agd.Device{dev}, - }}, - }, nil - } - - ds := &agdtest.ProfileStorage{ - OnProfiles: onProfiles, - } - - cacheFilePath := filepath.Join(tb.TempDir(), "profiles.json") - db, err := agd.NewDefaultProfileDB(ds, 1*time.Minute, cacheFilePath) - require.NoError(tb, err) - - return db -} - -func TestDefaultProfileDB(t *testing.T) { - dev := &agd.Device{ - ID: testDevID, - LinkedIP: &testClientIPv4, - } - - db := newDefaultProfileDB(t, dev) - - t.Run("by_device_id", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - p, d, err := db.ProfileByDeviceID(ctx, testDevID) - require.NoError(t, err) - - assert.Equal(t, testProfID, p.ID) - assert.Equal(t, d, dev) - }) - - t.Run("by_ip", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - p, d, err := db.ProfileByIP(ctx, testClientIPv4) - require.NoError(t, err) - - assert.Equal(t, testProfID, p.ID) - assert.Equal(t, d, dev) - }) -} - -var profSink *agd.Profile - -var devSink *agd.Device - -var errSink error - -func BenchmarkDefaultProfileDB_ProfileByDeviceID(b *testing.B) { - dev := &agd.Device{ - ID: testDevID, - } - - db := newDefaultProfileDB(b, dev) - ctx := context.Background() - - b.Run("success", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - profSink, devSink, errSink = db.ProfileByDeviceID(ctx, testDevID) - } - - assert.NotNil(b, profSink) - assert.NotNil(b, devSink) - assert.NoError(b, errSink) - }) - - const wrongDevID = testDevID + "_bad" - - b.Run("not_found", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - profSink, devSink, errSink = db.ProfileByDeviceID(ctx, wrongDevID) - } - - assert.Nil(b, profSink) - assert.Nil(b, devSink) - assert.ErrorAs(b, errSink, new(agd.NotFoundError)) - }) -} - -func BenchmarkDefaultProfileDB_ProfileByIP(b *testing.B) { - dev := &agd.Device{ - ID: testDevID, - LinkedIP: &testClientIPv4, - } - - db := newDefaultProfileDB(b, dev) - ctx := context.Background() - - b.Run("success", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - profSink, devSink, errSink = db.ProfileByIP(ctx, testClientIPv4) - } - - assert.NotNil(b, profSink) - assert.NotNil(b, devSink) - assert.NoError(b, errSink) - }) - - wrongClientIP := netip.MustParseAddr("5.6.7.8") - - b.Run("not_found", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - profSink, devSink, errSink = db.ProfileByIP(ctx, wrongClientIP) - } - - assert.Nil(b, profSink) - assert.Nil(b, devSink) - assert.ErrorAs(b, errSink, new(agd.NotFoundError)) - }) -} diff --git a/internal/agd/server.go b/internal/agd/server.go index 1211779..c010dab 100644 --- a/internal/agd/server.go +++ b/internal/agd/server.go @@ -102,8 +102,6 @@ type Server struct { // ServerBindData are the socket binding data for a server. Either AddrPort or // ListenConfig with Address must be set. // -// TODO(a.garipov): Add support for ListenConfig in the config file. -// // TODO(a.garipov): Consider turning this into a sum type. // // TODO(a.garipov): Consider renaming this and the one in websvc to something diff --git a/internal/agd/upstream.go b/internal/agd/upstream.go deleted file mode 100644 index 77952be..0000000 --- a/internal/agd/upstream.go +++ /dev/null @@ -1,26 +0,0 @@ -package agd - -import ( - "net/netip" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" -) - -// Upstream - -// Upstream module configuration. -type Upstream struct { - // Server is the upstream server we're using to forward DNS queries. - Server netip.AddrPort - - // Network is the Server network protocol. - Network forward.Network - - // FallbackServers is a list of the DNS servers we're using to fallback to - // when the upstream server fails to respond. - FallbackServers []netip.AddrPort - - // Timeout is the timeout for all outgoing DNS requests. - Timeout time.Duration -} diff --git a/internal/agdhttp/agdhttp.go b/internal/agdhttp/agdhttp.go index 53dc6b1..b0a51a8 100644 --- a/internal/agdhttp/agdhttp.go +++ b/internal/agdhttp/agdhttp.go @@ -12,20 +12,6 @@ import ( // Common Constants, Functions And Types -// HTTP header name constants. -const ( - HdrNameAcceptEncoding = "Accept-Encoding" - HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin" - HdrNameContentType = "Content-Type" - HdrNameContentEncoding = "Content-Encoding" - HdrNameServer = "Server" - HdrNameTrailer = "Trailer" - HdrNameUserAgent = "User-Agent" - - HdrNameXError = "X-Error" - HdrNameXRequestID = "X-Request-Id" -) - // HTTP header value constants. const ( HdrValApplicationJSON = "application/json" diff --git a/internal/agdhttp/client.go b/internal/agdhttp/client.go index a155756..3c46d66 100644 --- a/internal/agdhttp/client.go +++ b/internal/agdhttp/client.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/golibs/httphdr" ) // Client is a wrapper around http.Client. @@ -85,15 +86,15 @@ func (c *Client) do( } if contentType != "" { - req.Header.Set(HdrNameContentType, contentType) + req.Header.Set(httphdr.ContentType, contentType) } reqID, ok := agd.RequestIDFromContext(ctx) if ok { - req.Header.Set(HdrNameXRequestID, string(reqID)) + req.Header.Set(httphdr.XRequestID, string(reqID)) } - req.Header.Set(HdrNameUserAgent, c.userAgent) + req.Header.Set(httphdr.UserAgent, c.userAgent) resp, err = c.http.Do(req) if err != nil && resp != nil && resp.Header != nil { diff --git a/internal/agdhttp/error.go b/internal/agdhttp/error.go index e32e57c..ed22a3a 100644 --- a/internal/agdhttp/error.go +++ b/internal/agdhttp/error.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" ) // Common HTTP Errors @@ -40,7 +41,7 @@ func CheckStatus(resp *http.Response, expected int) (err error) { } return &StatusError{ - ServerName: resp.Header.Get(HdrNameServer), + ServerName: resp.Header.Get(httphdr.Server), Expected: expected, Got: resp.StatusCode, } @@ -73,6 +74,6 @@ func (err *ServerError) Unwrap() (unwrapped error) { func WrapServerError(err error, resp *http.Response) (wrapped *ServerError) { return &ServerError{ Err: err, - ServerName: resp.Header.Get(HdrNameServer), + ServerName: resp.Header.Get(httphdr.Server), } } diff --git a/internal/agdhttp/error_test.go b/internal/agdhttp/error_test.go index a88f501..f9666ca 100644 --- a/internal/agdhttp/error_test.go +++ b/internal/agdhttp/error_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" ) @@ -41,7 +42,7 @@ func TestCheckStatus(t *testing.T) { resp := &http.Response{ StatusCode: tc.got, Header: http.Header{ - agdhttp.HdrNameServer: []string{tc.srv}, + httphdr.Server: []string{tc.srv}, }, } err := agdhttp.CheckStatus(resp, tc.exp) @@ -73,7 +74,7 @@ func TestServerError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { resp := &http.Response{ Header: http.Header{ - agdhttp.HdrNameServer: []string{tc.srv}, + httphdr.Server: []string{tc.srv}, }, } err := agdhttp.WrapServerError(tc.err, resp) diff --git a/internal/agdnet/agdnet_example_test.go b/internal/agdnet/agdnet_example_test.go index 2f9260b..aa9dcf2 100644 --- a/internal/agdnet/agdnet_example_test.go +++ b/internal/agdnet/agdnet_example_test.go @@ -7,14 +7,18 @@ import ( ) func ExampleAndroidMetricDomainReplacement() { + printResult := func(input string) { + fmt.Printf("%-42q: %q\n", input, agdnet.AndroidMetricDomainReplacement(input)) + } + anAndroidDomain := "12345678-dnsotls-ds.metric.gstatic.com." - fmt.Printf("%-42q: %q\n", anAndroidDomain, agdnet.AndroidMetricDomainReplacement(anAndroidDomain)) + printResult(anAndroidDomain) anAndroidDomain = "123456-dnsohttps-ds.metric.gstatic.com." - fmt.Printf("%-42q: %q\n", anAndroidDomain, agdnet.AndroidMetricDomainReplacement(anAndroidDomain)) + printResult(anAndroidDomain) notAndroidDomain := "example.com" - fmt.Printf("%-42q: %q\n", notAndroidDomain, agdnet.AndroidMetricDomainReplacement(notAndroidDomain)) + printResult(notAndroidDomain) // Output: // "12345678-dnsotls-ds.metric.gstatic.com." : "00000000-dnsotls-ds.metric.gstatic.com." diff --git a/internal/agdtest/agdtest.go b/internal/agdtest/agdtest.go index 3c8b0e7..672745c 100644 --- a/internal/agdtest/agdtest.go +++ b/internal/agdtest/agdtest.go @@ -10,7 +10,11 @@ import ( // FilteredResponseTTL is the common filtering response TTL for tests. It is // also used by [NewConstructor]. -const FilteredResponseTTL = 10 * time.Second +const FilteredResponseTTL = FilteredResponseTTLSec * time.Second + +// FilteredResponseTTLSec is the common filtering response TTL for tests, as a +// number to simplify message creation. +const FilteredResponseTTLSec = 10 // NewConstructor returns a standard dnsmsg.Constructor for tests. func NewConstructor() (c *dnsmsg.Constructor) { diff --git a/internal/agdtest/interface.go b/internal/agdtest/interface.go index 43b0eb2..3036150 100644 --- a/internal/agdtest/interface.go +++ b/internal/agdtest/interface.go @@ -11,9 +11,11 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/billstat" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit" "github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/geoip" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/rulestat" "github.com/AdguardTeam/golibs/netutil" @@ -22,7 +24,144 @@ import ( // Interface Mocks // -// Keep entities in this file in alphabetic order. +// Keep entities within a module/package in alphabetic order. + +// Module std + +// Package net +// +// TODO(a.garipov): Move these to golibs? + +// type check +var _ net.Conn = (*Conn)(nil) + +// Conn is the [net.Conn] for tests. +type Conn struct { + OnClose func() (err error) + OnLocalAddr func() (laddr net.Addr) + OnRead func(b []byte) (n int, err error) + OnRemoteAddr func() (raddr net.Addr) + OnSetDeadline func(t time.Time) (err error) + OnSetReadDeadline func(t time.Time) (err error) + OnSetWriteDeadline func(t time.Time) (err error) + OnWrite func(b []byte) (n int, err error) +} + +// Close implements the [net.Conn] interface for *Conn. +func (c *Conn) Close() (err error) { + return c.OnClose() +} + +// LocalAddr implements the [net.Conn] interface for *Conn. +func (c *Conn) LocalAddr() (laddr net.Addr) { + return c.OnLocalAddr() +} + +// Read implements the [net.Conn] interface for *Conn. +func (c *Conn) Read(b []byte) (n int, err error) { + return c.OnRead(b) +} + +// RemoteAddr implements the [net.Conn] interface for *Conn. +func (c *Conn) RemoteAddr() (raddr net.Addr) { + return c.OnRemoteAddr() +} + +// SetDeadline implements the [net.Conn] interface for *Conn. +func (c *Conn) SetDeadline(t time.Time) (err error) { + return c.OnSetDeadline(t) +} + +// SetReadDeadline implements the [net.Conn] interface for *Conn. +func (c *Conn) SetReadDeadline(t time.Time) (err error) { + return c.OnSetReadDeadline(t) +} + +// SetWriteDeadline implements the [net.Conn] interface for *Conn. +func (c *Conn) SetWriteDeadline(t time.Time) (err error) { + return c.OnSetWriteDeadline(t) +} + +// Write implements the [net.Conn] interface for *Conn. +func (c *Conn) Write(b []byte) (n int, err error) { + return c.OnWrite(b) +} + +// type check +var _ net.Listener = (*Listener)(nil) + +// Listener is a [net.Listener] for tests. +type Listener struct { + OnAccept func() (c net.Conn, err error) + OnAddr func() (addr net.Addr) + OnClose func() (err error) +} + +// Accept implements the [net.Listener] interface for *Listener. +func (l *Listener) Accept() (c net.Conn, err error) { + return l.OnAccept() +} + +// Addr implements the [net.Listener] interface for *Listener. +func (l *Listener) Addr() (addr net.Addr) { + return l.OnAddr() +} + +// Close implements the [net.Listener] interface for *Listener. +func (l *Listener) Close() (err error) { + return l.OnClose() +} + +// type check +var _ net.PacketConn = (*PacketConn)(nil) + +// PacketConn is the [net.PacketConn] for tests. +type PacketConn struct { + OnClose func() (err error) + OnLocalAddr func() (laddr net.Addr) + OnReadFrom func(b []byte) (n int, addr net.Addr, err error) + OnSetDeadline func(t time.Time) (err error) + OnSetReadDeadline func(t time.Time) (err error) + OnSetWriteDeadline func(t time.Time) (err error) + OnWriteTo func(b []byte, addr net.Addr) (n int, err error) +} + +// Close implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) Close() (err error) { + return c.OnClose() +} + +// LocalAddr implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) LocalAddr() (laddr net.Addr) { + return c.OnLocalAddr() +} + +// ReadFrom implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + return c.OnReadFrom(b) +} + +// SetDeadline implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) SetDeadline(t time.Time) (err error) { + return c.OnSetDeadline(t) +} + +// SetReadDeadline implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) SetReadDeadline(t time.Time) (err error) { + return c.OnSetReadDeadline(t) +} + +// SetWriteDeadline implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) SetWriteDeadline(t time.Time) (err error) { + return c.OnSetWriteDeadline(t) +} + +// WriteTo implements the [net.PacketConn] interface for *PacketConn. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.OnWriteTo(b, addr) +} + +// Module AdGuardDNS // type check var _ agd.ErrorCollector = (*ErrorCollector)(nil) @@ -39,56 +178,6 @@ func (c *ErrorCollector) Collect(ctx context.Context, err error) { c.OnCollect(ctx, err) } -// type check -var _ agd.ProfileDB = (*ProfileDB)(nil) - -// ProfileDB is an agd.ProfileDB for tests. -type ProfileDB struct { - OnProfileByDeviceID func( - ctx context.Context, - id agd.DeviceID, - ) (p *agd.Profile, d *agd.Device, err error) - OnProfileByIP func( - ctx context.Context, - ip netip.Addr, - ) (p *agd.Profile, d *agd.Device, err error) -} - -// ProfileByDeviceID implements the agd.ProfileDB interface for *ProfileDB. -func (db *ProfileDB) ProfileByDeviceID( - ctx context.Context, - id agd.DeviceID, -) (p *agd.Profile, d *agd.Device, err error) { - return db.OnProfileByDeviceID(ctx, id) -} - -// ProfileByIP implements the agd.ProfileDB interface for *ProfileDB. -func (db *ProfileDB) ProfileByIP( - ctx context.Context, - ip netip.Addr, -) (p *agd.Profile, d *agd.Device, err error) { - return db.OnProfileByIP(ctx, ip) -} - -// type check -var _ agd.ProfileStorage = (*ProfileStorage)(nil) - -// ProfileStorage is a agd.ProfileStorage for tests. -type ProfileStorage struct { - OnProfiles func( - ctx context.Context, - req *agd.PSProfilesRequest, - ) (resp *agd.PSProfilesResponse, err error) -} - -// Profiles implements the agd.ProfileStorage interface for *ProfileStorage. -func (ds *ProfileStorage) Profiles( - ctx context.Context, - req *agd.PSProfilesRequest, -) (resp *agd.PSProfilesResponse, err error) { - return ds.OnProfiles(ctx, req) -} - // type check var _ agd.Refresher = (*Refresher)(nil) @@ -206,7 +295,7 @@ func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) // type check var _ filter.Interface = (*Filter)(nil) -// Filter is a filter.Interface for tests. +// Filter is a [filter.Interface] for tests. type Filter struct { OnFilterRequest func( ctx context.Context, @@ -218,10 +307,9 @@ type Filter struct { resp *dns.Msg, ri *agd.RequestInfo, ) (r filter.Result, err error) - OnClose func() (err error) } -// FilterRequest implements the filter.Interface interface for *Filter. +// FilterRequest implements the [filter.Interface] interface for *Filter. func (f *Filter) FilterRequest( ctx context.Context, req *dns.Msg, @@ -230,7 +318,7 @@ func (f *Filter) FilterRequest( return f.OnFilterRequest(ctx, req, ri) } -// FilterResponse implements the filter.Interface interface for *Filter. +// FilterResponse implements the [filter.Interface] interface for *Filter. func (f *Filter) FilterResponse( ctx context.Context, resp *dns.Msg, @@ -239,21 +327,36 @@ func (f *Filter) FilterResponse( return f.OnFilterResponse(ctx, resp, ri) } -// Close implements the filter.Interface interface for *Filter. -func (f *Filter) Close() (err error) { - return f.OnClose() +// type check +var _ filter.HashMatcher = (*HashMatcher)(nil) + +// HashMatcher is a [filter.HashMatcher] for tests. +type HashMatcher struct { + OnMatchByPrefix func( + ctx context.Context, + host string, + ) (hashes []string, matched bool, err error) +} + +// MatchByPrefix implements the [filter.HashMatcher] interface for *HashMatcher. +func (m *HashMatcher) MatchByPrefix( + ctx context.Context, + host string, +) (hashes []string, matched bool, err error) { + return m.OnMatchByPrefix(ctx, host) } // type check var _ filter.Storage = (*FilterStorage)(nil) -// FilterStorage is an filter.Storage for tests. +// FilterStorage is a [filter.Storage] for tests. type FilterStorage struct { OnFilterFromContext func(ctx context.Context, ri *agd.RequestInfo) (f filter.Interface) OnHasListID func(id agd.FilterListID) (ok bool) } -// FilterFromContext implements the filter.Storage interface for *FilterStorage. +// FilterFromContext implements the [filter.Storage] interface for +// *FilterStorage. func (s *FilterStorage) FilterFromContext( ctx context.Context, ri *agd.RequestInfo, @@ -261,7 +364,7 @@ func (s *FilterStorage) FilterFromContext( return s.OnFilterFromContext(ctx, ri) } -// HasListID implements the filter.Storage interface for *FilterStorage. +// HasListID implements the [filter.Storage] interface for *FilterStorage. func (s *FilterStorage) HasListID(id agd.FilterListID) (ok bool) { return s.OnHasListID(id) } @@ -295,6 +398,73 @@ func (g *GeoIP) Data(host string, ip netip.Addr) (l *agd.Location, err error) { return g.OnData(host, ip) } +// Package profiledb + +// type check +var _ profiledb.Interface = (*ProfileDB)(nil) + +// ProfileDB is a [profiledb.Interface] for tests. +type ProfileDB struct { + OnProfileByDeviceID func( + ctx context.Context, + id agd.DeviceID, + ) (p *agd.Profile, d *agd.Device, err error) + OnProfileByDedicatedIP func( + ctx context.Context, + ip netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) + OnProfileByLinkedIP func( + ctx context.Context, + ip netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) +} + +// ProfileByDeviceID implements the [profiledb.Interface] interface for +// *ProfileDB. +func (db *ProfileDB) ProfileByDeviceID( + ctx context.Context, + id agd.DeviceID, +) (p *agd.Profile, d *agd.Device, err error) { + return db.OnProfileByDeviceID(ctx, id) +} + +// ProfileByDedicatedIP implements the [profiledb.Interface] interface for +// *ProfileDB. +func (db *ProfileDB) ProfileByDedicatedIP( + ctx context.Context, + ip netip.Addr, +) (p *agd.Profile, d *agd.Device, err error) { + return db.OnProfileByDedicatedIP(ctx, ip) +} + +// ProfileByLinkedIP implements the [profiledb.Interface] interface for +// *ProfileDB. +func (db *ProfileDB) ProfileByLinkedIP( + ctx context.Context, + ip netip.Addr, +) (p *agd.Profile, d *agd.Device, err error) { + return db.OnProfileByLinkedIP(ctx, ip) +} + +// type check +var _ profiledb.Storage = (*ProfileStorage)(nil) + +// ProfileStorage is a profiledb.Storage for tests. +type ProfileStorage struct { + OnProfiles func( + ctx context.Context, + req *profiledb.StorageRequest, + ) (resp *profiledb.StorageResponse, err error) +} + +// Profiles implements the [profiledb.Storage] interface for *ProfileStorage. +func (s *ProfileStorage) Profiles( + ctx context.Context, + req *profiledb.StorageRequest, +) (resp *profiledb.StorageResponse, err error) { + return s.OnProfiles(ctx, req) +} + // Package querylog // type check @@ -327,6 +497,39 @@ func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.Fi // Module dnsserver +// Package netext + +var _ netext.ListenConfig = (*ListenConfig)(nil) + +// ListenConfig is a [netext.ListenConfig] for tests. +type ListenConfig struct { + OnListen func(ctx context.Context, network, address string) (l net.Listener, err error) + OnListenPacket func( + ctx context.Context, + network string, + address string, + ) (conn net.PacketConn, err error) +} + +// Listen implements the [netext.ListenConfig] interface for *ListenConfig. +func (c *ListenConfig) Listen( + ctx context.Context, + network string, + address string, +) (l net.Listener, err error) { + return c.OnListen(ctx, network, address) +} + +// ListenPacket implements the [netext.ListenConfig] interface for +// *ListenConfig. +func (c *ListenConfig) ListenPacket( + ctx context.Context, + network string, + address string, +) (conn net.PacketConn, err error) { + return c.OnListenPacket(ctx, network, address) +} + // Package ratelimit // type check diff --git a/internal/agdtime/agdtime.go b/internal/agdtime/agdtime.go new file mode 100644 index 0000000..c0da4c8 --- /dev/null +++ b/internal/agdtime/agdtime.go @@ -0,0 +1,63 @@ +// Package agdtime contains time-related utilities. +package agdtime + +import ( + "encoding" + "time" + + "github.com/AdguardTeam/golibs/errors" +) + +// Location is a wrapper around time.Location that can de/serialize itself from +// and to JSON. +// +// TODO(a.garipov): Move to timeutil. +type Location struct { + time.Location +} + +// LoadLocation is a wrapper around [time.LoadLocation] that returns a +// *Location instead. +func LoadLocation(name string) (l *Location, err error) { + tl, err := time.LoadLocation(name) + if err != nil { + // Don't wrap the error, because this function is a wrapper. + return nil, err + } + + return &Location{ + Location: *tl, + }, nil +} + +// UTC returns [time.UTC] as *Location. +func UTC() (l *Location) { + return &Location{ + Location: *time.UTC, + } +} + +// type check +var _ encoding.TextMarshaler = Location{} + +// MarshalText implements the [encoding.TextMarshaler] interface for Location. +func (l Location) MarshalText() (text []byte, err error) { + return []byte(l.String()), nil +} + +var _ encoding.TextUnmarshaler = (*Location)(nil) + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface for +// *Location. +func (l *Location) UnmarshalText(b []byte) (err error) { + defer func() { err = errors.Annotate(err, "unmarshaling location: %w") }() + + tl, err := time.LoadLocation(string(b)) + if err != nil { + return err + } + + l.Location = *tl + + return nil +} diff --git a/internal/agdtime/agdtime_example_test.go b/internal/agdtime/agdtime_example_test.go new file mode 100644 index 0000000..d160558 --- /dev/null +++ b/internal/agdtime/agdtime_example_test.go @@ -0,0 +1,41 @@ +package agdtime_test + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" +) + +func ExampleLocation() { + var req struct { + TimeZone *agdtime.Location `json:"tmz"` + } + + l, err := agdtime.LoadLocation("Europe/Brussels") + if err != nil { + panic(err) + } + + req.TimeZone = l + buf := &bytes.Buffer{} + err = json.NewEncoder(buf).Encode(req) + if err != nil { + panic(err) + } + + fmt.Print(buf) + + req.TimeZone = nil + err = json.NewDecoder(buf).Decode(&req) + if err != nil { + panic(err) + } + + fmt.Printf("%+v\n", req) + + // Output: + // {"tmz":"Europe/Brussels"} + // {TimeZone:Europe/Brussels} +} diff --git a/internal/backend/profiledb.go b/internal/backend/profiledb.go index e7b7aba..a61544e 100644 --- a/internal/backend/profiledb.go +++ b/internal/backend/profiledb.go @@ -12,10 +12,13 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" + "golang.org/x/exp/slices" ) // Profile Storage @@ -47,7 +50,7 @@ func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage) { } } -// ProfileStorage is the implementation of the [agd.ProfileStorage] interface +// ProfileStorage is the implementation of the [profiledb.Storage] interface // that retrieves the profile and device information from the business logic // backend. It is safe for concurrent use. // @@ -61,13 +64,13 @@ type ProfileStorage struct { } // type check -var _ agd.ProfileStorage = (*ProfileStorage)(nil) +var _ profiledb.Storage = (*ProfileStorage)(nil) -// Profiles implements the [agd.ProfileStorage] interface for *ProfileStorage. +// Profiles implements the [profiledb.Storage] interface for *ProfileStorage. func (s *ProfileStorage) Profiles( ctx context.Context, - req *agd.PSProfilesRequest, -) (resp *agd.PSProfilesResponse, err error) { + req *profiledb.StorageRequest, +) (resp *profiledb.StorageResponse, err error) { q := url.Values{} if !req.SyncTime.IsZero() { syncTimeStr := strconv.FormatInt(req.SyncTime.UnixMilli(), 10) @@ -127,7 +130,12 @@ type v1SettingsRespSchedule struct { Friday *[2]timeutil.Duration `json:"fri"` Saturday *[2]timeutil.Duration `json:"sat"` Sunday *[2]timeutil.Duration `json:"sun"` - TimeZone string `json:"tmz"` + + // TimeZone is the tzdata name of the time zone. + // + // NOTE: Do not use *agdtime.Location here so that lookup failures are + // properly mitigated in [v1SettingsRespParental.toInternal]. + TimeZone string `json:"tmz"` } // v1SettingsRespParental is the structure for decoding the settings.*.parental @@ -146,10 +154,11 @@ type v1SettingsRespParental struct { // v1SettingsRespDevice is the structure for decoding the settings.devices // property of the response from the backend. type v1SettingsRespDevice struct { - LinkedIP *netip.Addr `json:"linked_ip"` - ID string `json:"id"` - Name string `json:"name"` - FilteringEnabled bool `json:"filtering_enabled"` + LinkedIP netip.Addr `json:"linked_ip"` + ID string `json:"id"` + Name string `json:"name"` + DedicatedIPs []netip.Addr `json:"dedicated_ips"` + FilteringEnabled bool `json:"filtering_enabled"` } // v1SettingsRespSettings is the structure for decoding the settings property of @@ -232,12 +241,12 @@ func (p *v1SettingsRespParental) toInternal( sch = &agd.ParentalProtectionSchedule{} // TODO(a.garipov): Cache location lookup results. - sch.TimeZone, err = time.LoadLocation(psch.TimeZone) + sch.TimeZone, err = agdtime.LoadLocation(psch.TimeZone) if err != nil { // Report the error and assume UTC. reportf(ctx, errColl, "settings at index %d: schedule: time zone: %w", settIdx, err) - sch.TimeZone = time.UTC + sch.TimeZone = agdtime.UTC() } sch.Week = &agd.WeeklySchedule{} @@ -330,10 +339,10 @@ func devicesToInternal( errColl agd.ErrorCollector, settIdx int, respDevices []*v1SettingsRespDevice, -) (devices []*agd.Device) { +) (devices []*agd.Device, ids []agd.DeviceID) { l := len(respDevices) if l == 0 { - return nil + return nil, nil } devices = make([]*agd.Device, 0, l) @@ -344,8 +353,11 @@ func devicesToInternal( continue } + // TODO(a.garipov): Consider validating uniqueness of linked and + // dedicated IPs. dev := &agd.Device{ LinkedIP: d.LinkedIP, + DedicatedIPs: slices.Clone(d.DedicatedIPs), FilteringEnabled: d.FilteringEnabled, } @@ -368,10 +380,11 @@ func devicesToInternal( continue } + ids = append(ids, dev.ID) devices = append(devices, dev) } - return devices + return devices, ids } // filterListsToInternal is a helper that converts the filter lists from the @@ -458,12 +471,12 @@ func (r *v1SettingsResp) toInternal( // TODO(a.garipov): Here and in other functions, consider just adding the // error collector to the context. errColl agd.ErrorCollector, -) (pr *agd.PSProfilesResponse) { +) (pr *profiledb.StorageResponse) { if r == nil { return nil } - pr = &agd.PSProfilesResponse{ + pr = &profiledb.StorageResponse{ SyncTime: time.Unix(0, r.SyncTime*1_000_000), Profiles: make([]*agd.Profile, 0, len(r.Settings)), } @@ -476,7 +489,7 @@ func (r *v1SettingsResp) toInternal( continue } - devices := devicesToInternal(ctx, errColl, i, s.Devices) + devices, deviceIDs := devicesToInternal(ctx, errColl, i, s.Devices) rlEnabled, ruleLists := filterListsToInternal(ctx, errColl, i, s.RuleLists) rules := rulesToInternal(ctx, errColl, i, s.CustomRules) @@ -499,12 +512,14 @@ func (r *v1SettingsResp) toInternal( sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled + pr.Devices = append(pr.Devices, devices...) + pr.Profiles = append(pr.Profiles, &agd.Profile{ Parental: parental, BlockingMode: s.BlockingMode, ID: id, UpdateTime: updTime, - Devices: devices, + DeviceIDs: deviceIDs, RuleListIDs: ruleLists, CustomRules: rules, FilteredResponseTTL: fltRespTTL, diff --git a/internal/backend/profiledb_test.go b/internal/backend/profiledb_test.go index 521d234..6abcb22 100644 --- a/internal/backend/profiledb_test.go +++ b/internal/backend/profiledb_test.go @@ -13,8 +13,10 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" "github.com/AdguardTeam/AdGuardDNS/internal/backend" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -61,7 +63,7 @@ func TestProfileStorage_Profiles(t *testing.T) { require.NotNil(t, ds) ctx := context.Background() - req := &agd.PSProfilesRequest{ + req := &profiledb.StorageRequest{ SyncTime: syncTime, } @@ -81,10 +83,10 @@ func TestProfileStorage_Profiles(t *testing.T) { // testProfileResp returns profile resp corresponding with testdata. // // Keep in sync with the testdata one. -func testProfileResp(t *testing.T) *agd.PSProfilesResponse { +func testProfileResp(t *testing.T) (resp *profiledb.StorageResponse) { t.Helper() - wantLoc, err := time.LoadLocation("GMT") + wantLoc, err := agdtime.LoadLocation("GMT") require.NoError(t, err) dayRange := agd.DayRange{ @@ -121,7 +123,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { }, } - want := &agd.PSProfilesResponse{ + want := &profiledb.StorageResponse{ SyncTime: syncTime, Profiles: []*agd.Profile{{ Parental: nil, @@ -130,15 +132,10 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { }, ID: "37f97ee9", UpdateTime: updTime, - Devices: []*agd.Device{{ - ID: "118ffe93", - Name: "Device 1", - FilteringEnabled: true, - }, { - ID: "b9e1a762", - Name: "Device 2", - FilteringEnabled: true, - }}, + DeviceIDs: []agd.DeviceID{ + "118ffe93", + "b9e1a762", + }, RuleListIDs: []agd.FilterListID{"1"}, CustomRules: nil, FilteredResponseTTL: 10 * time.Second, @@ -154,24 +151,12 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { BlockingMode: wantBlockingMode, ID: "83f3ea8f", UpdateTime: updTime, - Devices: []*agd.Device{{ - ID: "0d7724fa", - Name: "Device 1", - FilteringEnabled: true, - }, { - ID: "6d2ac775", - Name: "Device 2", - FilteringEnabled: true, - }, { - ID: "94d4c481", - Name: "Device 3", - FilteringEnabled: true, - }, { - ID: "ada436e3", - LinkedIP: &wantLinkedIP, - Name: "Device 4", - FilteringEnabled: true, - }}, + DeviceIDs: []agd.DeviceID{ + "0d7724fa", + "6d2ac775", + "94d4c481", + "ada436e3", + }, RuleListIDs: []agd.FilterListID{"1"}, CustomRules: []agd.FilterRuleText{"||example.org^"}, FilteredResponseTTL: 3600 * time.Second, @@ -183,6 +168,35 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse { BlockPrivateRelay: false, BlockFirefoxCanary: false, }}, + Devices: []*agd.Device{{ + ID: "118ffe93", + Name: "Device 1", + FilteringEnabled: true, + }, { + ID: "b9e1a762", + Name: "Device 2", + FilteringEnabled: true, + }, { + ID: "0d7724fa", + Name: "Device 1", + FilteringEnabled: true, + }, { + ID: "6d2ac775", + Name: "Device 2", + FilteringEnabled: true, + }, { + ID: "94d4c481", + Name: "Device 3", + DedicatedIPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + FilteringEnabled: true, + }, { + ID: "ada436e3", + LinkedIP: wantLinkedIP, + Name: "Device 4", + FilteringEnabled: true, + }}, } return want diff --git a/internal/backend/testdata/profiles.json b/internal/backend/testdata/profiles.json index 6a937f8..26f3dee 100644 --- a/internal/backend/testdata/profiles.json +++ b/internal/backend/testdata/profiles.json @@ -50,18 +50,23 @@ { "id": "6d2ac775", "name": "Device 2", + "linked_ip": null, "filtering_enabled": true }, { "id": "94d4c481", "name": "Device 3", + "linked_ip": "", + "dedicated_ips": [ + "1.2.3.4" + ], "filtering_enabled": true }, { "id": "ada436e3", "name": "Device 4", - "filtering_enabled": true, - "linked_ip": "1.2.3.4" + "linked_ip": "1.2.3.4", + "filtering_enabled": true } ], "parental": { diff --git a/internal/bindtodevice/bindtodevice.go b/internal/bindtodevice/bindtodevice.go index d1aaa5d..41c19c2 100644 --- a/internal/bindtodevice/bindtodevice.go +++ b/internal/bindtodevice/bindtodevice.go @@ -1,40 +1,6 @@ // Package bindtodevice contains an implementation of the [netext.ListenConfig] // interface that uses Linux's SO_BINDTODEVICE socket option to be able to bind // to a device. -// -// TODO(a.garipov): Finish the package. The current plan is to eventually have -// something like this: -// -// mgr, err := bindtodevice.New() -// err := mgr.Add("wlp3s0_plain_dns", "wlp3s0", 53) -// subnet := netip.MustParsePrefix("1.2.3.0/24") -// lc, err := mgr.ListenConfig("wlp3s0_plain_dns", subnet) -// err := mgr.Start() -// -// Approximate YAML configuration example: -// -// 'interface_listeners': -// # Put listeners into a list so that there is space for future additional -// # settings, such as timeouts and buffer sizes. -// 'list': -// 'iface0_plain_dns': -// 'interface': 'iface0' -// 'port': 53 -// 'iface0_plain_dns_secondary': -// 'interface': 'iface0' -// 'port': 5353 -// # … -// # … -// 'server_groups': -// # … -// 'servers': -// - 'name': 'default_dns' -// # … -// bind_interfaces: -// - 'id': 'iface0_plain_dns' -// 'subnet': '1.2.3.0/24' -// - 'id': 'iface0_plain_dns_secondary' -// 'subnet': '1.2.3.0/24' package bindtodevice import ( diff --git a/internal/bindtodevice/bindtodevice_internal_test.go b/internal/bindtodevice/bindtodevice_internal_test.go index 1925747..d79c0b3 100644 --- a/internal/bindtodevice/bindtodevice_internal_test.go +++ b/internal/bindtodevice/bindtodevice_internal_test.go @@ -2,9 +2,13 @@ package bindtodevice import ( "net" + "net/netip" "time" ) +// Common timeout for tests +const testTimeout = 1 * time.Second + // Common addresses for tests. var ( testLAddr = &net.UDPAddr{ @@ -17,5 +21,8 @@ var ( } ) -// Common timeout for tests -const testTimeout = 1 * time.Second +// Common subnets for tests. +var ( + testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24") + testSubnetIPv6 = netip.MustParsePrefix("1234:5678::/64") +) diff --git a/internal/bindtodevice/bindtodevice_test.go b/internal/bindtodevice/bindtodevice_test.go index 7f70d8c..f138b03 100644 --- a/internal/bindtodevice/bindtodevice_test.go +++ b/internal/bindtodevice/bindtodevice_test.go @@ -1,6 +1,16 @@ package bindtodevice_test -import "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" +import ( + "net/netip" + "testing" + + "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" + "github.com/AdguardTeam/golibs/testutil" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} // Common interface listener IDs for tests const ( @@ -18,3 +28,6 @@ const ( // testIfaceName is the common network interface name for tests. const testIfaceName = "not_a_real_iface0" + +// testSubnetIPv4 is a common subnet for tests. +var testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24") diff --git a/internal/bindtodevice/chanlistenconfig_linux_internal_test.go b/internal/bindtodevice/chanlistenconfig_linux_internal_test.go index 3fb0639..c00a924 100644 --- a/internal/bindtodevice/chanlistenconfig_linux_internal_test.go +++ b/internal/bindtodevice/chanlistenconfig_linux_internal_test.go @@ -11,8 +11,8 @@ import ( ) func TestChanListenConfig(t *testing.T) { - pc := newChanPacketConn(nil, nil, testLAddr) - lsnr := newChanListener(nil, testLAddr) + pc := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr) + lsnr := newChanListener(nil, testSubnetIPv4, testLAddr) c := chanListenConfig{ packetConn: pc, listener: lsnr, diff --git a/internal/bindtodevice/chanlistener_linux.go b/internal/bindtodevice/chanlistener_linux.go index 433e63b..503dedc 100644 --- a/internal/bindtodevice/chanlistener_linux.go +++ b/internal/bindtodevice/chanlistener_linux.go @@ -4,6 +4,7 @@ package bindtodevice import ( "net" + "net/netip" "sync" ) @@ -13,17 +14,22 @@ import ( // Listeners of this type are returned by [chanListenConfig.Listen] and are used // in module dnsserver to make the bind-to-device logic work in DNS-over-TCP. type chanListener struct { - closeOnce *sync.Once - conns chan net.Conn - laddr net.Addr + // mu protects conns (against closure) and isClosed. + mu *sync.Mutex + conns chan net.Conn + laddr net.Addr + subnet netip.Prefix + isClosed bool } // newChanListener returns a new properly initialized *chanListener. -func newChanListener(conns chan net.Conn, laddr net.Addr) (l *chanListener) { +func newChanListener(conns chan net.Conn, subnet netip.Prefix, laddr net.Addr) (l *chanListener) { return &chanListener{ - closeOnce: &sync.Once{}, - conns: conns, - laddr: laddr, + mu: &sync.Mutex{}, + conns: conns, + laddr: laddr, + subnet: subnet, + isClosed: false, } } @@ -46,15 +52,30 @@ func (l *chanListener) Addr() (addr net.Addr) { return l.laddr } // Close implements the [net.Listener] interface for *chanListener. func (l *chanListener) Close() (err error) { - closedNow := false - l.closeOnce.Do(func() { - close(l.conns) - closedNow = true - }) + l.mu.Lock() + defer l.mu.Unlock() - if !closedNow { + if l.isClosed { return wrapConnError(tnChanLsnr, "Close", l.laddr, net.ErrClosed) } + close(l.conns) + l.isClosed = true + return nil } + +// send is a helper method to send a conn to the listener's channel. ok is +// false if the listener is closed. +func (l *chanListener) send(conn net.Conn) (ok bool) { + l.mu.Lock() + defer l.mu.Unlock() + + if l.isClosed { + return false + } + + l.conns <- conn + + return true +} diff --git a/internal/bindtodevice/chanlistener_linux_internal_test.go b/internal/bindtodevice/chanlistener_linux_internal_test.go index da85c27..7ed4086 100644 --- a/internal/bindtodevice/chanlistener_linux_internal_test.go +++ b/internal/bindtodevice/chanlistener_linux_internal_test.go @@ -12,7 +12,7 @@ import ( func TestChanListener_Accept(t *testing.T) { conns := make(chan net.Conn, 1) - l := newChanListener(conns, testLAddr) + l := newChanListener(conns, testSubnetIPv4, testLAddr) // A simple way to have a distinct net.Conn without actually implementing // the entire interface. @@ -32,14 +32,14 @@ func TestChanListener_Accept(t *testing.T) { } func TestChanListener_Addr(t *testing.T) { - l := newChanListener(nil, testLAddr) + l := newChanListener(nil, testSubnetIPv4, testLAddr) got := l.Addr() assert.Equal(t, testLAddr, got) } func TestChanListener_Close(t *testing.T) { conns := make(chan net.Conn) - l := newChanListener(conns, testLAddr) + l := newChanListener(conns, testSubnetIPv4, testLAddr) err := l.Close() assert.NoError(t, err) diff --git a/internal/bindtodevice/chanpacketconn_linux.go b/internal/bindtodevice/chanpacketconn_linux.go index dac6e09..94e3051 100644 --- a/internal/bindtodevice/chanpacketconn_linux.go +++ b/internal/bindtodevice/chanpacketconn_linux.go @@ -5,6 +5,7 @@ package bindtodevice import ( "fmt" "net" + "net/netip" "os" "sync" "time" @@ -19,32 +20,38 @@ import ( // are used in module dnsserver to make the bind-to-device logic work in // DNS-over-UDP. type chanPacketConn struct { - closeOnce *sync.Once - sessions chan *packetSession - laddr net.Addr + // mu protects sessions (against closure) and isClosed. + mu *sync.Mutex + sessions chan *packetSession + + writeRequests chan *packetConnWriteReq // deadlineMu protects readDeadline and writeDeadline. deadlineMu *sync.RWMutex readDeadline time.Time writeDeadline time.Time - writeRequests chan *packetConnWriteReq + laddr net.Addr + subnet netip.Prefix + isClosed bool } // newChanPacketConn returns a new properly initialized *chanPacketConn. func newChanPacketConn( sessions chan *packetSession, + subnet netip.Prefix, writeRequests chan *packetConnWriteReq, laddr net.Addr, ) (c *chanPacketConn) { return &chanPacketConn{ - closeOnce: &sync.Once{}, - sessions: sessions, - laddr: laddr, + mu: &sync.Mutex{}, + sessions: sessions, + writeRequests: writeRequests, deadlineMu: &sync.RWMutex{}, - writeRequests: writeRequests, + laddr: laddr, + subnet: subnet, } } @@ -70,16 +77,16 @@ var _ netext.SessionPacketConn = (*chanPacketConn)(nil) // Close implements the [netext.SessionPacketConn] interface for // *chanPacketConn. func (c *chanPacketConn) Close() (err error) { - closedNow := false - c.closeOnce.Do(func() { - close(c.sessions) - closedNow = true - }) + c.mu.Lock() + defer c.mu.Unlock() - if !closedNow { + if c.isClosed { return wrapConnError(tnChanPConn, "Close", c.laddr, net.ErrClosed) } + close(c.sessions) + c.isClosed = true + return nil } @@ -289,3 +296,18 @@ func sendWithTimer[T any](ch chan<- T, v T, timerCh <-chan time.Time) (err error return os.ErrDeadlineExceeded } } + +// send is a helper method to send a session to the packet connection's channel. +// ok is false if the listener is closed. +func (c *chanPacketConn) send(sess *packetSession) (ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.isClosed { + return false + } + + c.sessions <- sess + + return true +} diff --git a/internal/bindtodevice/chanpacketconn_linux_internal_test.go b/internal/bindtodevice/chanpacketconn_linux_internal_test.go index 7ee22b1..a917f9a 100644 --- a/internal/bindtodevice/chanpacketconn_linux_internal_test.go +++ b/internal/bindtodevice/chanpacketconn_linux_internal_test.go @@ -14,7 +14,7 @@ import ( func TestChanPacketConn_Close(t *testing.T) { sessions := make(chan *packetSession) - c := newChanPacketConn(sessions, nil, testLAddr) + c := newChanPacketConn(sessions, testSubnetIPv4, nil, testLAddr) err := c.Close() assert.NoError(t, err) @@ -23,14 +23,14 @@ func TestChanPacketConn_Close(t *testing.T) { } func TestChanPacketConn_LocalAddr(t *testing.T) { - c := newChanPacketConn(nil, nil, testLAddr) + c := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr) got := c.LocalAddr() assert.Equal(t, testLAddr, got) } func TestChanPacketConn_ReadFromSession(t *testing.T) { sessions := make(chan *packetSession, 1) - c := newChanPacketConn(sessions, nil, testLAddr) + c := newChanPacketConn(sessions, testSubnetIPv4, nil, testLAddr) body := []byte("hello") bodyLen := len(body) @@ -79,7 +79,7 @@ func TestChanPacketConn_ReadFromSession(t *testing.T) { func TestChanPacketConn_WriteToSession(t *testing.T) { sessions := make(chan *packetSession, 1) writes := make(chan *packetConnWriteReq, 1) - c := newChanPacketConn(sessions, writes, testLAddr) + c := newChanPacketConn(sessions, testSubnetIPv4, writes, testLAddr) body := []byte("hello") bodyLen := len(body) @@ -148,7 +148,7 @@ func checkWriteReqAndRespond( } func TestChanPacketConn_deadlines(t *testing.T) { - c := newChanPacketConn(nil, nil, testLAddr) + c := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr) deadline := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) testCases := []struct { diff --git a/internal/bindtodevice/chanindex_linux.go b/internal/bindtodevice/connindex_linux.go similarity index 52% rename from internal/bindtodevice/chanindex_linux.go rename to internal/bindtodevice/connindex_linux.go index 3f4937f..856d376 100644 --- a/internal/bindtodevice/chanindex_linux.go +++ b/internal/bindtodevice/connindex_linux.go @@ -4,33 +4,20 @@ package bindtodevice import ( "fmt" - "net" "net/netip" "golang.org/x/exp/slices" ) -// chanIndex is the data structure that contains the channels, to which the -// [Manager] sends new connections and packets based on their protocol (TCP vs. -// UDP), and subnet. +// connIndex is the data structure that contains the channel listeners and +// packet connections, to which the [Manager] sends new connections and packets +// based on their protocol (TCP vs. UDP), and subnet. // // In both slices a subnet with the largest prefix (the narrowest subnet) is // sorted closer to the beginning. -type chanIndex struct { - packetConns []*indexPacketConn - listeners []*indexListener -} - -// indexPacketConn contains data of a [chanPacketConn] in the index. -type indexPacketConn struct { - channel chan *packetSession - subnet netip.Prefix -} - -// indexListener contains data of a [chanListener] in the index. -type indexListener struct { - channel chan net.Conn - subnet netip.Prefix +type connIndex struct { + packetConns []*chanPacketConn + listeners []*chanListener } // subnetSortsBefore returns true if subnet x sorts before subnet y. @@ -58,26 +45,18 @@ func subnetCompare(x, y netip.Prefix) (cmp int) { } } -// addPacketConnChannel adds the channel to the subnet index. It returns an -// error if there is already one for this subnet. subnet should be masked. +// addPacketConn adds the channel packet connection to the index. It returns an +// error if there is already one for this subnet. c.subnet should be masked. // // TODO(a.garipov): Merge with [addListenerChannel]. -func (idx *chanIndex) addPacketConnChannel( - subnet netip.Prefix, - ch chan *packetSession, -) (err error) { - c := &indexPacketConn{ - channel: ch, - subnet: subnet, - } - - cmpFunc := func(x, y *indexPacketConn) (cmp int) { +func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) { + cmpFunc := func(x, y *chanPacketConn) (cmp int) { return subnetCompare(x.subnet, y.subnet) } newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc) if ok { - return fmt.Errorf("packetconn channel for subnet %s already registered", subnet) + return fmt.Errorf("packetconn channel for subnet %s already registered", c.subnet) } // TODO(a.garipov): Consider using a list for idx.packetConns. Currently, @@ -88,23 +67,18 @@ func (idx *chanIndex) addPacketConnChannel( return nil } -// addListenerChannel adds the channel to the subnet index. It returns an error -// if there is already one for this subnet. subnet should be masked. +// addListener adds the channel listener to the index. It returns an error if +// there is already one for this subnet. l.subnet should be masked. // // TODO(a.garipov): Merge with [addPacketConnChannel]. -func (idx *chanIndex) addListenerChannel(subnet netip.Prefix, ch chan net.Conn) (err error) { - l := &indexListener{ - channel: ch, - subnet: subnet, - } - - cmpFunc := func(x, y *indexListener) (cmp int) { +func (idx *connIndex) addListener(l *chanListener) (err error) { + cmpFunc := func(x, y *chanListener) (cmp int) { return subnetCompare(x.subnet, y.subnet) } newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc) if ok { - return fmt.Errorf("listener channel for subnet %s already registered", subnet) + return fmt.Errorf("listener channel for subnet %s already registered", l.subnet) } // TODO(a.garipov): Consider using a list for idx.listeners. Currently, @@ -115,24 +89,24 @@ func (idx *chanIndex) addListenerChannel(subnet netip.Prefix, ch chan net.Conn) return nil } -// packetConnChannel returns a packet-connection channel which accepts -// connections to local address laddr or nil if there is no such channel -func (idx *chanIndex) packetConnChannel(laddr netip.Addr) (ch chan *packetSession) { - for _, c := range idx.packetConns { +// packetConn returns a channel packet connection which accepts connections to +// local address laddr or nil if there is no such channel +func (idx *connIndex) packetConn(laddr netip.Addr) (c *chanPacketConn) { + for _, c = range idx.packetConns { if c.subnet.Contains(laddr) { - return c.channel + return c } } return nil } -// listenerChannel returns a listener channel which accepts connections to local +// listener returns a channel listener which accepts connections to local // address laddr or nil if there is no such channel -func (idx *chanIndex) listenerChannel(laddr netip.Addr) (ch chan net.Conn) { - for _, l := range idx.listeners { +func (idx *connIndex) listener(laddr netip.Addr) (l *chanListener) { + for _, l = range idx.listeners { if l.subnet.Contains(laddr) { - return l.channel + return l } } diff --git a/internal/bindtodevice/chanindex_linux_test.go b/internal/bindtodevice/connindex_linux_test.go similarity index 100% rename from internal/bindtodevice/chanindex_linux_test.go rename to internal/bindtodevice/connindex_linux_test.go diff --git a/internal/bindtodevice/interfacelistener_linux.go b/internal/bindtodevice/interfacelistener_linux.go index 197b357..0d2f33e 100644 --- a/internal/bindtodevice/interfacelistener_linux.go +++ b/internal/bindtodevice/interfacelistener_linux.go @@ -17,7 +17,7 @@ import ( // interfaceListener contains information about a single interface listener. type interfaceListener struct { - channels *chanIndex + conns *connIndex writeRequests chan *packetConnWriteReq done chan unit listenConf *net.ListenConfig @@ -64,14 +64,16 @@ func (l *interfaceListener) listenTCP(errCh chan<- error) { } laddr := netutil.NetAddrToAddrPort(conn.LocalAddr()) - ch := l.channels.listenerChannel(laddr.Addr()) - if ch == nil { + lsnr := l.conns.listener(laddr.Addr()) + if lsnr == nil { log.Info("%s: no channel for laddr %s", logPrefix, laddr) continue } - ch <- conn + if !lsnr.send(conn) { + log.Info("%s: channel for laddr %s is closed", logPrefix, laddr) + } } } @@ -120,14 +122,16 @@ func (l *interfaceListener) listenUDP(errCh chan<- error) { } laddr := sess.laddr.AddrPort().Addr() - ch := l.channels.packetConnChannel(laddr) - if ch == nil { + chanPConn := l.conns.packetConn(laddr) + if chanPConn == nil { log.Info("%s: no channel for laddr %s", logPrefix, laddr) continue } - ch <- sess + if !chanPConn.send(sess) { + log.Info("%s: channel for laddr %s is closed", logPrefix, laddr) + } } } diff --git a/internal/bindtodevice/interfacestorage.go b/internal/bindtodevice/interfacestorage.go new file mode 100644 index 0000000..58fcd76 --- /dev/null +++ b/internal/bindtodevice/interfacestorage.go @@ -0,0 +1,77 @@ +package bindtodevice + +import ( + "fmt" + "net" + "net/netip" + + "github.com/AdguardTeam/golibs/netutil" +) + +// NetInterface represents a network interface (aka device). +// +// TODO(a.garipov): Consider moving this and InterfaceStorage to netutil. +type NetInterface interface { + Subnets() (subnets []netip.Prefix, err error) +} + +// type check +var _ NetInterface = osInterface{} + +// osInterface is a wapper around [*net.Interface] that implements the +// [NetInterface] interface. +type osInterface struct { + iface *net.Interface +} + +// Subnets implements the [NetInterface] interface for osInterface. +func (osIface osInterface) Subnets() (subnets []netip.Prefix, err error) { + name := osIface.iface.Name + ifaceAddrs, err := osIface.iface.Addrs() + if err != nil { + return nil, fmt.Errorf("getting addrs for interface %s: %w", name, err) + } + + subnets = make([]netip.Prefix, 0, len(ifaceAddrs)) + for _, addr := range ifaceAddrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + return nil, fmt.Errorf("addr for interface %s is %T, not *net.IPNet", name, addr) + } + + var subnet netip.Prefix + subnet, err = netutil.IPNetToPrefixNoMapped(ipNet) + if err != nil { + return nil, fmt.Errorf("converting addr for interface %s: %w", name, err) + } + + subnets = append(subnets, subnet) + } + + return subnets, nil +} + +// InterfaceStorage is the interface for storages of network interfaces (aka +// devices). Its main implementation is [DefaultInterfaceStorage]. +type InterfaceStorage interface { + InterfaceByName(name string) (iface NetInterface, err error) +} + +// type check +var _ InterfaceStorage = DefaultInterfaceStorage{} + +// DefaultInterfaceStorage is the storage that uses the OS's network interfaces. +type DefaultInterfaceStorage struct{} + +// InterfaceByName implements the [InterfaceStorage] interface for +// DefaultInterfaceStorage. +func (DefaultInterfaceStorage) InterfaceByName(name string) (iface NetInterface, err error) { + netIface, err := net.InterfaceByName(name) + if err != nil { + return nil, fmt.Errorf("looking up interface %s: %w", name, err) + } + + return &osInterface{ + iface: netIface, + }, nil +} diff --git a/internal/bindtodevice/manager.go b/internal/bindtodevice/manager.go index 3eebd4c..598ab39 100644 --- a/internal/bindtodevice/manager.go +++ b/internal/bindtodevice/manager.go @@ -5,6 +5,10 @@ import "github.com/AdguardTeam/AdGuardDNS/internal/agd" // ManagerConfig is the configuration structure for [NewManager]. All fields // must be set. type ManagerConfig struct { + // InterfaceStorage is used to get the information about the system's + // network interfaces. Normally, this is [DefaultInterfaceStorage]. + InterfaceStorage InterfaceStorage + // ErrColl is the error collector that is used to collect non-critical // errors. ErrColl agd.ErrorCollector @@ -13,3 +17,14 @@ type ManagerConfig struct { // dispatch TCP connections and UDP sessions. ChannelBufferSize int } + +// ControlConfig is the configuration of socket options. +type ControlConfig struct { + // RcvBufSize defines the size of socket receive buffer in bytes. Default + // is zero (uses system settings). + RcvBufSize int + + // SndBufSize defines the size of socket send buffer in bytes. Default is + // zero (uses system settings). + SndBufSize int +} diff --git a/internal/bindtodevice/manager_linux.go b/internal/bindtodevice/manager_linux.go index 13c79fb..567ad5f 100644 --- a/internal/bindtodevice/manager_linux.go +++ b/internal/bindtodevice/manager_linux.go @@ -18,6 +18,7 @@ import ( // Manager creates individual listeners and dispatches connections to them. type Manager struct { + interfaces InterfaceStorage closeOnce *sync.Once ifaceListeners map[ID]*interfaceListener errColl agd.ErrorCollector @@ -28,6 +29,7 @@ type Manager struct { // NewManager returns a new manager of interface listeners. func NewManager(c *ManagerConfig) (m *Manager) { return &Manager{ + interfaces: c.InterfaceStorage, closeOnce: &sync.Once{}, ifaceListeners: map[ID]*interfaceListener{}, errColl: c.ErrColl, @@ -36,12 +38,25 @@ func NewManager(c *ManagerConfig) (m *Manager) { } } -// Add creates a new interface-listener record in m. +// defaultCtrlConf is the default control config. By default, don't alter +// anything. defaultCtrlConf must not be mutated. +var defaultCtrlConf = &ControlConfig{ + RcvBufSize: 0, + SndBufSize: 0, +} + +// Add creates a new interface-listener record in m. If conf is nil, a default +// configuration is used. // // Add must not be called after Start is called. -func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) { +func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig) (err error) { defer func() { err = errors.Annotate(err, "adding interface listener with id %q: %w", id) }() + _, err = m.interfaces.InterfaceByName(ifaceName) + if err != nil { + return fmt.Errorf("looking up interface %q: %w", ifaceName, err) + } + validateDup := func(lsnrID ID, lsnr *interfaceListener) (lsnrErr error) { lsnrIfaceName, lsnrPort := lsnr.ifaceName, lsnr.port if lsnrID == id { @@ -71,11 +86,15 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) { return err } + if conf == nil { + conf = defaultCtrlConf + } + m.ifaceListeners[id] = &interfaceListener{ - channels: &chanIndex{}, + conns: &connIndex{}, writeRequests: make(chan *packetConnWriteReq, m.chanBufSize), done: m.done, - listenConf: newListenConfig(ifaceName), + listenConf: newListenConfig(ifaceName, conf), errColl: m.errColl, ifaceName: ifaceName, port: port, @@ -90,53 +109,96 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) { // // ListenConfig must not be called after Start is called. func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) { - if masked := subnet.Masked(); subnet != masked { - return nil, fmt.Errorf( - "subnet %s for interface listener %q not masked (expected %s)", + defer func() { + err = errors.Annotate( + err, + "creating listen config for subnet %s and listener with id %q: %w", subnet, id, - masked, ) - } - + }() l, ok := m.ifaceListeners[id] if !ok { - return nil, fmt.Errorf("no listener for interface %q", id) + return nil, errors.Error("no interface listener found") } - connCh := make(chan net.Conn, m.chanBufSize) - err = l.channels.addListenerChannel(subnet, connCh) + err = m.validateIfaceSubnet(l.ifaceName, subnet) if err != nil { - return nil, fmt.Errorf("adding tcp conn channel: %w", err) + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + lsnrCh := make(chan net.Conn, m.chanBufSize) + lsnr := newChanListener(lsnrCh, subnet, &prefixNetAddr{ + prefix: subnet, + network: "tcp", + port: l.port, + }) + + err = l.conns.addListener(lsnr) + if err != nil { + return nil, fmt.Errorf("adding tcp conn: %w", err) } sessCh := make(chan *packetSession, m.chanBufSize) - err = l.channels.addPacketConnChannel(subnet, sessCh) + pConn := newChanPacketConn(sessCh, subnet, l.writeRequests, &prefixNetAddr{ + prefix: subnet, + network: "udp", + port: l.port, + }) + + err = l.conns.addPacketConn(pConn) if err != nil { // Technically shouldn't happen, since [chanIndex.addListenerChannel] // has already checked for duplicates. - return nil, fmt.Errorf("adding udp conn channel: %w", err) + return nil, fmt.Errorf("adding udp conn: %w", err) } return &chanListenConfig{ - packetConn: newChanPacketConn(sessCh, l.writeRequests, &prefixNetAddr{ - prefix: subnet, - network: "udp", - port: l.port, - }), - listener: newChanListener(connCh, &prefixNetAddr{ - prefix: subnet, - network: "tcp", - port: l.port, - }), + packetConn: pConn, + listener: lsnr, }, nil } +// validateIfaceSubnet validates the interface with the name ifaceName exists +// and that it can accept addresses from subnet. +func (m *Manager) validateIfaceSubnet(ifaceName string, subnet netip.Prefix) (err error) { + if masked := subnet.Masked(); subnet != masked { + return fmt.Errorf("subnet not masked (expected %s)", masked) + } + + iface, err := m.interfaces.InterfaceByName(ifaceName) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + ifaceSubnets, err := iface.Subnets() + if err != nil { + return fmt.Errorf("getting subnets: %w", err) + } + + for _, s := range ifaceSubnets { + if s.Contains(subnet.Addr()) && s.Bits() <= subnet.Bits() { + return nil + } + } + + return fmt.Errorf("interface %s does not contain subnet %s", ifaceName, subnet) +} + // type check var _ agd.Service = (*Manager)(nil) -// Start implements the [agd.Service] interface for *Manager. +// Start implements the [agd.Service] interface for *Manager. If m is nil, +// Start returns nil, since this feature is optional. +// +// TODO(a.garipov): Consider an interface solution. func (m *Manager) Start() (err error) { + if m == nil { + return nil + } + numListen := 2 * len(m.ifaceListeners) errCh := make(chan error, numListen) @@ -162,10 +224,17 @@ func (m *Manager) Start() (err error) { return nil } -// Shutdown implements the [agd.Service] interface for *Manager. +// Shutdown implements the [agd.Service] interface for *Manager. If m is nil, +// Shutdown returns nil, since this feature is optional. +// +// TODO(a.garipov): Consider an interface solution. // // TODO(a.garipov): Consider waiting for all sockets to close. func (m *Manager) Shutdown(_ context.Context) (err error) { + if m == nil { + return nil + } + closedNow := false m.closeOnce.Do(func() { close(m.done) diff --git a/internal/bindtodevice/manager_linux_test.go b/internal/bindtodevice/manager_linux_test.go index e33ea39..e7531d8 100644 --- a/internal/bindtodevice/manager_linux_test.go +++ b/internal/bindtodevice/manager_linux_test.go @@ -18,12 +18,46 @@ import ( // TODO(a.garipov): Add tests for other platforms? +// type check +var _ bindtodevice.InterfaceStorage = (*fakeInterfaceStorage)(nil) + +// fakeInterfaceStorage is a fake [bindtodevice.InterfaceStorage] for tests. +type fakeInterfaceStorage struct { + OnInterfaceByName func(name string) (iface bindtodevice.NetInterface, err error) +} + +// InterfaceByName implements the [bindtodevice.InterfaceStorage] interface +// for *fakeInterfaceStorage. +func (s *fakeInterfaceStorage) InterfaceByName( + name string, +) (iface bindtodevice.NetInterface, err error) { + return s.OnInterfaceByName(name) +} + +// type check +var _ bindtodevice.NetInterface = (*fakeInterface)(nil) + +// fakeInterface is a fake [bindtodevice.Interface] for tests. +type fakeInterface struct { + OnSubnets func() (subnets []netip.Prefix, err error) +} + +// Subnets implements the [bindtodevice.Interface] interface for *fakeInterface. +func (iface *fakeInterface) Subnets() (subnets []netip.Prefix, err error) { + return iface.OnSubnets() +} + func TestManager_Add(t *testing.T) { errColl := &agdtest.ErrorCollector{ OnCollect: func(_ context.Context, _ error) { panic("not implemented") }, } m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: &fakeInterfaceStorage{ + OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) { + return nil, nil + }, + }, ErrColl: errColl, ChannelBufferSize: 1, }) @@ -32,22 +66,22 @@ func TestManager_Add(t *testing.T) { // Don't use a table, since the results of these subtests depend on each // other. t.Run("success", func(t *testing.T) { - err := m.Add(testID1, testIfaceName, testPort1) + err := m.Add(testID1, testIfaceName, testPort1, nil) assert.NoError(t, err) }) t.Run("dup_id", func(t *testing.T) { - err := m.Add(testID1, testIfaceName, testPort1) + err := m.Add(testID1, testIfaceName, testPort1, nil) assert.Error(t, err) }) t.Run("dup_iface_port", func(t *testing.T) { - err := m.Add(testID2, testIfaceName, testPort1) + err := m.Add(testID2, testIfaceName, testPort1, nil) assert.Error(t, err) }) t.Run("success_other", func(t *testing.T) { - err := m.Add(testID2, testIfaceName, testPort2) + err := m.Add(testID2, testIfaceName, testPort2, nil) assert.NoError(t, err) }) } @@ -57,17 +91,27 @@ func TestManager_ListenConfig(t *testing.T) { OnCollect: func(_ context.Context, _ error) { panic("not implemented") }, } + subnet := testSubnetIPv4 + ifaceWithSubnet := &fakeInterface{ + OnSubnets: func() (subnets []netip.Prefix, err error) { + return []netip.Prefix{subnet}, nil + }, + } + m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: &fakeInterfaceStorage{ + OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) { + return ifaceWithSubnet, nil + }, + }, ErrColl: errColl, ChannelBufferSize: 1, }) require.NotNil(t, m) - err := m.Add(testID1, testIfaceName, testPort1) + err := m.Add(testID1, testIfaceName, testPort1, nil) require.NoError(t, err) - subnet := netip.MustParsePrefix("1.2.3.0/24") - // Don't use a table, since the results of these subtests depend on each // other. t.Run("not_found", func(t *testing.T) { @@ -94,6 +138,60 @@ func TestManager_ListenConfig(t *testing.T) { assert.Nil(t, lc) assert.Error(t, lcErr) }) + + t.Run("no_subnet", func(t *testing.T) { + ifaceWithoutSubnet := &fakeInterface{ + OnSubnets: func() (subnets []netip.Prefix, err error) { + return nil, nil + }, + } + + noSubnetMgr := bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: &fakeInterfaceStorage{ + OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) { + return ifaceWithoutSubnet, nil + }, + }, + ErrColl: errColl, + ChannelBufferSize: 1, + }) + require.NotNil(t, noSubnetMgr) + + subTestErr := noSubnetMgr.Add(testID1, testIfaceName, testPort1, nil) + require.NoError(t, subTestErr) + + lc, subTestErr := noSubnetMgr.ListenConfig(testID1, subnet) + assert.Nil(t, lc) + assert.Error(t, subTestErr) + }) + + t.Run("narrower_subnet", func(t *testing.T) { + ifaceWithNarrowerSubnet := &fakeInterface{ + OnSubnets: func() (subnets []netip.Prefix, err error) { + narrowerSubnet := netip.PrefixFrom(subnet.Addr(), subnet.Bits()+4) + + return []netip.Prefix{narrowerSubnet}, nil + }, + } + + narrowSubnetMgr := bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: &fakeInterfaceStorage{ + OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) { + return ifaceWithNarrowerSubnet, nil + }, + }, + ErrColl: errColl, + ChannelBufferSize: 1, + }) + require.NotNil(t, narrowSubnetMgr) + + subTestErr := narrowSubnetMgr.Add(testID1, testIfaceName, testPort1, nil) + require.NoError(t, subTestErr) + + lc, subTestErr := narrowSubnetMgr.ListenConfig(testID1, subnet) + assert.Nil(t, lc) + assert.Error(t, subTestErr) + }) } func TestManager(t *testing.T) { @@ -113,13 +211,14 @@ func TestManager(t *testing.T) { } m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: bindtodevice.DefaultInterfaceStorage{}, ErrColl: errColl, ChannelBufferSize: 1, }) require.NotNil(t, m) // TODO(a.garipov): Add support for zero port. - err := m.Add(testID1, ifaceName, testPort1) + err := m.Add(testID1, ifaceName, testPort1, nil) require.NoError(t, err) subnet, err := netutil.IPNetToPrefixNoMapped(&net.IPNet{ diff --git a/internal/bindtodevice/manager_others.go b/internal/bindtodevice/manager_others.go index 457e481..d31f91b 100644 --- a/internal/bindtodevice/manager_others.go +++ b/internal/bindtodevice/manager_others.go @@ -13,12 +13,12 @@ import ( // Manager creates individual listeners and dispatches connections to them. // -// It is only suported on Linux. +// It is only supported on Linux. type Manager struct{} // NewManager returns a new manager of interface listeners. // -// It is only suported on Linux. +// It is only supported on Linux. func NewManager(c *ManagerConfig) (m *Manager) { return &Manager{} } @@ -29,14 +29,16 @@ const errUnsupported errors.Error = "bindtodevice is only supported on linux" // Add creates a new interface-listener record in m. // -// It is only suported on Linux. -func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) { return errUnsupported } +// It is only supported on Linux. +func (m *Manager) Add(id ID, ifaceName string, port uint16, cc *ControlConfig) (err error) { + return errUnsupported +} // ListenConfig returns a new netext.ListenConfig that receives connections from // the interface listener with the given id and the destination addresses of // which fall within subnet. subnet should be masked. // -// It is only suported on Linux. +// It is only supported on Linux. func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) { return nil, errUnsupported } @@ -44,12 +46,26 @@ func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfi // type check var _ agd.Service = (*Manager)(nil) -// Start implements the [agd.Service] interface for *Manager. +// Start implements the [agd.Service] interface for *Manager. If m is nil, +// Start returns nil, since this feature is optional. // -// It is only suported on Linux. -func (m *Manager) Start() (err error) { return errUnsupported } +// It is only supported on Linux. +func (m *Manager) Start() (err error) { + if m == nil { + return nil + } -// Shutdown implements the [agd.Service] interface for *Manager. + return errUnsupported +} + +// Shutdown implements the [agd.Service] interface for *Manager. If m is nil, +// Shutdown returns nil, since this feature is optional. // -// It is only suported on Linux. -func (m *Manager) Shutdown(_ context.Context) (err error) { return errUnsupported } +// It is only supported on Linux. +func (m *Manager) Shutdown(_ context.Context) (err error) { + if m == nil { + return nil + } + + return errUnsupported +} diff --git a/internal/bindtodevice/prefixaddr_linux_internal_test.go b/internal/bindtodevice/prefixaddr_linux_internal_test.go index bc95326..7df5be7 100644 --- a/internal/bindtodevice/prefixaddr_linux_internal_test.go +++ b/internal/bindtodevice/prefixaddr_linux_internal_test.go @@ -3,6 +3,7 @@ package bindtodevice import ( + "fmt" "net/netip" "testing" @@ -11,16 +12,42 @@ import ( func TestPrefixAddr(t *testing.T) { const ( - wantStr = "1.2.3.0:56789/24" + port = 56789 network = "tcp" ) - pa := &prefixNetAddr{ - prefix: netip.MustParsePrefix("1.2.3.0/24"), - network: network, - port: 56789, - } + testCases := []struct { + in *prefixNetAddr + want string + name string + }{{ + in: &prefixNetAddr{ + prefix: testSubnetIPv4, + network: network, + port: port, + }, + want: fmt.Sprintf( + "%s/%d", + netip.AddrPortFrom(testSubnetIPv4.Addr(), port), testSubnetIPv4.Bits(), + ), + name: "ipv4", + }, { + in: &prefixNetAddr{ + prefix: testSubnetIPv6, + network: network, + port: port, + }, + want: fmt.Sprintf( + "%s/%d", + netip.AddrPortFrom(testSubnetIPv6.Addr(), port), testSubnetIPv6.Bits(), + ), + name: "ipv6", + }} - assert.Equal(t, wantStr, pa.String()) - assert.Equal(t, network, pa.Network()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tc.in.String()) + assert.Equal(t, network, tc.in.Network()) + }) + } } diff --git a/internal/bindtodevice/socket_linux.go b/internal/bindtodevice/socket_linux.go index 122fa17..ad0e910 100644 --- a/internal/bindtodevice/socket_linux.go +++ b/internal/bindtodevice/socket_linux.go @@ -12,106 +12,100 @@ import ( "golang.org/x/sys/unix" ) -// newListenConfig returns a [net.ListenConfig] that can bind to a network -// interface (device) by its name. -func newListenConfig(devName string) (lc *net.ListenConfig) { - c := &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) (err error) { - return listenControl(devName, network, address, c) - }, - } +// setSockOptFunc is a function that sets a socket option on fd. +type setSockOptFunc func(fd int) (err error) - return c +// newIntSetSockOptFunc returns an integer socket-option function with the given +// parameters. +func newIntSetSockOptFunc(name string, lvl, opt, val int) (o setSockOptFunc) { + return func(fd int) (err error) { + opErr := unix.SetsockoptInt(fd, lvl, opt, val) + + return errors.Annotate(opErr, "setting %s: %w", name) + } } -// listenControl is used as a [net.ListenConfig.Control] function to set -// additional socket options, including SO_BINDTODEVICE. -func listenControl(devName, network, _ string, c syscall.RawConn) (err error) { - var ctrlFunc func(fd uintptr, devName string) (err error) +// newStringSetSockOptFunc returns a string socket-option function with the +// given parameters. +func newStringSetSockOptFunc(name string, lvl, opt int, val string) (o setSockOptFunc) { + return func(fd int) (err error) { + opErr := unix.SetsockoptString(fd, lvl, opt, val) + + return errors.Annotate(opErr, "setting %s: %w", name) + } +} + +// newListenConfig returns a [net.ListenConfig] that can bind to a network +// interface (device) by its name. ctrlConf must not be nil. +func newListenConfig(devName string, ctrlConf *ControlConfig) (lc *net.ListenConfig) { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) (err error) { + return listenControlWithSO(ctrlConf, devName, network, address, c) + }, + } +} + +// listenControlWithSO is used as a [net.ListenConfig.Control] function to set +// additional socket options. +func listenControlWithSO( + ctrlConf *ControlConfig, + devName string, + network string, + _ string, + c syscall.RawConn, +) (err error) { + opts := []setSockOptFunc{ + newStringSetSockOptFunc("SO_BINDTODEVICE", unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName), + // Use SO_REUSEADDR as well, which is not technically necessary, to + // help with the situation of sockets hanging in CLOSE_WAIT for too + // long. + newIntSetSockOptFunc("SO_REUSEADDR", unix.SOL_SOCKET, unix.SO_REUSEADDR, 1), + newIntSetSockOptFunc("SO_REUSEPORT", unix.SOL_SOCKET, unix.SO_REUSEPORT, 1), + } switch network { case "tcp", "tcp4", "tcp6": - ctrlFunc = setTCPSockOpt + // Socket options for TCP connection already set. Go on. case "udp", "udp4", "udp6": - ctrlFunc = setUDPSockOpt + opts = append( + opts, + newIntSetSockOptFunc("IP_RECVORIGDSTADDR", unix.IPPROTO_IP, unix.IP_RECVORIGDSTADDR, 1), + newIntSetSockOptFunc("IP_FREEBIND", unix.IPPROTO_IP, unix.IP_FREEBIND, 1), + newIntSetSockOptFunc("IPV6_RECVORIGDSTADDR", unix.IPPROTO_IPV6, unix.IPV6_RECVORIGDSTADDR, 1), + newIntSetSockOptFunc("IPV6_FREEBIND", unix.IPPROTO_IPV6, unix.IPV6_FREEBIND, 1), + ) default: return fmt.Errorf("bad network %q", network) } + if ctrlConf.SndBufSize > 0 { + opts = append( + opts, + newIntSetSockOptFunc("SO_SNDBUF", unix.SOL_SOCKET, unix.SO_SNDBUF, ctrlConf.SndBufSize), + ) + } + + if ctrlConf.RcvBufSize > 0 { + opts = append( + opts, + newIntSetSockOptFunc("SO_RCVBUF", unix.SOL_SOCKET, unix.SO_RCVBUF, ctrlConf.RcvBufSize), + ) + } + var opErr error err = c.Control(func(fd uintptr) { - opErr = ctrlFunc(fd, devName) + d := int(fd) + for _, opt := range opts { + opErr = opt(d) + if opErr != nil { + return + } + } }) return errors.WithDeferred(opErr, err) } -// setTCPSockOpt sets the SO_BINDTODEVICE and other socket options for a TCP -// connection. -func setTCPSockOpt(fd uintptr, devName string) (err error) { - defer func() { err = errors.Annotate(err, "setting tcp opts: %w") }() - - fdInt := int(fd) - err = unix.SetsockoptString(fdInt, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName) - if err != nil { - return fmt.Errorf("setting SO_BINDTODEVICE: %w", err) - } - - err = unix.SetsockoptInt(fdInt, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) - if err != nil { - return fmt.Errorf("setting SO_REUSEPORT: %w", err) - } - - return nil -} - -// setUDPSockOpt sets the SO_BINDTODEVICE and other socket options for a UDP -// connection. -func setUDPSockOpt(fd uintptr, devName string) (err error) { - defer func() { err = errors.Annotate(err, "setting udp opts: %w") }() - - fdInt := int(fd) - err = unix.SetsockoptString(fdInt, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName) - if err != nil { - return fmt.Errorf("setting SO_BINDTODEVICE: %w", err) - } - - intOpts := []struct { - name string - level int - opt int - }{{ - name: "SO_REUSEPORT", - level: unix.SOL_SOCKET, - opt: unix.SO_REUSEPORT, - }, { - name: "IP_RECVORIGDSTADDR", - level: unix.IPPROTO_IP, - opt: unix.IP_RECVORIGDSTADDR, - }, { - name: "IP_FREEBIND", - level: unix.IPPROTO_IP, - opt: unix.IP_FREEBIND, - }, { - name: "IPV6_RECVORIGDSTADDR", - level: unix.IPPROTO_IPV6, - opt: unix.IPV6_RECVORIGDSTADDR, - }, { - name: "IPV6_FREEBIND", - level: unix.IPPROTO_IPV6, - opt: unix.IPV6_FREEBIND, - }} - - for _, o := range intOpts { - err = unix.SetsockoptInt(fdInt, o.level, o.opt, 1) - if err != nil { - return fmt.Errorf("setting %s: %w", o.name, err) - } - } - - return nil -} - // readPacketSession is a helper that reads a packet-session data from a UDP // connection. func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err error) { @@ -165,8 +159,11 @@ func sockAddrData(sockAddr unix.Sockaddr) (origDstAddr *net.UDPAddr, respOOB []b Port: sockAddr.Port, } + // Set both addresses to make sure that users receive the correct source + // IP address even when virtual interfaces are involved. pktInfo := &unix.Inet4Pktinfo{ - Addr: sockAddr.Addr, + Addr: sockAddr.Addr, + Spec_dst: sockAddr.Addr, } respOOB = unix.PktInfo4(pktInfo) diff --git a/internal/bindtodevice/socket_linux_internal_test.go b/internal/bindtodevice/socket_linux_internal_test.go index c2bd054..bb062f5 100644 --- a/internal/bindtodevice/socket_linux_internal_test.go +++ b/internal/bindtodevice/socket_linux_internal_test.go @@ -9,6 +9,7 @@ import ( "net/netip" "os" "strings" + "syscall" "testing" "time" @@ -18,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "golang.org/x/sys/unix" ) // TestInterfaceEnvVarName is the environment variable name the presence and @@ -88,7 +90,7 @@ func TestListenControl(t *testing.T) { } ifaceName := iface.Name - lc := newListenConfig(ifaceName) + lc := newListenConfig(ifaceName, &ControlConfig{}) require.NotNil(t, lc) t.Run("tcp", func(t *testing.T) { @@ -380,3 +382,81 @@ func closestIP(t testing.TB, n *net.IPNet, ip net.IP) (closest net.IP) { return nil } + +func TestListenControlWithSO(t *testing.T) { + const ( + sndBufSize = 10000 + rcvBufSize = 20000 + ) + + iface, _ := InterfaceForTests(t) + if iface == nil { + t.Skipf("test %s skipped: please set env var %s", t.Name(), TestInterfaceEnvVarName) + } + + ifaceName := iface.Name + lc := newListenConfig( + ifaceName, + &ControlConfig{ + RcvBufSize: rcvBufSize, + SndBufSize: sndBufSize, + }, + ) + require.NotNil(t, lc) + + type syscallConner interface { + SyscallConn() (c syscall.RawConn, err error) + } + + t.Run("udp", func(t *testing.T) { + c, err := lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0") + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + require.NoError(t, opErr) + + assert.Equal(t, sndBufSize*2, val) + }) + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + require.NoError(t, opErr) + + assert.Equal(t, rcvBufSize*2, val) + }) + require.NoError(t, err) + }) + + t.Run("tcp", func(t *testing.T) { + c, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:0") + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + require.NoError(t, opErr) + + assert.Equal(t, sndBufSize*2, val) + }) + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + require.NoError(t, opErr) + + assert.Equal(t, rcvBufSize*2, val) + }) + require.NoError(t, err) + }) +} diff --git a/internal/cmd/backend.go b/internal/cmd/backend.go index 78e9fc2..c1592c3 100644 --- a/internal/cmd/backend.go +++ b/internal/cmd/backend.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/backend" "github.com/AdguardTeam/AdGuardDNS/internal/billstat" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" ) @@ -76,7 +77,7 @@ func setupBackend( envs *environments, sigHdlr signalHandler, errColl agd.ErrorCollector, -) (profDB *agd.DefaultProfileDB, rec *billstat.RuntimeRecorder, err error) { +) (profDB *profiledb.Default, rec *billstat.RuntimeRecorder, err error) { profStrgConf, billStatConf := conf.toInternal(envs, errColl) rec = billstat.NewRuntimeRecorder(&billstat.RuntimeRecorderConfig{ Uploader: backend.NewBillStat(billStatConf), @@ -103,7 +104,7 @@ func setupBackend( sigHdlr.add(billStatRefr) profStrg := backend.NewProfileStorage(profStrgConf) - profDB, err = agd.NewDefaultProfileDB( + profDB, err = profiledb.New( profStrg, conf.FullRefreshIvl.Duration, envs.ProfilesCachePath, diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index e9de141..f8dde60 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -16,9 +16,9 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc" "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/websvc" @@ -130,11 +130,23 @@ func Main() { fltGroups, err := c.FilteringGroups.toInternal(fltStrg) check(err) + // Network interface listener + + btdCtrlConf, ctrlConf := c.Network.toInternal() + + btdMgr, err := c.InterfaceListeners.toInternal(errColl, btdCtrlConf) + check(err) + + err = btdMgr.Start() + check(err) + + sigHdlr.add(btdMgr) + // Server groups messages := dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, c.Filters.ResponseTTL.Duration) - srvGrps, err := c.ServerGroups.toInternal(messages, fltGroups) + srvGrps, err := c.ServerGroups.toInternal(messages, btdMgr, fltGroups) check(err) // TLS keys logging @@ -173,7 +185,7 @@ func Main() { // Rate limiting consulAllowlistURL := &envs.ConsulAllowlistURL.URL - rateLimiter, err := setupRateLimiter(c.RateLimit, consulAllowlistURL, sigHdlr, errColl) + rateLimiter, connLimiter, err := setupRateLimiter(c.RateLimit, consulAllowlistURL, sigHdlr, errColl) check(err) // GeoIP database @@ -197,24 +209,21 @@ func Main() { // DNS service - metricsListener := prometheus.NewForwardMetricsListener(len(c.Upstream.FallbackServers) + 1) - - upstream, err := c.Upstream.toInternal() + fwdConf, err := c.Upstream.toInternal() check(err) - handler := forward.NewHandler(&forward.HandlerConfig{ - Address: upstream.Server, - Network: upstream.Network, - MetricsListener: metricsListener, - HealthcheckDomainTmpl: c.Upstream.Healthcheck.DomainTmpl, - FallbackAddresses: c.Upstream.FallbackServers, - Timeout: c.Upstream.Timeout.Duration, - HealthcheckBackoffDuration: c.Upstream.Healthcheck.BackoffDuration.Duration, - }, c.Upstream.Healthcheck.Enabled) + handler := forward.NewHandler(fwdConf) + + // TODO(a.garipov): Consider making these configurable via the configuration + // file. + hashStorages := map[string]*hashprefix.Storage{ + filter.GeneralTXTSuffix: safeBrowsingHashes, + filter.AdultBlockingTXTSuffix: adultBlockingHashes, + } dnsConf := &dnssvc.Config{ Messages: messages, - SafeBrowsing: filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes), + SafeBrowsing: hashprefix.NewMatcher(hashStorages), BillStat: billStatRec, ProfileDB: profDB, DNSCheck: dnsCk, @@ -226,21 +235,20 @@ func Main() { Handler: handler, QueryLog: c.buildQueryLog(envs), RuleStat: ruleStat, - Upstream: upstream, RateLimit: rateLimiter, + ConnLimiter: connLimiter, FilteringGroups: fltGroups, ServerGroups: srvGrps, CacheSize: c.Cache.Size, ECSCacheSize: c.Cache.ECSSize, UseECSCache: c.Cache.Type == cacheTypeECS, ResearchMetrics: bool(envs.ResearchMetrics), + ControlConf: ctrlConf, } dnsSvc, err := dnssvc.New(dnsConf) check(err) - sigHdlr.add(dnsSvc) - // Connectivity check err = connectivityCheck(dnsConf, c.ConnectivityCheck) diff --git a/internal/cmd/config.go b/internal/cmd/config.go index af24751..f7475bc 100644 --- a/internal/cmd/config.go +++ b/internal/cmd/config.go @@ -61,6 +61,13 @@ type configuration struct { // ConnectivityCheck is the connectivity check configuration. ConnectivityCheck *connCheckConfig `yaml:"connectivity_check"` + // InterfaceListeners is the configuration for the network interface + // listeners and their common parameters. + InterfaceListeners *interfaceListenersConfig `yaml:"interface_listeners"` + + // Network is the configuration for network listeners. + Network *network `yaml:"network"` + // AdditionalMetricsInfo is extra information, which is exposed by metrics. AdditionalMetricsInfo additionalInfo `yaml:"additional_metrics_info"` @@ -140,6 +147,12 @@ func (c *configuration) validate() (err error) { }, { validate: c.ConnectivityCheck.validate, name: "connectivity_check", + }, { + validate: c.InterfaceListeners.validate, + name: "interface_listeners", + }, { + validate: c.Network.validate, + name: "network", }, { validate: c.AdditionalMetricsInfo.validate, name: "additional_metrics_info", diff --git a/internal/cmd/env.go b/internal/cmd/env.go index 4a51188..fd8d26c 100644 --- a/internal/cmd/env.go +++ b/internal/cmd/env.go @@ -34,8 +34,8 @@ type environments struct { YoutubeSafeSearchURL *agdhttp.URL `env:"YOUTUBE_SAFE_SEARCH_URL,notEmpty"` RuleStatURL *agdhttp.URL `env:"RULESTAT_URL"` - ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yml"` - DNSDBPath string `env:"DNSDB_PATH" envDefault:"./dnsdb.bolt"` + ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"` + DNSDBPath string `env:"DNSDB_PATH"` FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"` ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.json"` GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"` diff --git a/internal/cmd/filter.go b/internal/cmd/filter.go index fca8414..7e8511c 100644 --- a/internal/cmd/filter.go +++ b/internal/cmd/filter.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" ) @@ -51,8 +52,8 @@ func (c *filtersConfig) toInternal( errColl agd.ErrorCollector, resolver agdnet.Resolver, envs *environments, - safeBrowsing *filter.HashPrefix, - adultBlocking *filter.HashPrefix, + safeBrowsing *hashprefix.Filter, + adultBlocking *hashprefix.Filter, ) (conf *filter.DefaultStorageConfig) { return &filter.DefaultStorageConfig{ FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL), diff --git a/internal/cmd/ifacelistener.go b/internal/cmd/ifacelistener.go new file mode 100644 index 0000000..092115b --- /dev/null +++ b/internal/cmd/ifacelistener.go @@ -0,0 +1,102 @@ +package cmd + +import ( + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdmaps" + "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" + "github.com/AdguardTeam/golibs/errors" +) + +// Network interface listener configuration + +// interfaceListenersConfig contains the optional configuration for the network +// interface listeners and their common parameters. +type interfaceListenersConfig struct { + // List is the ID-to-configuration mapping of network interface listeners. + List map[bindtodevice.ID]*interfaceListener `yaml:"list"` + + // ChannelBufferSize is the size of the buffers of the channels used to + // dispatch TCP connections and UDP sessions. + ChannelBufferSize int `yaml:"channel_buffer_size"` +} + +// toInternal converts c to a bindtodevice.Manager. c is assumed to be valid. +func (c *interfaceListenersConfig) toInternal( + errColl agd.ErrorCollector, + ctrlConf *bindtodevice.ControlConfig, +) (m *bindtodevice.Manager, err error) { + if c == nil { + return nil, nil + } + + m = bindtodevice.NewManager(&bindtodevice.ManagerConfig{ + InterfaceStorage: bindtodevice.DefaultInterfaceStorage{}, + ErrColl: errColl, + ChannelBufferSize: c.ChannelBufferSize, + }) + + err = agdmaps.OrderedRangeError( + c.List, + func(id bindtodevice.ID, l *interfaceListener) (addErr error) { + return errors.Annotate(m.Add(id, l.Interface, l.Port, ctrlConf), "adding listener %q: %w", id) + }, + ) + + if err != nil { + return nil, err + } + + return m, nil +} + +// validate returns an error if the network interface listeners configuration is +// invalid. +func (c *interfaceListenersConfig) validate() (err error) { + switch { + case c == nil: + // This configuration is optional. + // + // TODO(a.garipov): Consider making required or not relying on nil + // values. + return nil + case c.ChannelBufferSize <= 0: + return newMustBePositiveError("channel_buffer_size", c.ChannelBufferSize) + case len(c.List) == 0: + return errors.Error("no list") + default: + // Go on. + } + + err = agdmaps.OrderedRangeError( + c.List, + func(id bindtodevice.ID, l *interfaceListener) (lsnrErr error) { + return errors.Annotate(l.validate(), "interface %q: %w", id) + }, + ) + + return err +} + +// interfaceListener contains configuration for a single network interface +// listener. +type interfaceListener struct { + // Interface is the name of the network interface in the system. + Interface string `yaml:"interface"` + + // Port is the port number on which to listen for incoming connections. + Port uint16 `yaml:"port"` +} + +// validate returns an error if the interface listener configuration is invalid. +func (l *interfaceListener) validate() (err error) { + switch { + case l == nil: + return errNilConfig + case l.Port == 0: + return errors.Error("port must not be zero") + case l.Interface == "": + return errors.Error("no interface") + default: + return nil + } +} diff --git a/internal/cmd/network.go b/internal/cmd/network.go new file mode 100644 index 0000000..232784a --- /dev/null +++ b/internal/cmd/network.go @@ -0,0 +1,51 @@ +package cmd + +import ( + "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" +) + +// network defines the network settings. +// +// TODO(a.garipov): Use [datasize.ByteSize] for sizes. +type network struct { + // SndBufSize defines the size of socket send buffer in bytes. Default is + // zero (uses system settings). + SndBufSize int `yaml:"so_sndbuf"` + + // RcvBufSize defines the size of socket receive buffer in bytes. Default + // is zero (uses system settings). + RcvBufSize int `yaml:"so_rcvbuf"` +} + +// validate returns an error if the network configuration is invalid. +func (n *network) validate() (err error) { + if n == nil { + return errNilConfig + } + + if n.SndBufSize < 0 { + return newMustBeNonNegativeError("so_sndbuf", n.SndBufSize) + } + + if n.RcvBufSize < 0 { + return newMustBeNonNegativeError("so_rcvbuf", n.RcvBufSize) + } + + return nil +} + +// toInternal converts n to the bindtodevice control configuration and network +// extension control configuration. +func (n *network) toInternal() (bc *bindtodevice.ControlConfig, nc *netext.ControlConfig) { + bc = &bindtodevice.ControlConfig{ + SndBufSize: n.SndBufSize, + RcvBufSize: n.RcvBufSize, + } + nc = &netext.ControlConfig{ + SndBufSize: n.SndBufSize, + RcvBufSize: n.RcvBufSize, + } + + return bc, nc +} diff --git a/internal/cmd/ratelimit.go b/internal/cmd/ratelimit.go index e4834b0..ccc107a 100644 --- a/internal/cmd/ratelimit.go +++ b/internal/cmd/ratelimit.go @@ -6,8 +6,11 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" + "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter" "github.com/AdguardTeam/AdGuardDNS/internal/consul" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/timeutil" "github.com/c2h5oh/datasize" ) @@ -19,6 +22,10 @@ type rateLimitConfig struct { // AllowList is the allowlist of clients. Allowlist *allowListConfig `yaml:"allowlist"` + // ConnectionLimit is the configuration for the limits on stream + // connections. + ConnectionLimit *connLimitConfig `yaml:"connection_limit"` + // Rate limit options for IPv4 addresses. IPv4 *rateLimitOptions `yaml:"ipv4"` @@ -97,12 +104,18 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba // validate returns an error if the safe rate limiting configuration is invalid. func (c *rateLimitConfig) validate() (err error) { - if c == nil { + switch { + case c == nil: return errNilConfig - } else if c.Allowlist == nil { + case c.Allowlist == nil: return fmt.Errorf("allowlist: %w", errNilConfig) } + err = c.ConnectionLimit.validate() + if err != nil { + return fmt.Errorf("connection_limit: %w", err) + } + err = c.IPv4.validate() if err != nil { return fmt.Errorf("ipv4: %w", err) @@ -129,16 +142,16 @@ func setupRateLimiter( consulAllowlist *url.URL, sigHdlr signalHandler, errColl agd.ErrorCollector, -) (rateLimiter *ratelimit.BackOff, err error) { +) (rateLimiter *ratelimit.BackOff, connLimiter *connlimiter.Limiter, err error) { allowSubnets, err := agdnet.ParseSubnets(conf.Allowlist.List...) if err != nil { - return nil, fmt.Errorf("parsing allowlist subnets: %w", err) + return nil, nil, fmt.Errorf("parsing allowlist subnets: %w", err) } allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil) refresher, err := consul.NewAllowlistRefresher(allowlist, consulAllowlist) if err != nil { - return nil, fmt.Errorf("creating allowlist refresher: %w", err) + return nil, nil, fmt.Errorf("creating allowlist refresher: %w", err) } refr := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{ @@ -152,10 +165,67 @@ func setupRateLimiter( }) err = refr.Start() if err != nil { - return nil, fmt.Errorf("starting allowlist refresher: %w", err) + return nil, nil, fmt.Errorf("starting allowlist refresher: %w", err) } sigHdlr.add(refr) - return ratelimit.NewBackOff(conf.toInternal(allowlist)), nil + return ratelimit.NewBackOff(conf.toInternal(allowlist)), conf.ConnectionLimit.toInternal(), nil +} + +// connLimitConfig is the configuration structure for the stream-connection +// limiter. +type connLimitConfig struct { + // Stop is the point at which the limiter stops accepting new connections. + // Once the number of active connections reaches this limit, new connections + // wait for the number to decrease below Resume. + // + // Stop must be greater than zero and greater than or equal to Resume. + Stop uint64 `yaml:"stop"` + + // Resume is the point at which the limiter starts accepting new connections + // again. + // + // Resume must be greater than zero and less than or equal to Stop. + Resume uint64 `yaml:"resume"` + + // Enabled, if true, enables stream-connection limiting. + Enabled bool `yaml:"enabled"` +} + +// toInternal converts c to the connection limiter to use. c is assumed to be +// valid. +func (c *connLimitConfig) toInternal() (l *connlimiter.Limiter) { + if !c.Enabled { + return nil + } + + l, err := connlimiter.New(&connlimiter.Config{ + Stop: c.Stop, + Resume: c.Resume, + }) + if err != nil { + panic(err) + } + + metrics.ConnLimiterLimits.WithLabelValues("stop").Set(float64(c.Stop)) + metrics.ConnLimiterLimits.WithLabelValues("resume").Set(float64(c.Resume)) + + return l +} + +// validate returns an error if the connection limit configuration is invalid. +func (c *connLimitConfig) validate() (err error) { + switch { + case c == nil: + return errNilConfig + case !c.Enabled: + return nil + case c.Stop == 0: + return newMustBePositiveError("stop", c.Stop) + case c.Resume > c.Stop: + return errors.Error("resume: must be less than or equal to stop") + default: + return nil + } } diff --git a/internal/cmd/safebrowsing.go b/internal/cmd/safebrowsing.go index e8cc36a..55f9503 100644 --- a/internal/cmd/safebrowsing.go +++ b/internal/cmd/safebrowsing.go @@ -7,8 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" - "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" @@ -45,13 +44,13 @@ func (c *safeBrowsingConfig) toInternal( resolver agdnet.Resolver, id agd.FilterListID, cacheDir string, -) (fltConf *filter.HashPrefixConfig, err error) { - hashes, err := hashstorage.New("") +) (fltConf *hashprefix.FilterConfig, err error) { + hashes, err := hashprefix.NewStorage("") if err != nil { return nil, err } - return &filter.HashPrefixConfig{ + return &hashprefix.FilterConfig{ Hashes: hashes, URL: netutil.CloneURL(&c.URL.URL), ErrColl: errColl, @@ -95,13 +94,13 @@ func setupHashPrefixFilter( cachePath string, sigHdlr signalHandler, errColl agd.ErrorCollector, -) (strg *hashstorage.Storage, flt *filter.HashPrefix, err error) { +) (strg *hashprefix.Storage, flt *hashprefix.Filter, err error) { fltConf, err := conf.toInternal(errColl, resolver, id, cachePath) if err != nil { return nil, nil, fmt.Errorf("configuring hash prefix filter %s: %w", id, err) } - flt, err = filter.NewHashPrefix(fltConf) + flt, err = hashprefix.NewFilter(fltConf) if err != nil { return nil, nil, fmt.Errorf("creating hash prefix filter %s: %w", id, err) } diff --git a/internal/cmd/server.go b/internal/cmd/server.go index 56c64f7..1aea133 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -5,6 +5,8 @@ import ( "net/netip" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/stringutil" @@ -14,10 +16,18 @@ import ( // toInternal returns the configuration of DNS servers for a single server // group. srvs is assumed to be valid. -func (srvs servers) toInternal(tlsConfig *agd.TLS) (dnsSrvs []*agd.Server, err error) { +func (srvs servers) toInternal( + tlsConfig *agd.TLS, + btdMgr *bindtodevice.Manager, +) (dnsSrvs []*agd.Server, err error) { dnsSrvs = make([]*agd.Server, 0, len(srvs)) for _, srv := range srvs { - bindData := srv.bindData() + var bindData []*agd.ServerBindData + bindData, err = srv.bindData(btdMgr) + if err != nil { + return nil, fmt.Errorf("server %q: %w", srv.Name, err) + } + name := agd.ServerName(srv.Name) switch p := srv.Protocol; p { case srvProtoDNS: @@ -158,27 +168,56 @@ type server struct { // Protocol is the protocol of the server. Protocol serverProto `yaml:"protocol"` - // BindAddresses are addresses this server binds to. + // BindAddresses are addresses this server binds to. If BindAddresses is + // set, BindInterfaces must not be set. BindAddresses []netip.AddrPort `yaml:"bind_addresses"` + // BindInterfaces are network interface data for this server to bind to. If + // BindInterfaces is set, BindAddresses must not be set. + BindInterfaces []*serverBindInterface `yaml:"bind_interfaces"` + // LinkedIPEnabled shows if the linked IP addresses should be used to detect // profiles on this server. LinkedIPEnabled bool `yaml:"linked_ip_enabled"` } // bindData returns the socket binding data for this server. -func (s *server) bindData() (bindData []*agd.ServerBindData) { - addrs := s.BindAddresses - bindData = make([]*agd.ServerBindData, 0, len(addrs)) - for _, addr := range addrs { +func (s *server) bindData( + btdMgr *bindtodevice.Manager, +) (bindData []*agd.ServerBindData, err error) { + if addrs := s.BindAddresses; len(addrs) > 0 { + bindData = make([]*agd.ServerBindData, 0, len(addrs)) + for _, addr := range addrs { + bindData = append(bindData, &agd.ServerBindData{ + AddrPort: addr, + }) + } + + return bindData, nil + } + + if btdMgr == nil { + err = errors.Error("bind_interfaces are only supported when interface_listeners are set") + + return nil, err + } + + ifaces := s.BindInterfaces + bindData = make([]*agd.ServerBindData, 0, len(ifaces)) + for i, iface := range s.BindInterfaces { + var lc netext.ListenConfig + lc, err = btdMgr.ListenConfig(iface.ID, iface.Subnet) + if err != nil { + return nil, fmt.Errorf("bind_interface at index %d: %w", i, err) + } + bindData = append(bindData, &agd.ServerBindData{ - AddrPort: addr, + ListenConfig: lc, + Address: string(iface.ID), }) } - // TODO(a.garipov): Support bind_interfaces. - - return bindData + return bindData, nil } // validate returns an error if the configuration is invalid. @@ -188,13 +227,12 @@ func (s *server) validate() (err error) { return errNilConfig case s.Name == "": return errors.Error("no name") - case len(s.BindAddresses) == 0: - return errors.Error("no bind_addresses") } - err = validateAddrs(s.BindAddresses) + err = s.validateBindData() if err != nil { - return fmt.Errorf("bind_addresses: %w", err) + // Don't wrap the error, because it's informative enough as is. + return err } err = s.Protocol.validate() @@ -209,3 +247,62 @@ func (s *server) validate() (err error) { return nil } + +// validateBindData returns an error if the server's binding data aren't valid. +func (s *server) validateBindData() (err error) { + bindAddrsSet, bindIfacesSet := len(s.BindAddresses) > 0, len(s.BindInterfaces) > 0 + if bindAddrsSet { + if bindIfacesSet { + return errors.Error("bind_addresses and bind_interfaces cannot both be set") + } + + err = validateAddrs(s.BindAddresses) + if err != nil { + return fmt.Errorf("bind_addresses: %w", err) + } + + return nil + } + + if !bindIfacesSet { + return errors.Error("neither bind_addresses nor bind_interfaces is set") + } + + if s.Protocol != srvProtoDNS { + return fmt.Errorf( + "bind_interfaces: only supported for protocol %q, got %q", + srvProtoDNS, + s.Protocol, + ) + } + + for i, bindIface := range s.BindInterfaces { + err = bindIface.validate() + if err != nil { + return fmt.Errorf("bind_interfaces: at index %d: %w", i, err) + } + } + + return nil +} + +// serverBindInterface contains the data for a network interface binding. +type serverBindInterface struct { + ID bindtodevice.ID `yaml:"id"` + Subnet netip.Prefix `yaml:"subnet"` +} + +// validate returns an error if the network interface binding configuration is +// invalid. +func (c *serverBindInterface) validate() (err error) { + switch { + case c == nil: + return errNilConfig + case c.ID == "": + return errors.Error("no id") + case !c.Subnet.IsValid(): + return errors.Error("bad subnet") + default: + return nil + } +} diff --git a/internal/cmd/servergroup.go b/internal/cmd/servergroup.go index aee2e32..70b35d0 100644 --- a/internal/cmd/servergroup.go +++ b/internal/cmd/servergroup.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/stringutil" @@ -19,6 +20,7 @@ type serverGroups []*serverGroup // service. srvGrps is assumed to be valid. func (srvGrps serverGroups) toInternal( messages *dnsmsg.Constructor, + btdMgr *bindtodevice.Manager, fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup, ) (svcSrvGrps []*agd.ServerGroup, err error) { svcSrvGrps = make([]*agd.ServerGroup, len(srvGrps)) @@ -42,7 +44,7 @@ func (srvGrps serverGroups) toInternal( FilteringGroup: fltGrpID, } - svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf) + svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf, btdMgr) if err != nil { return nil, fmt.Errorf("server group %q: %w", g.Name, err) } diff --git a/internal/cmd/upstream.go b/internal/cmd/upstream.go index a1171eb..9380031 100644 --- a/internal/cmd/upstream.go +++ b/internal/cmd/upstream.go @@ -6,9 +6,11 @@ import ( "net/netip" "net/url" "strings" + "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/timeutil" ) @@ -33,18 +35,32 @@ type upstreamConfig struct { } // toInternal converts c to the data storage configuration for the DNS server. -func (c *upstreamConfig) toInternal() (conf *agd.Upstream, err error) { - net, addrPort, err := splitUpstreamURL(c.Server) +func (c *upstreamConfig) toInternal() (fwdConf *forward.HandlerConfig, err error) { + network, addrPort, err := splitUpstreamURL(c.Server) if err != nil { return nil, err } - return &agd.Upstream{ - Server: addrPort, - Network: net, - FallbackServers: c.FallbackServers, - Timeout: c.Timeout.Duration, - }, nil + fallbacks := c.FallbackServers + metricsListener := prometheus.NewForwardMetricsListener(len(fallbacks) + 1) + + var hcInit time.Duration + if c.Healthcheck.Enabled { + hcInit = c.Healthcheck.Timeout.Duration + } + + fwdConf = &forward.HandlerConfig{ + Address: addrPort, + Network: network, + MetricsListener: metricsListener, + HealthcheckDomainTmpl: c.Healthcheck.DomainTmpl, + FallbackAddresses: c.FallbackServers, + Timeout: c.Timeout.Duration, + HealthcheckBackoffDuration: c.Healthcheck.BackoffDuration.Duration, + HealthcheckInitDuration: hcInit, + } + + return fwdConf, nil } // validate returns an error if the upstream configuration is invalid. diff --git a/internal/cmd/websvc.go b/internal/cmd/websvc.go index 5e6834f..be3e6a0 100644 --- a/internal/cmd/websvc.go +++ b/internal/cmd/websvc.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/websvc" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" ) @@ -413,8 +414,8 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) { // Check Content-Type here as opposed to in validate, because we need // all keys to be canonicalized first. - if file.Headers.Get(agdhttp.HdrNameContentType) == "" { - return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required") + if file.Headers.Get(httphdr.ContentType) == "" { + return nil, errors.Error("content: " + httphdr.ContentType + " header is required") } return file, nil diff --git a/internal/connlimiter/conn.go b/internal/connlimiter/conn.go new file mode 100644 index 0000000..eb61a9a --- /dev/null +++ b/internal/connlimiter/conn.go @@ -0,0 +1,51 @@ +package connlimiter + +import ( + "net" + "sync/atomic" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" + "github.com/AdguardTeam/AdGuardDNS/internal/optlog" + "github.com/AdguardTeam/golibs/errors" +) + +// limitConn is a wrapper for a stream connection that decreases the counter +// value on close. +// +// See https://pkg.go.dev/golang.org/x/net/netutil#LimitListener. +type limitConn struct { + net.Conn + + decrement func() + start time.Time + serverInfo dnsserver.ServerInfo + isClosed atomic.Bool +} + +// Close closes the underlying connection and decrements the counter. +func (c *limitConn) Close() (err error) { + defer func() { err = errors.Annotate(err, "limit conn: %w") }() + + if !c.isClosed.CompareAndSwap(false, true) { + return net.ErrClosed + } + + // Close the connection immediately and wait for the counter decrement and + // metrics later. + err = c.Conn.Close() + + connLife := time.Since(c.start).Seconds() + name := c.serverInfo.Name + optlog.Debug3("connlimiter: %s: closed conn from %s after %fs", name, c.RemoteAddr(), connLife) + metrics.StreamConnLifeDuration.WithLabelValues( + name, + c.serverInfo.Proto.String(), + c.serverInfo.Addr, + ).Observe(connLife) + + c.decrement() + + return err +} diff --git a/internal/connlimiter/counter.go b/internal/connlimiter/counter.go new file mode 100644 index 0000000..e91c66a --- /dev/null +++ b/internal/connlimiter/counter.go @@ -0,0 +1,35 @@ +package connlimiter + +// counter is the simultaneous stream-connection counter. It stops accepting +// new connections once it reaches stop and resumes when the number of active +// connections goes back to resume. +// +// Note that current is the number of both active stream-connections as well as +// goroutines that are currently in the process of accepting a new connection +// but haven't accepted one yet. +type counter struct { + current uint64 + stop uint64 + resume uint64 + isAccepting bool +} + +// increment tries to add the connection to the current active connection count. +// If the counter does not accept new connections, shouldAccept is false. +func (c *counter) increment() (shouldAccept bool) { + if !c.isAccepting { + return false + } + + c.current++ + c.isAccepting = c.current < c.stop + + return true +} + +// decrement decreases the number of current active connections. +func (c *counter) decrement() { + c.current-- + + c.isAccepting = c.isAccepting || c.current <= c.resume +} diff --git a/internal/connlimiter/counter_internal_test.go b/internal/connlimiter/counter_internal_test.go new file mode 100644 index 0000000..2d63480 --- /dev/null +++ b/internal/connlimiter/counter_internal_test.go @@ -0,0 +1,42 @@ +package connlimiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + t.Run("same", func(t *testing.T) { + c := &counter{ + current: 0, + stop: 1, + resume: 1, + isAccepting: true, + } + + assert.True(t, c.increment()) + assert.False(t, c.increment()) + + c.decrement() + assert.True(t, c.increment()) + assert.False(t, c.increment()) + }) + + t.Run("more", func(t *testing.T) { + c := &counter{ + current: 0, + stop: 2, + resume: 1, + isAccepting: true, + } + + assert.True(t, c.increment()) + assert.True(t, c.increment()) + assert.False(t, c.increment()) + + c.decrement() + assert.True(t, c.increment()) + assert.False(t, c.increment()) + }) +} diff --git a/internal/connlimiter/limiter.go b/internal/connlimiter/limiter.go new file mode 100644 index 0000000..797345b --- /dev/null +++ b/internal/connlimiter/limiter.go @@ -0,0 +1,73 @@ +package connlimiter + +import ( + "fmt" + "net" + "sync" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" +) + +// Config is the configuration structure for the stream-connection limiter. +type Config struct { + // Stop is the point at which the limiter stops accepting new connections. + // Once the number of active connections reaches this limit, new connections + // wait for the number to decrease to or below Resume. + // + // Stop must be greater than zero and greater than or equal to Resume. + Stop uint64 + + // Resume is the point at which the limiter starts accepting new connections + // again. + // + // Resume must be greater than zero and less than or equal to Stop. + Resume uint64 +} + +// Limiter is the stream-connection limiter. +type Limiter struct { + // counterCond is the shared condition variable that protects counter. + counterCond *sync.Cond + + // counter is the shared counter of active stream-connections. + counter *counter +} + +// New returns a new *Limiter. +func New(c *Config) (l *Limiter, err error) { + if c == nil || c.Stop == 0 || c.Resume > c.Stop { + return nil, fmt.Errorf("bad limiter config: %+v", c) + } + + return &Limiter{ + counterCond: sync.NewCond(&sync.Mutex{}), + counter: &counter{ + current: 0, + stop: c.Stop, + resume: c.Resume, + isAccepting: true, + }, + }, nil +} + +// Limit wraps lsnr to control the number of active connections. srvInfo is +// used for logging and metrics. +func (l *Limiter) Limit(lsnr net.Listener, srvInfo dnsserver.ServerInfo) (limited net.Listener) { + name, addr := srvInfo.Name, srvInfo.Addr + proto := srvInfo.Proto.String() + + return &limitListener{ + Listener: lsnr, + + counterCond: l.counterCond, + counter: l.counter, + + serverInfo: srvInfo, + + activeGauge: metrics.ConnLimiterActiveStreamConns.WithLabelValues(name, proto, addr), + waitingHist: metrics.StreamConnWaitDuration.WithLabelValues(name, proto, addr), + + isClosed: false, + } +} diff --git a/internal/connlimiter/limiter_test.go b/internal/connlimiter/limiter_test.go new file mode 100644 index 0000000..568e974 --- /dev/null +++ b/internal/connlimiter/limiter_test.go @@ -0,0 +1,122 @@ +package connlimiter_test + +import ( + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second + +// testServerInfo is the common server information for tests. +var testServerInfo = dnsserver.ServerInfo{ + Name: "test_server", + Addr: "127.0.0.1:0", + Proto: agd.ProtoDoT, +} + +func TestLimiter(t *testing.T) { + l, err := connlimiter.New(&connlimiter.Config{ + Stop: 1, + Resume: 1, + }) + require.NoError(t, err) + + conn := &agdtest.Conn{ + OnClose: func() (err error) { return nil }, + OnLocalAddr: func() (laddr net.Addr) { panic("not implemented") }, + OnRead: func(b []byte) (n int, err error) { panic("not implemented") }, + OnRemoteAddr: func() (addr net.Addr) { + return &net.TCPAddr{ + IP: netutil.IPv4Localhost().AsSlice(), + Port: 1234, + } + }, + OnSetDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnSetReadDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnSetWriteDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnWrite: func(b []byte) (n int, err error) { panic("not implemented") }, + } + + lsnr := &agdtest.Listener{ + OnAccept: func() (c net.Conn, err error) { return conn, nil }, + OnAddr: func() (addr net.Addr) { + return &net.TCPAddr{ + IP: netutil.IPv4Localhost().AsSlice(), + Port: 853, + } + }, + OnClose: func() (err error) { return nil }, + } + + limited := l.Limit(lsnr, testServerInfo) + + // Accept one connection. + gotConn, err := limited.Accept() + require.NoError(t, err) + + // Try accepting another connection. This should block until gotConn is + // closed. + otherStarted, otherListened := make(chan struct{}, 1), make(chan struct{}, 1) + go func() { + pt := &testutil.PanicT{} + + otherStarted <- struct{}{} + + otherConn, otherErr := limited.Accept() + require.NoError(pt, otherErr) + + otherListened <- struct{}{} + + require.NoError(pt, otherConn.Close()) + }() + + // Wait for the other goroutine to start. + testutil.RequireReceive(t, otherStarted, testTimeout) + + // Assert that the other connection hasn't been accepted. + var otherAccepted bool + select { + case <-otherListened: + otherAccepted = true + default: + otherAccepted = false + } + assert.False(t, otherAccepted) + + require.NoError(t, gotConn.Close()) + + // Check that double close causes an error. + assert.ErrorIs(t, gotConn.Close(), net.ErrClosed) + + testutil.RequireReceive(t, otherListened, testTimeout) + + err = limited.Close() + require.NoError(t, err) + + // Check that double close causes an error. + assert.ErrorIs(t, limited.Close(), net.ErrClosed) +} + +func TestLimiter_badConf(t *testing.T) { + l, err := connlimiter.New(&connlimiter.Config{ + Stop: 1, + Resume: 2, + }) + assert.Nil(t, l) + assert.Error(t, err) +} diff --git a/internal/connlimiter/listenconfig.go b/internal/connlimiter/listenconfig.go new file mode 100644 index 0000000..a30e689 --- /dev/null +++ b/internal/connlimiter/listenconfig.go @@ -0,0 +1,54 @@ +package connlimiter + +import ( + "context" + "net" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" +) + +// type check +var _ netext.ListenConfig = (*ListenConfig)(nil) + +// ListenConfig is a [netext.ListenConfig] that uses a [*Limiter] to limit the +// number of active stream-connections. +type ListenConfig struct { + listenConfig netext.ListenConfig + limiter *Limiter +} + +// NewListenConfig returns a new netext.ListenConfig that uses l to limit the +// number of active stream-connections. +func NewListenConfig(c netext.ListenConfig, l *Limiter) (limited *ListenConfig) { + return &ListenConfig{ + listenConfig: c, + limiter: l, + } +} + +// ListenPacket implements the [netext.ListenConfig] interface for +// *ListenConfig. +func (c *ListenConfig) ListenPacket( + ctx context.Context, + network string, + address string, +) (conn net.PacketConn, err error) { + return c.listenConfig.ListenPacket(ctx, network, address) +} + +// Listen implements the [netext.ListenConfig] interface for *ListenConfig. +// Listen returns a net.Listener wrapped by c's limiter. ctx must contain a +// [dnsserver.ServerInfo]. +func (c *ListenConfig) Listen( + ctx context.Context, + network string, + address string, +) (l net.Listener, err error) { + l, err = c.listenConfig.Listen(ctx, network, address) + if err != nil { + return nil, err + } + + return c.limiter.Limit(l, dnsserver.MustServerInfoFromContext(ctx)), nil +} diff --git a/internal/connlimiter/listenconfig_test.go b/internal/connlimiter/listenconfig_test.go new file mode 100644 index 0000000..58a2e4e --- /dev/null +++ b/internal/connlimiter/listenconfig_test.go @@ -0,0 +1,75 @@ +package connlimiter_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListenConfig(t *testing.T) { + pc := &agdtest.PacketConn{ + OnClose: func() (err error) { panic("not implemented") }, + OnLocalAddr: func() (laddr net.Addr) { panic("not implemented") }, + OnReadFrom: func(b []byte) (n int, addr net.Addr, err error) { panic("not implemented") }, + OnSetDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnSetReadDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnSetWriteDeadline: func(t time.Time) (err error) { panic("not implemented") }, + OnWriteTo: func(b []byte, addr net.Addr) (n int, err error) { panic("not implemented") }, + } + + lsnr := &agdtest.Listener{ + OnAccept: func() (c net.Conn, err error) { panic("not implemented") }, + OnAddr: func() (addr net.Addr) { panic("not implemented") }, + OnClose: func() (err error) { return nil }, + } + + c := &agdtest.ListenConfig{ + OnListen: func( + ctx context.Context, + network string, + address string, + ) (l net.Listener, err error) { + return lsnr, nil + }, + OnListenPacket: func( + ctx context.Context, + network string, + address string, + ) (conn net.PacketConn, err error) { + return pc, nil + }, + } + + l, err := connlimiter.New(&connlimiter.Config{ + Stop: 1, + Resume: 1, + }) + require.NoError(t, err) + + limited := connlimiter.NewListenConfig(c, l) + + ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo) + gotLsnr, err := limited.Listen(ctx, "", "") + require.NoError(t, err) + + // TODO(a.garipov): Add more testing logic here if [Limiter] becomes + // unexported. + assert.NotEqual(t, lsnr, gotLsnr) + + err = gotLsnr.Close() + require.NoError(t, err) + + gotPC, err := limited.ListenPacket(ctx, "", "") + require.NoError(t, err) + + // TODO(a.garipov): Add more testing logic here if [Limiter] becomes + // unexported. + assert.Equal(t, pc, gotPC) +} diff --git a/internal/connlimiter/listener.go b/internal/connlimiter/listener.go new file mode 100644 index 0000000..59d1f97 --- /dev/null +++ b/internal/connlimiter/listener.go @@ -0,0 +1,139 @@ +// Package connlimiter describes a limiter of the number of active +// stream-connections. +package connlimiter + +import ( + "net" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/optlog" + "github.com/AdguardTeam/golibs/errors" + "github.com/prometheus/client_golang/prometheus" +) + +// limitListener is a wrapper that uses a counter to limit the number of active +// stream-connections. +// +// See https://pkg.go.dev/golang.org/x/net/netutil#LimitListener. +type limitListener struct { + net.Listener + + // counterCond is the condition variable that protects counter and isClosed + // through its locker, as well as signals when connections can be accepted + // again or when the listener has been closed. + counterCond *sync.Cond + + // counter is the shared counter for all listeners. + counter *counter + + // activeGauge is the metrics gauge of currently active stream-connections. + activeGauge prometheus.Gauge + + // waitingHist is the metrics histogram of how much a connection spends + // waiting for an accept. + waitingHist prometheus.Observer + + // serverInfo is used for logging and metrics in both the listener itself + // and in its conns. + serverInfo dnsserver.ServerInfo + + // isClosed shows whether this listener has been closed. + isClosed bool +} + +// Accept returns a new connection if the counter allows it. Otherwise, it +// waits until the counter allows it or the listener is closed. +func (l *limitListener) Accept() (conn net.Conn, err error) { + defer func() { err = errors.Annotate(err, "limit listener: %w") }() + + waitStart := time.Now() + + isClosed := l.increment() + if isClosed { + return nil, net.ErrClosed + } + + l.waitingHist.Observe(time.Since(waitStart).Seconds()) + l.activeGauge.Inc() + + conn, err = l.Listener.Accept() + if err != nil { + l.decrement() + + return nil, err + } + + return &limitConn{ + Conn: conn, + + decrement: l.decrement, + start: time.Now(), + serverInfo: l.serverInfo, + }, nil +} + +// increment waits until it can increase the number of active connections +// in the counter. If the listener is closed while waiting, increment exits and +// returns true +func (l *limitListener) increment() (isClosed bool) { + l.counterCond.L.Lock() + defer l.counterCond.L.Unlock() + + // Make sure to check both that the counter allows this connection and that + // the listener hasn't been closed. Only log about waiting for an increment + // when such waiting actually took place. + waited := false + for !l.counter.increment() && !l.isClosed { + if !waited { + optlog.Debug1("connlimiter: server %s: accept waiting", l.serverInfo.Name) + + waited = true + } + + l.counterCond.Wait() + } + + if waited { + optlog.Debug1("connlimiter: server %s: accept stopped waiting", l.serverInfo.Name) + } + + return l.isClosed +} + +// decrement decreases the number of active connections in the counter and +// broadcasts the change. +func (l *limitListener) decrement() { + l.counterCond.L.Lock() + defer l.counterCond.L.Unlock() + + l.activeGauge.Dec() + + l.counter.decrement() + + l.counterCond.Signal() +} + +// Close closes the underlying listener and signals to all goroutines waiting +// for an accept that the listener is closed now. +func (l *limitListener) Close() (err error) { + defer func() { err = errors.Annotate(err, "limit listener: %w") }() + + l.counterCond.L.Lock() + defer l.counterCond.L.Unlock() + + if l.isClosed { + return net.ErrClosed + } + + // Close the listener immediately; change the boolean and broadcast the + // change later. + err = l.Listener.Close() + + l.isClosed = true + + l.counterCond.Broadcast() + + return err +} diff --git a/internal/dnscheck/consul.go b/internal/dnscheck/consul.go index 190c3ab..d48a5f0 100644 --- a/internal/dnscheck/consul.go +++ b/internal/dnscheck/consul.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" @@ -303,8 +304,8 @@ func (cc *Consul) serveCheckTest(ctx context.Context, w http.ResponseWriter, r * } h := w.Header() - h.Set(agdhttp.HdrNameContentType, agdhttp.HdrValApplicationJSON) - h.Set(agdhttp.HdrNameAccessControlAllowOrigin, agdhttp.HdrValWildcard) + h.Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON) + h.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard) err = json.NewEncoder(w).Encode(inf) if err != nil { diff --git a/internal/dnsdb/bolt_test.go b/internal/dnsdb/bolt_test.go index 0ca8069..a5bc228 100644 --- a/internal/dnsdb/bolt_test.go +++ b/internal/dnsdb/bolt_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "os" + "strings" "testing" "github.com/AdguardTeam/AdGuardDNS/internal/agd" @@ -15,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/golibs/httphdr" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,9 +41,9 @@ func TestBolt_ServeHTTP(t *testing.T) { const dname = "some-domain.name" successHdr := http.Header{ - agdhttp.HdrNameContentType: []string{agdhttp.HdrValTextCSV}, - agdhttp.HdrNameTrailer: []string{agdhttp.HdrNameXError}, - agdhttp.HdrNameContentEncoding: []string{"gzip"}, + httphdr.ContentType: []string{agdhttp.HdrValTextCSV}, + httphdr.Trailer: []string{httphdr.XError}, + httphdr.ContentEncoding: []string{"gzip"}, } newMsg := func(rcode int, name string, qtype uint16) (m *dns.Msg) { @@ -104,7 +106,10 @@ func TestBolt_ServeHTTP(t *testing.T) { for _, m := range msgs { ctx := context.Background() db.Record(ctx, m, &agd.RequestInfo{ - Host: m.Question[0].Name, + // Emulate the logic from init middleware. + // + // See [dnssvc.initMw.newRequestInfo]. + Host: strings.TrimSuffix(m.Question[0].Name, "."), }) err := db.Refresh(context.Background()) @@ -117,7 +122,7 @@ func TestBolt_ServeHTTP(t *testing.T) { (&url.URL{Scheme: "http", Host: "example.com"}).String(), nil, ) - r.Header.Add(agdhttp.HdrNameAcceptEncoding, "gzip") + r.Header.Add(httphdr.AcceptEncoding, "gzip") for _, tc := range testCases { db := newTmpBolt(t) diff --git a/internal/dnsdb/bolthttp.go b/internal/dnsdb/bolthttp.go index 2d969a0..0f32a15 100644 --- a/internal/dnsdb/bolthttp.go +++ b/internal/dnsdb/bolthttp.go @@ -15,6 +15,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "go.etcd.io/bbolt" ) @@ -38,7 +39,7 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) { } h := w.Header() - h.Add(agdhttp.HdrNameContentType, agdhttp.HdrValTextCSV) + h.Add(httphdr.ContentType, agdhttp.HdrValTextCSV) if dbPath == "" { // No data. @@ -47,10 +48,10 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - h.Set(agdhttp.HdrNameTrailer, agdhttp.HdrNameXError) + h.Set(httphdr.Trailer, httphdr.XError) defer func() { if err != nil { - h.Set(agdhttp.HdrNameXError, err.Error()) + h.Set(httphdr.XError, err.Error()) agd.Collectf(ctx, db.errColl, "dnsdb: http handler error: %w", err) } }() @@ -59,8 +60,8 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) { var rw io.Writer = w // TODO(a.garipov): Consider parsing the quality value. - if strings.Contains(r.Header.Get(agdhttp.HdrNameAcceptEncoding), "gzip") { - h.Set(agdhttp.HdrNameContentEncoding, "gzip") + if strings.Contains(r.Header.Get(httphdr.AcceptEncoding), "gzip") { + h.Set(httphdr.ContentEncoding, "gzip") gw := gzip.NewWriter(w) defer func() { err = errors.WithDeferred(err, gw.Close()) }() diff --git a/internal/dnsmsg/constructor.go b/internal/dnsmsg/constructor.go index ee282de..e1cd697 100644 --- a/internal/dnsmsg/constructor.go +++ b/internal/dnsmsg/constructor.go @@ -195,12 +195,12 @@ func (c *Constructor) newHdr(req *dns.Msg, rrType RRType) (hdr dns.RR_Header) { } // newHdrWithClass returns a new resource record header with specified class. -func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, class dns.Class) (hdr dns.RR_Header) { +func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, cl dns.Class) (h dns.RR_Header) { return dns.RR_Header{ Name: req.Question[0].Name, Rrtype: rrType, Ttl: uint32(c.fltRespTTL.Seconds()), - Class: uint16(class), + Class: uint16(cl), } } diff --git a/internal/dnsserver/cache/cache_test.go b/internal/dnsserver/cache/cache_test.go index 6207c53..6673096 100644 --- a/internal/dnsserver/cache/cache_test.go +++ b/internal/dnsserver/cache/cache_test.go @@ -23,14 +23,8 @@ func TestMiddleware_Wrap(t *testing.T) { aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) cnameReq := dnsservertest.NewReq(reqHostname, dns.TypeCNAME, dns.ClassINET) - cnameAns := dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)}, - Sec: dnsservertest.SectionAnswer, - } - soaNs := dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewSOA(reqHostname, 3600, reqNs1, reqNs2)}, - Sec: dnsservertest.SectionNs, - } + cnameAns := dnsservertest.SectionAnswer{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)} + soaNs := dnsservertest.SectionNs{dnsservertest.NewSOA(reqHostname, 3600, reqNs1, reqNs2)} const N = 5 testCases := []struct { @@ -40,8 +34,8 @@ func TestMiddleware_Wrap(t *testing.T) { wantNumReq int }{{ req: aReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, 3600, net.IP{1, 2, 3, 4})}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, 3600, net.IP{1, 2, 3, 4}), }), name: "simple_a", wantNumReq: 1, @@ -67,9 +61,8 @@ func TestMiddleware_Wrap(t *testing.T) { wantNumReq: N, }, { req: aReq, - resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewNS(reqHostname, 3600, reqNs1)}, - Sec: dnsservertest.SectionNs, + resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.SectionNs{ + dnsservertest.NewNS(reqHostname, 3600, reqNs1), }), name: "non_authoritative_nxdomain", // TODO(ameshkov): Consider https://datatracker.ietf.org/doc/html/rfc2308#section-3. @@ -86,15 +79,15 @@ func TestMiddleware_Wrap(t *testing.T) { wantNumReq: 1, }, { req: cnameReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.SectionAnswer{ + dnsservertest.NewCNAME(reqHostname, 3600, reqCname), }), name: "simple_cname_ans", wantNumReq: 1, }, { req: aReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4})}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}), }), name: "expired_one", wantNumReq: N, diff --git a/internal/dnsserver/dnsservertest/handler.go b/internal/dnsserver/dnsservertest/handler.go index 3f5a67f..18b444b 100644 --- a/internal/dnsserver/dnsservertest/handler.go +++ b/internal/dnsserver/dnsservertest/handler.go @@ -2,14 +2,19 @@ package dnsservertest import ( "context" - "net" + "time" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) -// CreateTestHandler creates a [dnsserver.Handler] with the specified parameters. +// AnswerTTL is the default TTL of the test handler's answers. +const AnswerTTL time.Duration = 100 * time.Second + +// CreateTestHandler creates a [dnsserver.Handler] with the specified +// parameters. All responses will have the [TestAnsTTL] TTL. func CreateTestHandler(recordsCount int) (h dnsserver.Handler) { f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) { // Check that necessary context keys are set. @@ -20,30 +25,23 @@ func CreateTestHandler(recordsCount int) (h dnsserver.Handler) { return errors.Error("client info does not contain server name") } - hostname := req.Question[0].Name - - resp := &dns.Msg{ - Compress: true, + ans := make(SectionAnswer, 0, recordsCount) + hdr := dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(AnswerTTL.Seconds()), } - resp.SetReply(req) + ip := netutil.IPv4Localhost().Prev() for i := 0; i < recordsCount; i++ { - hdr := dns.RR_Header{ - Name: hostname, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 100, - } - - a := &dns.A{ - // Add 1 to make sure that each IP is valid. - A: net.IP{127, 0, 0, byte(i + 1)}, - Hdr: hdr, - } - - resp.Answer = append(resp.Answer, a) + // Add 1 to make sure that each IP is valid. + ip = ip.Next() + ans = append(ans, &dns.A{Hdr: hdr, A: ip.AsSlice()}) } + resp := NewResp(dns.RcodeSuccess, req, ans) + _ = rw.WriteMsg(ctx, req, resp) return nil diff --git a/internal/dnsserver/dnsservertest/msg.go b/internal/dnsserver/dnsservertest/msg.go index 6feac03..8481ced 100644 --- a/internal/dnsserver/dnsservertest/msg.go +++ b/internal/dnsserver/dnsservertest/msg.go @@ -4,23 +4,17 @@ import ( "net" "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/require" ) // CreateMessage creates a DNS message for the specified hostname and qtype. func CreateMessage(hostname string, qtype uint16) (m *dns.Msg) { - return &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - RecursionDesired: true, - }, - Question: []dns.Question{{ - Name: dns.Fqdn(hostname), - Qtype: qtype, - Qclass: dns.ClassINET, - }}, - } + m = NewReq(hostname, qtype, dns.ClassINET) + m.RecursionDesired = true + + return m } // RequireResponse checks that the DNS response we received is what was @@ -29,9 +23,9 @@ func RequireResponse( t *testing.T, req *dns.Msg, resp *dns.Msg, - expectedRecordsCount int, - expectedRCode int, - expectedTruncated bool, + wantAnsLen int, + wantRCode int, + wantTruncated bool, ) { t.Helper() @@ -40,26 +34,64 @@ func RequireResponse( // Check that Opcode is not changed in the response // regardless of the response status require.Equal(t, req.Opcode, resp.Opcode) - require.Equal(t, expectedRCode, resp.Rcode) - require.Equal(t, expectedTruncated, resp.Truncated) + require.Equal(t, wantRCode, resp.Rcode) + require.Equal(t, wantTruncated, resp.Truncated) require.True(t, resp.Response) // Response must not have a Z flag set even for a query that does // See https://github.com/miekg/dns/issues/975 require.False(t, resp.Zero) - require.Equal(t, expectedRecordsCount, len(resp.Answer)) + require.Len(t, resp.Answer, wantAnsLen) // Check that there's an OPT record in the response if len(req.Extra) > 0 { - require.True(t, len(resp.Extra) > 0) + require.NotEmpty(t, resp.Extra) } - if expectedRecordsCount > 0 { - a, ok := resp.Answer[0].(*dns.A) - require.True(t, ok) + if wantAnsLen > 0 { + a := testutil.RequireTypeAssert[*dns.A](t, resp.Answer[0]) require.Equal(t, req.Question[0].Name, a.Hdr.Name) } } +// RRSection is the resource record set to be appended to a new message created +// by [NewReq] and [NewResp]. It's essentially a sum type of: +// +// - [SectionAnswer] +// - [SectionNs] +// - [SectionExtra] +type RRSection interface { + // appendTo modifies m adding the resource record set into it appropriately. + appendTo(m *dns.Msg) +} + +// type check +var ( + _ RRSection = SectionAnswer{} + _ RRSection = SectionNs{} + _ RRSection = SectionExtra{} +) + +// SectionAnswer should wrap a resource record set for the Answer section of DNS +// message. +type SectionAnswer []dns.RR + +// appendTo implements the [RRSection] interface for SectionAnswer. +func (rrs SectionAnswer) appendTo(m *dns.Msg) { m.Answer = append(m.Answer, ([]dns.RR)(rrs)...) } + +// SectionNs should wrap a resource record set for the Ns section of DNS +// message. +type SectionNs []dns.RR + +// appendTo implements the [RRSection] interface for SectionNs. +func (rrs SectionNs) appendTo(m *dns.Msg) { m.Ns = append(m.Ns, ([]dns.RR)(rrs)...) } + +// SectionExtra should wrap a resource record set for the Extra section of DNS +// message. +type SectionExtra []dns.RR + +// appendTo implements the [RRSection] interface for SectionExtra. +func (rrs SectionExtra) appendTo(m *dns.Msg) { m.Extra = append(m.Extra, ([]dns.RR)(rrs)...) } + // NewReq returns the new DNS request with a single question for name, qtype, // qclass, and rrs added. func NewReq(name string, qtype, qclass uint16, rrs ...RRSection) (req *dns.Msg) { @@ -68,13 +100,15 @@ func NewReq(name string, qtype, qclass uint16, rrs ...RRSection) (req *dns.Msg) Id: dns.Id(), }, Question: []dns.Question{{ - Name: name, + Name: dns.Fqdn(name), Qtype: qtype, Qclass: qclass, }}, } - withRRs(req, rrs...) + for _, rr := range rrs { + rr.appendTo(req) + } return req } @@ -86,50 +120,13 @@ func NewResp(rcode int, req *dns.Msg, rrs ...RRSection) (resp *dns.Msg) { resp.RecursionAvailable = true resp.Compress = true - withRRs(resp, rrs...) + for _, rr := range rrs { + rr.appendTo(resp) + } return resp } -// MsgSection is used to specify the resource record set of the DNS message. -type MsgSection int - -// Possible values of the MsgSection. -const ( - SectionAnswer MsgSection = iota - SectionNs - SectionExtra -) - -// RRSection is the slice of resource records to be appended to a new message -// created by NewReq and NewResp. -// -// TODO(e.burkov): Use separate types for different sections of DNS message -// instead of constants. -type RRSection struct { - RRs []dns.RR - Sec MsgSection -} - -// withRRs adds rrs to the m. Invalid rrs are skipped. -func withRRs(m *dns.Msg, rrs ...RRSection) { - for _, r := range rrs { - var msgRR *[]dns.RR - switch r.Sec { - case SectionAnswer: - msgRR = &m.Answer - case SectionNs: - msgRR = &m.Ns - case SectionExtra: - msgRR = &m.Extra - default: - continue - } - - *msgRR = append(*msgRR, r.RRs...) - } -} - // NewCNAME constructs the new resource record of type CNAME. func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) { return &dns.CNAME{ diff --git a/internal/dnsserver/dnsservertest/msg_test.go b/internal/dnsserver/dnsservertest/msg_test.go new file mode 100644 index 0000000..88e9682 --- /dev/null +++ b/internal/dnsserver/dnsservertest/msg_test.go @@ -0,0 +1,76 @@ +package dnsservertest_test + +import ( + "fmt" + "net" + + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" +) + +func ExampleNewReq() { + const nonUniqueID = 1234 + + m := dnsservertest.NewReq("example.org.", dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{ + dnsservertest.NewECSExtra(netutil.IPv4Zero(), uint16(netutil.AddrFamilyIPv4), 0, 0), + }) + m.Id = nonUniqueID + fmt.Println(m) + + // Output: + // + // ;; opcode: QUERY, status: NOERROR, id: 1234 + // ;; flags:; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 1 + // + // ;; OPT PSEUDOSECTION: + // ; EDNS: version 0; flags:; udp: 0 + // ; SUBNET: 0.0.0.0/0/0 + // + // ;; QUESTION SECTION: + // ;example.org. IN A +} + +func ExampleNewResp() { + const ( + nonUniqueID = 1234 + testFQDN = "example.org." + realTestFQDN = "real." + testFQDN + ) + + m := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{ + dnsservertest.NewECSExtra(netutil.IPv4Zero(), uint16(netutil.AddrFamilyIPv4), 0, 0), + }) + m.Id = nonUniqueID + + m = dnsservertest.NewResp(dns.RcodeSuccess, m, dnsservertest.SectionAnswer{ + dnsservertest.NewCNAME(testFQDN, 3600, realTestFQDN), + dnsservertest.NewA(realTestFQDN, 3600, net.IP{1, 2, 3, 4}), + }, dnsservertest.SectionNs{ + dnsservertest.NewSOA(realTestFQDN, 1000, "ns."+realTestFQDN, "mbox."+realTestFQDN), + dnsservertest.NewNS(testFQDN, 1000, "ns."+testFQDN), + }, dnsservertest.SectionExtra{ + m.IsEdns0(), + }) + fmt.Println(m) + + // Output: + // + // ;; opcode: QUERY, status: NOERROR, id: 1234 + // ;; flags: qr ra; QUERY: 1, ANSWER: 2, AUTHORITY: 2, ADDITIONAL: 1 + // + // ;; OPT PSEUDOSECTION: + // ; EDNS: version 0; flags:; udp: 0 + // ; SUBNET: 0.0.0.0/0/0 + // + // ;; QUESTION SECTION: + // ;example.org. IN A + // + // ;; ANSWER SECTION: + // example.org. 3600 IN CNAME real.example.org. + // real.example.org. 3600 IN A 1.2.3.4 + // + // ;; AUTHORITY SECTION: + // real.example.org. 1000 IN SOA ns.real.example.org. mbox.real.example.org. 0 0 0 0 0 + // example.org. 1000 IN NS ns.example.org. +} diff --git a/internal/dnsserver/example_test.go b/internal/dnsserver/example_test.go index 71ed64f..4d058a0 100644 --- a/internal/dnsserver/example_test.go +++ b/internal/dnsserver/example_test.go @@ -59,7 +59,7 @@ func ExampleWithMiddlewares() { forwarder := forward.NewHandler(&forward.HandlerConfig{ Address: netip.MustParseAddrPort("94.140.14.140:53"), Network: forward.NetworkAny, - }, true) + }) middleware := querylog.NewLogMiddleware(os.Stdout) handler := dnsserver.WithMiddlewares(forwarder, middleware) diff --git a/internal/dnsserver/forward/example_test.go b/internal/dnsserver/forward/example_test.go index 897a283..b22035c 100644 --- a/internal/dnsserver/forward/example_test.go +++ b/internal/dnsserver/forward/example_test.go @@ -20,7 +20,7 @@ func ExampleNewHandler() { FallbackAddresses: []netip.AddrPort{ netip.MustParseAddrPort("1.1.1.1:53"), }, - }, false), + }), }, } diff --git a/internal/dnsserver/forward/forward.go b/internal/dnsserver/forward/forward.go index a683ccc..c6e3fa8 100644 --- a/internal/dnsserver/forward/forward.go +++ b/internal/dnsserver/forward/forward.go @@ -38,20 +38,6 @@ import ( // queries to the specified upstreams. It also implements [io.Closer], allowing // resource reuse. type Handler struct { - // lastFailedHealthcheck shows the last time of failed healthcheck. - // - // It is of type int64 to be accessed by package atomic. The field is - // arranged for 64-bit alignment on the first position. - lastFailedHealthcheck int64 - - // useFallbacks is not zero if the main upstream server failed health check - // probes and therefore the fallback upstream servers should be used for - // resolving. - // - // It is of type uint64 to be accessed by package atomic. The field is - // arranged for 64-bit alignment on the second position. - useFallbacks uint64 - // metrics is a listener for the handler events. metrics MetricsListener @@ -65,12 +51,21 @@ type Handler struct { // fallbacks is a list of fallback DNS servers. fallbacks []Upstream + // lastFailedHealthcheck contains the Unix time of the last time of failed + // healthcheck. + lastFailedHealthcheck atomic.Int64 + // timeout specifies the query timeout for upstreams and fallbacks. timeout time.Duration // hcBackoffTime specifies the delay before returning to the main upstream // after failed healthcheck probe. hcBackoff time.Duration + + // useFallbacks is true if the main upstream server failed health check + // probes and therefore the fallback upstream servers should be used for + // resolving. + useFallbacks atomic.Bool } // ErrNoResponse is returned from Handler's methods when the desired response @@ -113,14 +108,21 @@ type HandlerConfig struct { // upstream until this time has passed. If the healthcheck is still // performed, each failed check advances the backoff. HealthcheckBackoffDuration time.Duration + + // HealthcheckInitDuration is the time duration for initial upstream + // healthcheck. + HealthcheckInitDuration time.Duration } // NewHandler initializes a new instance of Handler. It also performs a health -// check afterwards if initialHealthcheck is true. Note, that this handler only -// support plain DNS upstreams. c must not be nil. -func NewHandler(c *HandlerConfig, initialHealthcheck bool) (h *Handler) { +// check afterwards if c.HealthcheckInitDuration is not zero. Note, that this +// handler only support plain DNS upstreams. c must not be nil. +func NewHandler(c *HandlerConfig) (h *Handler) { h = &Handler{ - upstream: NewUpstreamPlain(c.Address, c.Network), + upstream: NewUpstreamPlain(&UpstreamPlainConfig{ + Network: c.Network, + Address: c.Address, + }), hcDomainTmpl: c.HealthcheckDomainTmpl, timeout: c.Timeout, hcBackoff: c.HealthcheckBackoffDuration, @@ -134,13 +136,19 @@ func NewHandler(c *HandlerConfig, initialHealthcheck bool) (h *Handler) { h.fallbacks = make([]Upstream, len(c.FallbackAddresses)) for i, addr := range c.FallbackAddresses { - h.fallbacks[i] = NewUpstreamPlain(addr, NetworkAny) + h.fallbacks[i] = NewUpstreamPlain(&UpstreamPlainConfig{ + Network: NetworkAny, + Address: addr, + }) } - if initialHealthcheck { + if c.HealthcheckInitDuration > 0 { + ctx, cancel := context.WithTimeout(context.Background(), c.HealthcheckInitDuration) + defer cancel() + // Ignore the error since it's considered non-critical and also should // have been logged already. - _ = h.refresh(context.Background(), true) + _ = h.refresh(ctx, true) } return h @@ -176,7 +184,7 @@ func (h *Handler) ServeDNS( ) (err error) { defer func() { err = annotate(err, h.upstream) }() - useFallbacks := atomic.LoadUint64(&h.useFallbacks) != 0 + useFallbacks := h.useFallbacks.Load() var resp *dns.Msg if !useFallbacks { resp, err = h.exchange(ctx, h.upstream, req) diff --git a/internal/dnsserver/forward/forward_test.go b/internal/dnsserver/forward/forward_test.go index 52fbe5a..7c3221e 100644 --- a/internal/dnsserver/forward/forward_test.go +++ b/internal/dnsserver/forward/forward_test.go @@ -4,6 +4,7 @@ import ( "context" "net/netip" "testing" + "time" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" @@ -17,6 +18,21 @@ func TestMain(m *testing.M) { testutil.DiscardLogOutput(m) } +// testTimeout is the timeout for tests. +const testTimeout = 1 * time.Second + +// newTimeoutCtx is a test helper that returns a context with a timeout of +// [testTimeout] and its cancel function being called in the test cleanup. +// It should not be used where cancellation is expected sooner. +func newTimeoutCtx(tb testing.TB, parent context.Context) (ctx context.Context) { + tb.Helper() + + ctx, cancel := context.WithTimeout(parent, testTimeout) + tb.Cleanup(cancel) + + return ctx +} + func TestHandler_ServeDNS(t *testing.T) { srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) @@ -24,13 +40,14 @@ func TestHandler_ServeDNS(t *testing.T) { handler := forward.NewHandler(&forward.HandlerConfig{ Address: netip.MustParseAddrPort(addr), Network: forward.NetworkAny, - }, true) + Timeout: testTimeout, + }) req := dnsservertest.CreateMessage("example.org.", dns.TypeA) rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr()) // Check the handler's ServeDNS method - err := handler.ServeDNS(context.Background(), rw, req) + err := handler.ServeDNS(newTimeoutCtx(t, context.Background()), rw, req) require.NoError(t, err) res := rw.Msg() @@ -46,7 +63,8 @@ func TestHandler_ServeDNS_fallbackNetError(t *testing.T) { FallbackAddresses: []netip.AddrPort{ netip.MustParseAddrPort(srv.LocalUDPAddr().String()), }, - }, true) + Timeout: testTimeout, + }) req := dnsservertest.CreateMessage("example.org.", dns.TypeA) rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr()) diff --git a/internal/dnsserver/forward/healthcheck.go b/internal/dnsserver/forward/healthcheck.go index b385dd4..d873efe 100644 --- a/internal/dnsserver/forward/healthcheck.go +++ b/internal/dnsserver/forward/healthcheck.go @@ -6,7 +6,6 @@ import ( "math/rand" "strconv" "strings" - "sync/atomic" "time" "github.com/AdguardTeam/golibs/errors" @@ -25,25 +24,25 @@ func (h *Handler) refresh(ctx context.Context, shouldReport bool) (err error) { return nil } - var useFallbacks uint64 - lastFailed := atomic.LoadInt64(&h.lastFailedHealthcheck) + var useFallbacks bool + lastFailed := h.lastFailedHealthcheck.Load() shouldReturnToMain := time.Since(time.Unix(lastFailed, 0)) >= h.hcBackoff if !shouldReturnToMain { // Make sure that useFallbacks is left true if the main upstream is // still in the backoff mode. - useFallbacks = 1 + useFallbacks = true log.Debug("forward: healthcheck: in backoff, will not return to main on success") } err = h.healthcheck(ctx) if err != nil { - atomic.StoreInt64(&h.lastFailedHealthcheck, time.Now().Unix()) - useFallbacks = 1 + h.lastFailedHealthcheck.Store(time.Now().Unix()) + useFallbacks = true } - statusChanged := atomic.CompareAndSwapUint64(&h.useFallbacks, 1-useFallbacks, useFallbacks) + statusChanged := h.useFallbacks.CompareAndSwap(!useFallbacks, useFallbacks) if statusChanged || shouldReport { - h.setUpstreamStatus(useFallbacks == 0) + h.setUpstreamStatus(!useFallbacks) } return errors.Annotate(err, "forward: %w") diff --git a/internal/dnsserver/forward/healthcheck_test.go b/internal/dnsserver/forward/healthcheck_test.go index 8a3e2aa..7c224f9 100644 --- a/internal/dnsserver/forward/healthcheck_test.go +++ b/internal/dnsserver/forward/healthcheck_test.go @@ -15,8 +15,10 @@ import ( ) func TestHandler_Refresh(t *testing.T) { - var upstreamUp uint64 - var upstreamRequestsCount uint64 + var upstreamIsUp atomic.Bool + var upstreamRequestsCount atomic.Int64 + + defaultHandler := dnsservertest.DefaultHandler() // This handler writes an empty message if upstreamUp flag is false. handlerFunc := dnsserver.HandlerFunc(func( @@ -24,16 +26,15 @@ func TestHandler_Refresh(t *testing.T) { rw dnsserver.ResponseWriter, req *dns.Msg, ) (err error) { - atomic.AddUint64(&upstreamRequestsCount, 1) + upstreamRequestsCount.Add(1) nrw := dnsserver.NewNonWriterResponseWriter(rw.LocalAddr(), rw.RemoteAddr()) - handler := dnsservertest.DefaultHandler() - err = handler.ServeDNS(ctx, nrw, req) + err = defaultHandler.ServeDNS(ctx, nrw, req) if err != nil { return err } - if atomic.LoadUint64(&upstreamUp) == 0 { + if !upstreamIsUp.Load() { return rw.WriteMsg(ctx, req, &dns.Msg{}) } @@ -41,7 +42,7 @@ func TestHandler_Refresh(t *testing.T) { }) upstream, _ := dnsservertest.RunDNSServer(t, handlerFunc) - fallback, _ := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) + fallback, _ := dnsservertest.RunDNSServer(t, defaultHandler) handler := forward.NewHandler(&forward.HandlerConfig{ Address: netip.MustParseAddrPort(upstream.LocalUDPAddr().String()), Network: forward.NetworkAny, @@ -49,38 +50,41 @@ func TestHandler_Refresh(t *testing.T) { FallbackAddresses: []netip.AddrPort{ netip.MustParseAddrPort(fallback.LocalUDPAddr().String()), }, - // Make sure that the handler routs queries back to the main upstream + Timeout: testTimeout, + // Make sure that the handler routes queries back to the main upstream // immediately. HealthcheckBackoffDuration: 0, - }, false) + }) req := dnsservertest.CreateMessage("example.org.", dns.TypeA) rw := dnsserver.NewNonWriterResponseWriter(fallback.LocalUDPAddr(), fallback.LocalUDPAddr()) - err := handler.ServeDNS(context.Background(), rw, req) - require.Error(t, err) - assert.Equal(t, uint64(1), atomic.LoadUint64(&upstreamRequestsCount)) + ctx := context.Background() - err = handler.Refresh(context.Background()) + err := handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req) require.Error(t, err) - assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount)) + assert.Equal(t, int64(2), upstreamRequestsCount.Load()) - err = handler.ServeDNS(context.Background(), rw, req) + err = handler.Refresh(newTimeoutCtx(t, ctx)) + require.Error(t, err) + assert.Equal(t, int64(4), upstreamRequestsCount.Load()) + + err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req) require.NoError(t, err) - assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount)) + assert.Equal(t, int64(4), upstreamRequestsCount.Load()) // Now, set upstream up. - atomic.StoreUint64(&upstreamUp, 1) + upstreamIsUp.Store(true) - err = handler.ServeDNS(context.Background(), rw, req) + err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req) require.NoError(t, err) - assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount)) + assert.Equal(t, int64(4), upstreamRequestsCount.Load()) - err = handler.Refresh(context.Background()) + err = handler.Refresh(newTimeoutCtx(t, ctx)) require.NoError(t, err) - assert.Equal(t, uint64(3), atomic.LoadUint64(&upstreamRequestsCount)) + assert.Equal(t, int64(5), upstreamRequestsCount.Load()) - err = handler.ServeDNS(context.Background(), rw, req) + err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req) require.NoError(t, err) - assert.Equal(t, uint64(4), atomic.LoadUint64(&upstreamRequestsCount)) + assert.Equal(t, int64(6), upstreamRequestsCount.Load()) } diff --git a/internal/dnsserver/forward/upstreamplain.go b/internal/dnsserver/forward/upstreamplain.go index 97d1a32..edf3e24 100644 --- a/internal/dnsserver/forward/upstreamplain.go +++ b/internal/dnsserver/forward/upstreamplain.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/netip" + "strings" "sync" "time" @@ -15,7 +16,7 @@ import ( "github.com/miekg/dns" ) -// Network is a enumeration of networks UpstreamPlain supports +// Network is a enumeration of networks UpstreamPlain supports. type Network string const ( @@ -72,11 +73,21 @@ type UpstreamPlain struct { // type check var _ Upstream = (*UpstreamPlain)(nil) -// NewUpstreamPlain creates and initializes a new instance of UpstreamPlain. -func NewUpstreamPlain(addr netip.AddrPort, network Network) (ups *UpstreamPlain) { +// UpstreamPlainConfig is the configuration structure for a plain-DNS upstream. +type UpstreamPlainConfig struct { + // Network is the network to use for this upstream. + Network Network + + // Address is the address of the upstream DNS server. + Address netip.AddrPort +} + +// NewUpstreamPlain returns a new properly initialized *UpstreamPlain. c must +// not be nil. +func NewUpstreamPlain(c *UpstreamPlainConfig) (ups *UpstreamPlain) { ups = &UpstreamPlain{ - addr: addr, - network: network, + addr: c.Address, + network: c.Network, } ups.connsPoolUDP = pool.NewPool(poolMaxCapacity, makeConnsPoolFactory(ups, NetworkUDP)) @@ -127,33 +138,39 @@ func (u *UpstreamPlain) String() (str string) { return fmt.Sprintf("%s://%s", u.network, u.addr) } -// exchangeUDP attempts to send the DNS request over UDP. It returns a -// fallbackToTCP flag to signal if we should fallback to using TCP instead. -// this may happen if the response received over UDP was truncated and +// exchangeUDP attempts to send the DNS request over UDP. It returns a +// fallbackToTCP flag to signal if the caller should fallback to using TCP +// instead. This may happen if the response received over UDP was truncated and // TCP is enabled for this upstream or if UDP is disabled. func (u *UpstreamPlain) exchangeUDP( ctx context.Context, req *dns.Msg, ) (fallbackToTCP bool, resp *dns.Msg, err error) { if u.network == NetworkTCP { - // fallback to TCP immediately. + // Fallback to TCP immediately. return true, nil, nil } resp, err = u.exchangeNet(ctx, req, NetworkUDP) if err != nil { - // error means that the upstream is dead, no need to fallback to TCP. - return false, resp, err + // The network error always causes the subsequent query attempt using + // fresh UDP connection, so if it happened again, the upstream is likely + // dead and using TCP appears meaningless. See [exchangeNet]. + // + // Thus, non-network errors are considered being related to the + // response. It may also happen the received response is intended for + // another timeouted request sent from the same source port, but falling + // back to TCP in this case shouldn't hurt. + fallbackToTCP = !isExpectedConnErr(err) + + return fallbackToTCP, resp, err } - // If the response is truncated and we can use TCP, make sure that we'll - // fallback to TCP. We also fallback to TCP if we received a response with - // the wrong ID (it may happen with the servers under heavy load). - if (resp.Truncated || resp.Id != req.Id) && u.network != NetworkUDP { - fallbackToTCP = true - } + // Also, fallback to TCP if the received response is truncated and the + // upstream isn't UDP-only. + fallbackToTCP = u.network != NetworkUDP && resp != nil && resp.Truncated - return fallbackToTCP, resp, err + return fallbackToTCP, resp, nil } // exchangeNet sends a DNS query using the specified network (either TCP or UDP). @@ -203,25 +220,36 @@ func (u *UpstreamPlain) exchangeNet( return resp, err } -// validateResponse checks if the response is valid for the original query. For -// instance, it is possible to receive a response to a different query, and we -// must be sure that we received what was expected. -func (u *UpstreamPlain) validateResponse(req, resp *dns.Msg) (err error) { +// validatePlainResponse returns an error if the response is not valid for the +// original request. This is required because we might receive a response to a +// different query, e.g. when the server is under heavy load. +func validatePlainResponse(req, resp *dns.Msg) (err error) { if req.Id != resp.Id { return dns.ErrId } - if len(resp.Question) != 1 { - return ErrQuestion + if qlen := len(resp.Question); qlen != 1 { + return fmt.Errorf("%w: only 1 question allowed; got %d", ErrQuestion, qlen) } - if req.Question[0].Name != resp.Question[0].Name { - return ErrQuestion + reqQ, respQ := req.Question[0], resp.Question[0] + + if reqQ.Qtype != respQ.Qtype { + return fmt.Errorf("%w: mismatched type %s", ErrQuestion, dns.Type(respQ.Qtype)) + } + + // Compare the names case-insensitively, just like CoreDNS does. + if !strings.EqualFold(reqQ.Name, respQ.Name) { + return fmt.Errorf("%w: mismatched name %q", ErrQuestion, respQ.Name) } return nil } +// defaultUDPTimeout is the default timeout for waiting a valid DNS message or +// network error. +const defaultUDPTimeout = 1 * time.Minute + // processConn writes the query to the connection and then reads the response // from it. We might be dealing with an idle dead connection so if we get // a network error here, we'll attempt to open a new connection and call this @@ -236,7 +264,7 @@ func (u *UpstreamPlain) processConn( req *dns.Msg, buf []byte, bufLen int, -) (msg *dns.Msg, err error) { +) (resp *dns.Msg, err error) { // Make sure that we return the connection to the pool in the end or close // if there was any error. defer func() { @@ -248,7 +276,12 @@ func (u *UpstreamPlain) processConn( }() // Prepare a context with a deadline if needed. - if deadline, ok := ctx.Deadline(); ok { + deadline, ok := ctx.Deadline() + if !ok && network == NetworkUDP { + deadline, ok = time.Now().Add(defaultUDPTimeout), true + } + + if ok { err = conn.SetDeadline(deadline) if err != nil { return nil, fmt.Errorf("setting deadline: %w", err) @@ -261,19 +294,28 @@ func (u *UpstreamPlain) processConn( return nil, fmt.Errorf("writing request: %w", err) } - var resp *dns.Msg + return u.readValidMsg(req, network, conn, buf) +} + +// readValidMsg reads the response from conn to buf, parses and validates it. +func (u *UpstreamPlain) readValidMsg( + req *dns.Msg, + network Network, + conn net.Conn, + buf []byte, +) (resp *dns.Msg, err error) { resp, err = u.readMsg(network, conn, buf) if err != nil { - // Error is already wrapped. + // Don't wrap the error, because it's informative enough as is. return nil, err } - err = u.validateResponse(req, resp) + err = validatePlainResponse(req, resp) if err != nil { - return nil, fmt.Errorf("validating response: %w", err) + return resp, fmt.Errorf("validating %s response: %w", network, err) } - return resp, err + return resp, nil } // readMsg reads the response from the specified connection and parses it. @@ -302,7 +344,8 @@ func (u *UpstreamPlain) readMsg(network Network, conn net.Conn, buf []byte) (*dn if n < minDNSMessageSize { return nil, fmt.Errorf("invalid msg: %w", dns.ErrShortRead) } - ret := new(dns.Msg) + + ret := &dns.Msg{} err = ret.Unpack(buf) if err != nil { return nil, fmt.Errorf("unpacking msg: %w", err) diff --git a/internal/dnsserver/forward/upstreamplain_test.go b/internal/dnsserver/forward/upstreamplain_test.go index 1d1b049..22e60a7 100644 --- a/internal/dnsserver/forward/upstreamplain_test.go +++ b/internal/dnsserver/forward/upstreamplain_test.go @@ -5,11 +5,14 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -31,11 +34,14 @@ func TestUpstreamPlain_Exchange(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) - u := forward.NewUpstreamPlain(netip.MustParseAddrPort(addr), tc.network) + u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: tc.network, + Address: netip.MustParseAddrPort(addr), + }) defer log.OnCloserError(u, log.DEBUG) req := dnsservertest.CreateMessage("example.org.", dns.TypeA) - res, err := u.Exchange(context.Background(), req) + res, err := u.Exchange(newTimeoutCtx(t, context.Background()), req) require.NoError(t, err) require.NotNil(t, res) dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false) @@ -71,34 +77,199 @@ func TestUpstreamPlain_Exchange_truncated(t *testing.T) { return rw.WriteMsg(ctx, req, res) }) - _, addr := dnsservertest.RunDNSServer(t, handlerFunc) + _, addrStr := dnsservertest.RunDNSServer(t, handlerFunc) // Create a test message. req := dnsservertest.CreateMessage("example.org.", dns.TypeA) // First, check that we receive truncated response over UDP. - uAddr := netip.MustParseAddrPort(addr) - uUDP := forward.NewUpstreamPlain(uAddr, forward.NetworkUDP) + addr := netip.MustParseAddrPort(addrStr) + uUDP := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: forward.NetworkUDP, + Address: addr, + }) defer log.OnCloserError(uUDP, log.DEBUG) - res, err := uUDP.Exchange(context.Background(), req) + ctx := context.Background() + + res, err := uUDP.Exchange(newTimeoutCtx(t, ctx), req) require.NoError(t, err) dnsservertest.RequireResponse(t, req, res, 0, dns.RcodeSuccess, true) // Second, check that nothing is truncated over TCP. - uTCP := forward.NewUpstreamPlain(uAddr, forward.NetworkTCP) + uTCP := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: forward.NetworkTCP, + Address: addr, + }) defer log.OnCloserError(uTCP, log.DEBUG) - res, err = uTCP.Exchange(context.Background(), req) + res, err = uTCP.Exchange(newTimeoutCtx(t, ctx), req) require.NoError(t, err) dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false) // Now with NetworkANY response is also not truncated since the upstream // fallbacks to TCP. - uAny := forward.NewUpstreamPlain(uAddr, forward.NetworkAny) + uAny := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: forward.NetworkAny, + Address: addr, + }) defer log.OnCloserError(uAny, log.DEBUG) - res, err = uAny.Exchange(context.Background(), req) + res, err = uAny.Exchange(newTimeoutCtx(t, ctx), req) require.NoError(t, err) dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false) } + +func TestUpstreamPlain_Exchange_fallbackFail(t *testing.T) { + pt := testutil.PanicT{} + + // Use only unbuffered channels to block until received and validated. + netCh := make(chan string) + respCh := make(chan struct{}) + + h := dnsserver.HandlerFunc(func( + ctx context.Context, + rw dnsserver.ResponseWriter, + req *dns.Msg, + ) (err error) { + testutil.RequireSend(pt, netCh, rw.RemoteAddr().Network(), testTimeout) + + resp := dnsservertest.NewResp(dns.RcodeSuccess, req) + + // Make all responses invalid. + resp.Id = req.Id + 1 + + return rw.WriteMsg(ctx, req, resp) + }) + + _, addr := dnsservertest.RunDNSServer(t, h) + u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: forward.NetworkUDP, + Address: netip.MustParseAddrPort(addr), + }) + testutil.CleanupAndRequireSuccess(t, u.Close) + + req := dnsservertest.CreateMessage("example.org.", dns.TypeA) + + var resp *dns.Msg + var err error + go func() { + resp, err = u.Exchange(newTimeoutCtx(t, context.Background()), req) + testutil.RequireSend(pt, respCh, struct{}{}, testTimeout) + }() + + // First attempt should use UDP and fail due to bad ID. + network, _ := testutil.RequireReceive(t, netCh, testTimeout) + require.Equal(t, string(forward.NetworkUDP), network) + + // Second attempt should use TCP and succeed. + network, _ = testutil.RequireReceive(t, netCh, testTimeout) + require.Equal(t, string(forward.NetworkTCP), network) + + testutil.RequireReceive(t, respCh, testTimeout) + require.ErrorIs(t, err, dns.ErrId) + assert.NotNil(t, resp) +} + +func TestUpstreamPlain_Exchange_fallbackSuccess(t *testing.T) { + const ( + // network is set to UDP to ensure that falling back to TCP will still + // be performed. + network = forward.NetworkUDP + + goodDomain = "domain.example." + badDomain = "bad.example." + ) + + pt := testutil.PanicT{} + + req := dnsservertest.CreateMessage(goodDomain, dns.TypeA) + resp := dnsservertest.NewResp(dns.RcodeSuccess, req) + + // Prepare malformed responses. + + badIDResp := dnsmsg.Clone(resp) + badIDResp.Id = ^req.Id + + badQNumResp := dnsmsg.Clone(resp) + badQNumResp.Question = append(badQNumResp.Question, req.Question[0]) + + badQnameResp := dnsmsg.Clone(resp) + badQnameResp.Question[0].Name = badDomain + + badQtypeResp := dnsmsg.Clone(resp) + badQtypeResp.Question[0].Qtype = dns.TypeMX + + testCases := []struct { + udpResp *dns.Msg + name string + }{{ + udpResp: badIDResp, + name: "wrong_id", + }, { + udpResp: badQNumResp, + name: "wrong_question)_number", + }, { + udpResp: badQnameResp, + name: "wrong_qname", + }, { + udpResp: badQtypeResp, + name: "wrong_qtype", + }} + + for _, tc := range testCases { + clonedReq := dnsmsg.Clone(req) + badResp := dnsmsg.Clone(tc.udpResp) + goodResp := dnsmsg.Clone(resp) + + // Use only unbuffered channels to block until received and validated. + netCh := make(chan string) + respCh := make(chan struct{}) + + h := dnsserver.HandlerFunc(func( + ctx context.Context, + rw dnsserver.ResponseWriter, + req *dns.Msg, + ) (err error) { + network := rw.RemoteAddr().Network() + testutil.RequireSend(pt, netCh, network, testTimeout) + + if network == string(forward.NetworkUDP) { + // Respond with invalid message via UDP. + return rw.WriteMsg(ctx, req, badResp) + } + + // Respond with valid message via TCP. + return rw.WriteMsg(ctx, req, goodResp) + }) + + t.Run(tc.name, func(t *testing.T) { + _, addr := dnsservertest.RunDNSServer(t, dnsserver.HandlerFunc(h)) + + u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{ + Network: network, + Address: netip.MustParseAddrPort(addr), + }) + testutil.CleanupAndRequireSuccess(t, u.Close) + + var actualResp *dns.Msg + var err error + go func() { + actualResp, err = u.Exchange(newTimeoutCtx(t, context.Background()), clonedReq) + testutil.RequireSend(pt, respCh, struct{}{}, testTimeout) + }() + + // First attempt should use UDP and fail due to bad ID. + network, _ := testutil.RequireReceive(t, netCh, testTimeout) + require.Equal(t, string(forward.NetworkUDP), network) + + // Second attempt should use TCP and succeed. + network, _ = testutil.RequireReceive(t, netCh, testTimeout) + require.Equal(t, string(forward.NetworkTCP), network) + + testutil.RequireReceive(t, respCh, testTimeout) + require.NoError(t, err) + dnsservertest.RequireResponse(t, req, actualResp, 0, dns.RcodeSuccess, false) + }) + } +} diff --git a/internal/dnsserver/go.mod b/internal/dnsserver/go.mod index 3544026..38d5bd0 100644 --- a/internal/dnsserver/go.mod +++ b/internal/dnsserver/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver go 1.20 require ( - github.com/AdguardTeam/golibs v0.12.1 + github.com/AdguardTeam/golibs v0.13.2 github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnsstamps v1.0.3 github.com/bluele/gcache v0.0.2 @@ -13,7 +13,7 @@ require ( github.com/prometheus/client_golang v1.14.0 github.com/quic-go/quic-go v0.33.0 github.com/stretchr/testify v1.8.2 - golang.org/x/exp v0.0.0-20230307190834-24139beb5833 + golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/net v0.8.0 golang.org/x/sys v0.6.0 ) diff --git a/internal/dnsserver/go.sum b/internal/dnsserver/go.sum index a0d5590..0dd980d 100644 --- a/internal/dnsserver/go.sum +++ b/internal/dnsserver/go.sum @@ -1,5 +1,4 @@ -github.com/AdguardTeam/golibs v0.12.1 h1:bJfFzCnUCl+QsP6prUltM2Sjt0fTiDBPlxuAwfKP3g8= -github.com/AdguardTeam/golibs v0.12.1/go.mod h1:rIglKDHdLvFT1UbhumBLHO9S4cvWS9MEyT1njommI/Y= +github.com/AdguardTeam/golibs v0.13.2 h1:BPASsyQKmb+b8VnvsNOHp7bKfcZl9Z+Z2UhPjOiupSc= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= @@ -80,8 +79,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s= -golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= diff --git a/internal/dnsserver/netext/listenconfig.go b/internal/dnsserver/netext/listenconfig.go index 19a47d3..d148645 100644 --- a/internal/dnsserver/netext/listenconfig.go +++ b/internal/dnsserver/netext/listenconfig.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "syscall" ) // ListenConfig is the interface that allows controlling options of connections @@ -18,23 +19,43 @@ type ListenConfig interface { ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error) } +// defaultCtrlConf is the default control config. By default, don't alter +// anything. defaultCtrlConf must not be mutated. +var defaultCtrlConf = &ControlConfig{ + RcvBufSize: 0, + SndBufSize: 0, +} + // DefaultListenConfig returns the default [ListenConfig] used by the servers in // this module except for the plain-DNS ones, which use -// [DefaultListenConfigWithOOB]. -func DefaultListenConfig() (lc ListenConfig) { +// [DefaultListenConfigWithOOB]. If conf is nil, a default configuration is +// used. +func DefaultListenConfig(conf *ControlConfig) (lc ListenConfig) { + if conf == nil { + conf = defaultCtrlConf + } + return &net.ListenConfig{ - Control: defaultListenControl, + Control: func(_, _ string, c syscall.RawConn) (err error) { + return listenControlWithSO(conf, c) + }, } } // DefaultListenConfigWithOOB returns the default [ListenConfig] used by the // plain-DNS servers in this module. The resulting ListenConfig sets additional // socket flags and processes the control-messages of connections created with -// ListenPacket. -func DefaultListenConfigWithOOB() (lc ListenConfig) { +// ListenPacket. If conf is nil, a default configuration is used. +func DefaultListenConfigWithOOB(conf *ControlConfig) (lc ListenConfig) { + if conf == nil { + conf = defaultCtrlConf + } + return &listenConfigOOB{ ListenConfig: net.ListenConfig{ - Control: defaultListenControl, + Control: func(_, _ string, c syscall.RawConn) (err error) { + return listenControlWithSO(conf, c) + }, }, } } @@ -71,3 +92,14 @@ func (lc *listenConfigOOB) ListenPacket( return wrapPacketConn(c), nil } + +// ControlConfig is the configuration of socket options. +type ControlConfig struct { + // RcvBufSize defines the size of socket receive buffer in bytes. Default + // is zero (uses system settings). + RcvBufSize int + + // SndBufSize defines the size of socket send buffer in bytes. Default is + // zero (uses system settings). + SndBufSize int +} diff --git a/internal/dnsserver/netext/listenconfig_unix.go b/internal/dnsserver/netext/listenconfig_unix.go index 3435b0c..23fca54 100644 --- a/internal/dnsserver/netext/listenconfig_unix.go +++ b/internal/dnsserver/netext/listenconfig_unix.go @@ -13,17 +13,50 @@ import ( "golang.org/x/sys/unix" ) -// defaultListenControl is used as a [net.ListenConfig.Control] function to set -// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this -// package. -func defaultListenControl(_, _ string, c syscall.RawConn) (err error) { +// setSockOptFunc is a function that sets a socket option on fd. +type setSockOptFunc func(fd int) (err error) + +// newSetSockOptFunc returns a socket-option function with the given parameters. +func newSetSockOptFunc(name string, lvl, opt, val int) (o setSockOptFunc) { + return func(fd int) (err error) { + err = unix.SetsockoptInt(fd, lvl, opt, val) + + return errors.Annotate(err, "setting %s: %w", name) + } +} + +// listenControlWithSO is used as a [net.ListenConfig.Control] function to set +// the SO_REUSEPORT, SO_SNDBUF, and SO_RCVBUF socket options on all sockets +// used by the DNS servers in this package. conf must not be nil. +func listenControlWithSO(conf *ControlConfig, c syscall.RawConn) (err error) { + opts := []setSockOptFunc{ + newSetSockOptFunc("SO_REUSEPORT", unix.SOL_SOCKET, unix.SO_REUSEPORT, 1), + } + + if conf.SndBufSize > 0 { + opts = append( + opts, + newSetSockOptFunc("SO_SNDBUF", unix.SOL_SOCKET, unix.SO_SNDBUF, conf.SndBufSize), + ) + } + + if conf.RcvBufSize > 0 { + opts = append( + opts, + newSetSockOptFunc("SO_RCVBUF", unix.SOL_SOCKET, unix.SO_RCVBUF, conf.RcvBufSize), + ) + } + var opErr error err = c.Control(func(fd uintptr) { - opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + fdInt := int(fd) + for _, opt := range opts { + opErr = opt(fdInt) + if opErr != nil { + return + } + } }) - if err != nil { - return err - } return errors.WithDeferred(opErr, err) } diff --git a/internal/dnsserver/netext/listenconfig_unix_test.go b/internal/dnsserver/netext/listenconfig_unix_test.go index caa4f6f..d9e3155 100644 --- a/internal/dnsserver/netext/listenconfig_unix_test.go +++ b/internal/dnsserver/netext/listenconfig_unix_test.go @@ -15,7 +15,7 @@ import ( ) func TestDefaultListenConfigWithOOB(t *testing.T) { - lc := netext.DefaultListenConfigWithOOB() + lc := netext.DefaultListenConfigWithOOB(nil) require.NotNil(t, lc) type syscallConner interface { @@ -65,3 +65,81 @@ func TestDefaultListenConfigWithOOB(t *testing.T) { require.NoError(t, err) }) } + +func TestDefaultListenConfigWithSO(t *testing.T) { + const ( + sndBufSize = 10000 + rcvBufSize = 20000 + ) + + lc := netext.DefaultListenConfigWithOOB(&netext.ControlConfig{ + SndBufSize: sndBufSize, + RcvBufSize: rcvBufSize, + }) + require.NotNil(t, lc) + + type syscallConner interface { + SyscallConn() (c syscall.RawConn, err error) + } + + t.Run("ipv4", func(t *testing.T) { + c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0") + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + require.NoError(t, opErr) + + // TODO(a.garipov): Rewrite this to use actual expected values for + // each OS. + assert.LessOrEqual(t, sndBufSize, val) + }) + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + require.NoError(t, opErr) + + assert.LessOrEqual(t, rcvBufSize, val) + }) + require.NoError(t, err) + }) + + t.Run("ipv6", func(t *testing.T) { + c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0") + if errors.Is(err, syscall.EADDRNOTAVAIL) { + // Some CI machines have IPv6 disabled. + t.Skipf("ipv6 seems to not be supported: %s", err) + } + + require.NoError(t, err) + require.NotNil(t, c) + require.Implements(t, (*syscallConner)(nil), c) + + sc, err := c.(syscallConner).SyscallConn() + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + require.NoError(t, opErr) + + // TODO(a.garipov): Rewrite this to use actual expected values for + // each OS. + assert.LessOrEqual(t, sndBufSize, val) + }) + require.NoError(t, err) + + err = sc.Control(func(fd uintptr) { + val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + require.NoError(t, opErr) + + assert.LessOrEqual(t, rcvBufSize, val) + }) + require.NoError(t, err) + }) +} diff --git a/internal/dnsserver/netext/listenconfig_windows.go b/internal/dnsserver/netext/listenconfig_windows.go index 20e11bf..8172b4a 100644 --- a/internal/dnsserver/netext/listenconfig_windows.go +++ b/internal/dnsserver/netext/listenconfig_windows.go @@ -7,9 +7,9 @@ import ( "syscall" ) -// defaultListenControl is nil on Windows, because it doesn't support -// SO_REUSEPORT. -var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error) +// listenControlWithSO is nil on Windows, because it doesn't support socket +// options. +var listenControlWithSO func(_ *ControlConfig, _ syscall.RawConn) (_ error) // setIPOpts sets the IPv4 and IPv6 options on a packet connection. func setIPOpts(c net.PacketConn) (err error) { diff --git a/internal/dnsserver/netext/packetconn_linux_test.go b/internal/dnsserver/netext/packetconn_linux_test.go index 6999445..e315457 100644 --- a/internal/dnsserver/netext/packetconn_linux_test.go +++ b/internal/dnsserver/netext/packetconn_linux_test.go @@ -51,7 +51,7 @@ func TestSessionPacketConn(t *testing.T) { } func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) { - lc := netext.DefaultListenConfigWithOOB() + lc := netext.DefaultListenConfigWithOOB(nil) require.NotNil(t, lc) c, err := lc.ListenPacket(context.Background(), proto, addr) diff --git a/internal/dnsserver/prometheus/forward_test.go b/internal/dnsserver/prometheus/forward_test.go index d2ee285..da5975f 100644 --- a/internal/dnsserver/prometheus/forward_test.go +++ b/internal/dnsserver/prometheus/forward_test.go @@ -24,7 +24,7 @@ func TestForwardMetricsListener_integration_request(t *testing.T) { Address: netip.MustParseAddrPort(addr), Network: forward.NetworkAny, MetricsListener: prometheus.NewForwardMetricsListener(0), - }, true) + }) // Prepare a test DNS message and call the handler's ServeDNS function. // It will then call the metrics listener and prom metrics should be diff --git a/internal/dnsserver/serverbench_test.go b/internal/dnsserver/serverbench_test.go index ad59fe7..2e0d14c 100644 --- a/internal/dnsserver/serverbench_test.go +++ b/internal/dnsserver/serverbench_test.go @@ -308,7 +308,7 @@ func BenchmarkServeQUIC(b *testing.B) { } // Open QUIC session - sess, err := quic.DialAddr(addr.String(), tlsConfig, nil) + sess, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil) require.NoError(b, err) defer func() { err = sess.CloseWithError(0, "") diff --git a/internal/dnsserver/serverdns.go b/internal/dnsserver/serverdns.go index 88ee67e..1e568a6 100644 --- a/internal/dnsserver/serverdns.go +++ b/internal/dnsserver/serverdns.go @@ -112,7 +112,7 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) { } if conf.ListenConfig == nil { - conf.ListenConfig = netext.DefaultListenConfigWithOOB() + conf.ListenConfig = netext.DefaultListenConfigWithOOB(nil) } s = &ServerDNS{ diff --git a/internal/dnsserver/serverdns_test.go b/internal/dnsserver/serverdns_test.go index 0e6f3d6..c522870 100644 --- a/internal/dnsserver/serverdns_test.go +++ b/internal/dnsserver/serverdns_test.go @@ -411,6 +411,8 @@ func TestServerDNS_integration_udpMsgIgnore(t *testing.T) { } func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) { + t.Parallel() + testCases := []struct { name string buf []byte @@ -459,7 +461,10 @@ func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler()) conn, err := net.Dial("tcp", addr) require.Nil(t, err) diff --git a/internal/dnsserver/serverdnscrypt.go b/internal/dnsserver/serverdnscrypt.go index 84e30d0..4297882 100644 --- a/internal/dnsserver/serverdnscrypt.go +++ b/internal/dnsserver/serverdnscrypt.go @@ -41,7 +41,7 @@ var _ Server = (*ServerDNSCrypt)(nil) // NewServerDNSCrypt creates a new instance of ServerDNSCrypt. func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) { if conf.ListenConfig == nil { - conf.ListenConfig = netext.DefaultListenConfig() + conf.ListenConfig = netext.DefaultListenConfig(nil) } return &ServerDNSCrypt{ diff --git a/internal/dnsserver/serverhttps.go b/internal/dnsserver/serverhttps.go index 7c45652..0bc8fa2 100644 --- a/internal/dnsserver/serverhttps.go +++ b/internal/dnsserver/serverhttps.go @@ -16,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" @@ -85,7 +86,7 @@ type ServerHTTPS struct { h3Server *http3.Server // quicListener is a listener that we use to serve DoH3 requests. - quicListener quic.EarlyListener + quicListener *quic.EarlyListener conf ConfigHTTPS } @@ -98,7 +99,7 @@ func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) { if conf.ListenConfig == nil { // Do not enable OOB here, because ListenPacket is only used by HTTP/3, // and quic-go sets the necessary flags. - conf.ListenConfig = netext.DefaultListenConfig() + conf.ListenConfig = netext.DefaultListenConfig(nil) } s = &ServerHTTPS{ @@ -289,7 +290,7 @@ func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Lis } // serveH3 is launched in a worker goroutine and serves HTTP/3 requests. -func (s *ServerHTTPS) serveH3(ctx context.Context, hs *http3.Server, ql quic.EarlyListener) { +func (s *ServerHTTPS) serveH3(ctx context.Context, hs *http3.Server, ql *quic.EarlyListener) { defer s.wg.Done() // Do not recover from panics here since if this goroutine panics, the @@ -441,10 +442,10 @@ func (h *httpHandler) writeResponse( switch ct { case MimeTypeDoH: buf, err = resp.Pack() - w.Header().Set("Content-Type", MimeTypeDoH) + w.Header().Set(httphdr.ContentType, MimeTypeDoH) case MimeTypeJSON: buf, err = dnsMsgToJSON(resp) - w.Header().Set("Content-Type", MimeTypeJSON) + w.Header().Set(httphdr.ContentType, MimeTypeJSON) default: return fmt.Errorf("invalid content type: %s", ct) } @@ -458,8 +459,8 @@ func (h *httpHandler) writeResponse( // lifetime (see Section 4.2 of [RFC7234]) so that the DoH client is // more likely to use fresh DNS data. maxAge := minimalTTL(resp) - w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", maxAge.Seconds())) - w.Header().Set("Content-Length", strconv.Itoa(len(buf))) + w.Header().Set(httphdr.CacheControl, fmt.Sprintf("max-age=%f", maxAge.Seconds())) + w.Header().Set(httphdr.ContentLength, strconv.Itoa(len(buf))) w.WriteHeader(http.StatusOK) // Write the actual response diff --git a/internal/dnsserver/serverhttps_test.go b/internal/dnsserver/serverhttps_test.go index fd9803e..ad32426 100644 --- a/internal/dnsserver/serverhttps_test.go +++ b/internal/dnsserver/serverhttps_test.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -27,6 +28,8 @@ import ( ) func TestServerHTTPS_integration_serveRequests(t *testing.T) { + t.Parallel() + testCases := []struct { name string method string @@ -101,7 +104,10 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) { }} for _, tc := range testCases { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") srv, err := dnsservertest.RunLocalHTTPSServer( dnsservertest.DefaultHandler(), @@ -356,7 +362,7 @@ func TestServerHTTPS_0RTT(t *testing.T) { // quicConfig with TokenStore set so that 0-RTT was enabled. quicConfig := &quic.Config{ TokenStore: quic.NewLRUTokenStore(1, 10), - Tracer: quicTracer, + Tracer: quicTracer.TracerForConnection, } // ClientSessionCache in the tls.Config must also be set for 0-RTT to work. @@ -516,7 +522,7 @@ func createDoH3Client( tlsCfg *tls.Config, cfg *quic.Config, ) (c quic.EarlyConnection, e error) { - return quic.DialAddrEarlyContext(ctx, httpsAddr.String(), tlsCfg, cfg) + return quic.DialAddrEarly(ctx, httpsAddr.String(), tlsCfg, cfg) }, QuicConfig: quicConfig, TLSClientConfig: tlsConfig, @@ -550,8 +556,8 @@ func createDoHRequest(proto, method string, msg *dns.Msg) (r *http.Request, err return nil, err } - r.Header.Set("Content-Type", dnsserver.MimeTypeDoH) - r.Header.Set("Accept", dnsserver.MimeTypeDoH) + r.Header.Set(httphdr.ContentType, dnsserver.MimeTypeDoH) + r.Header.Set(httphdr.Accept, dnsserver.MimeTypeDoH) return r, nil } @@ -582,8 +588,8 @@ func createJSONRequest( return nil, err } - r.Header.Set("Content-Type", dnsserver.MimeTypeJSON) - r.Header.Set("Accept", dnsserver.MimeTypeJSON) + r.Header.Set(httphdr.ContentType, dnsserver.MimeTypeJSON) + r.Header.Set(httphdr.Accept, dnsserver.MimeTypeJSON) return r, err } diff --git a/internal/dnsserver/serverquic.go b/internal/dnsserver/serverquic.go index 016f4cc..ad0fb93 100644 --- a/internal/dnsserver/serverquic.go +++ b/internal/dnsserver/serverquic.go @@ -81,7 +81,7 @@ type ServerQUIC struct { pool *ants.Pool // quicListener is a listener that we use to accept DoQ connections. - quicListener quic.Listener + quicListener *quic.Listener // bytesPool is a pool to avoid unnecessary allocations when reading // DNS packets. @@ -101,7 +101,7 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) { if conf.ListenConfig == nil { // Do not enable OOB here as quic-go will do that on its own. - conf.ListenConfig = netext.DefaultListenConfig() + conf.ListenConfig = netext.DefaultListenConfig(nil) } s = &ServerQUIC{ @@ -220,7 +220,7 @@ func (s *ServerQUIC) startServeQUIC(ctx context.Context) { } // serveQUIC listens for incoming QUIC connections. -func (s *ServerQUIC) serveQUIC(ctx context.Context, l quic.Listener) (err error) { +func (s *ServerQUIC) serveQUIC(ctx context.Context, l *quic.Listener) (err error) { connWg := &sync.WaitGroup{} // Wait until all conns are processed before exiting this method defer connWg.Wait() @@ -261,7 +261,7 @@ func (s *ServerQUIC) serveQUIC(ctx context.Context, l quic.Listener) (err error) // acceptQUICConn is a wrapper around quic.Listener.Accept that makes sure that the // timeout is handled properly. -func acceptQUICConn(ctx context.Context, l quic.Listener) (conn quic.Connection, err error) { +func acceptQUICConn(ctx context.Context, l *quic.Listener) (conn quic.Connection, err error) { ctx, cancel := context.WithDeadline(ctx, time.Now().Add(DefaultReadTimeout)) defer cancel() @@ -662,9 +662,7 @@ func newServerQUICConfig(metrics MetricsListener) (conf *quic.Config) { RequireAddressValidation: v.requiresValidation, // Enable 0-RTT by default for all addresses, it's beneficial for the // performance. - Allow0RTT: func(net.Addr) (ok bool) { - return true - }, + Allow0RTT: true, } } diff --git a/internal/dnsserver/serverquic_test.go b/internal/dnsserver/serverquic_test.go index cc0a86d..5415c4a 100644 --- a/internal/dnsserver/serverquic_test.go +++ b/internal/dnsserver/serverquic_test.go @@ -34,7 +34,7 @@ func TestServerQUIC_integration_query(t *testing.T) { }) // Open a QUIC connection. - conn, err := quic.DialAddr(addr.String(), tlsConfig, nil) + conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil) require.NoError(t, err) defer testutil.CleanupAndRequireSuccess(t, func() (err error) { @@ -83,7 +83,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) { }) // Open a QUIC connection. - conn, err := quic.DialAddr(addr.String(), tlsConfig, nil) + conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil) require.NoError(t, err) defer func(conn quic.Connection, code quic.ApplicationErrorCode, s string) { @@ -122,7 +122,7 @@ func TestServerQUIC_integration_0RTT(t *testing.T) { // quicConfig with TokenStore set so that 0-RTT was enabled. quicConfig := &quic.Config{ TokenStore: quic.NewLRUTokenStore(1, 10), - Tracer: quicTracer, + Tracer: quicTracer.TracerForConnection, } // ClientSessionCache in the tls.Config must also be set for 0-RTT to work. @@ -156,7 +156,7 @@ func TestServerQUIC_integration_largeQuery(t *testing.T) { }) // Open a QUIC connection. - conn, err := quic.DialAddr(addr.String(), tlsConfig, nil) + conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil) require.NoError(t, err) defer testutil.CleanupAndRequireSuccess(t, func() (err error) { @@ -190,7 +190,7 @@ func testQUICExchange( tlsConfig *tls.Config, quicConfig *quic.Config, ) { - conn, err := quic.DialAddrEarly(addr.String(), tlsConfig, quicConfig) + conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, quicConfig) require.NoError(t, err) defer testutil.CleanupAndRequireSuccess(t, func() (err error) { diff --git a/internal/dnsserver/servertls_test.go b/internal/dnsserver/servertls_test.go index aa37a3c..a97688d 100644 --- a/internal/dnsserver/servertls_test.go +++ b/internal/dnsserver/servertls_test.go @@ -45,6 +45,8 @@ func TestServerTLS_integration_queryTLS(t *testing.T) { } func TestServerTLS_integration_msgIgnore(t *testing.T) { + t.Parallel() + testCases := []struct { name string buf []byte @@ -88,7 +90,10 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tlsConfig := dnsservertest.CreateServerTLSConfig("example.org") h := dnsservertest.DefaultHandler() addr := dnsservertest.RunTLSServer(t, h, tlsConfig) diff --git a/internal/dnssvc/debug.go b/internal/dnssvc/debug.go index 3dd60c7..e40c320 100644 --- a/internal/dnssvc/debug.go +++ b/internal/dnssvc/debug.go @@ -17,15 +17,16 @@ import ( // Debug header name constants. const ( - hdrNameResType = "res-type" - hdrNameRuleListID = "rule-list-id" - hdrNameRule = "rule" - hdrNameClientIP = "client-ip" - hdrNameDeviceID = "device-id" - hdrNameProfileID = "profile-id" - hdrNameCountry = "country" - hdrNameASN = "asn" - hdrNameHost = "adguard-dns.com." + hdrNameResType = "res-type" + hdrNameRuleListID = "rule-list-id" + hdrNameRule = "rule" + hdrNameClientIP = "client-ip" + hdrNameDeviceID = "device-id" + hdrNameProfileID = "profile-id" + hdrNameCountry = "country" + hdrNameASN = "asn" + hdrNameSubdivision = "subdivision" + hdrNameHost = "adguard-dns.com." ) // writeDebugResponse writes the debug response to rw. @@ -97,17 +98,41 @@ func (svc *Service) appendDebugExtraFromContext( } } - if d := ri.Location; d != nil { - setQuestionName(debugReq, "", hdrNameCountry) - err = svc.messages.AppendDebugExtra(debugReq, resp, string(d.Country)) + if loc := ri.Location; loc != nil { + err = svc.appendDebugExtraFromLocation(loc, debugReq, resp) if err != nil { - return fmt.Errorf("adding %s extra: %w", hdrNameCountry, err) + // Don't wrap the error, because it's informative enough as is. + return err } + } - setQuestionName(debugReq, "", hdrNameASN) - err = svc.messages.AppendDebugExtra(debugReq, resp, strconv.FormatUint(uint64(d.ASN), 10)) + return nil +} + +// appendDebugExtraFromLocation adds debug info to response got from request +// info location. loc should not be nil. +func (svc *Service) appendDebugExtraFromLocation( + loc *agd.Location, + debugReq *dns.Msg, + resp *dns.Msg, +) (err error) { + setQuestionName(debugReq, "", hdrNameCountry) + err = svc.messages.AppendDebugExtra(debugReq, resp, string(loc.Country)) + if err != nil { + return fmt.Errorf("adding %s extra: %w", hdrNameCountry, err) + } + + setQuestionName(debugReq, "", hdrNameASN) + err = svc.messages.AppendDebugExtra(debugReq, resp, strconv.FormatUint(uint64(loc.ASN), 10)) + if err != nil { + return fmt.Errorf("adding %s extra: %w", hdrNameASN, err) + } + + if subdivision := loc.TopSubdivision; subdivision != "" { + setQuestionName(debugReq, "", hdrNameSubdivision) + err = svc.messages.AppendDebugExtra(debugReq, resp, subdivision) if err != nil { - return fmt.Errorf("adding %s extra: %w", hdrNameASN, err) + return fmt.Errorf("adding %s extra: %w", hdrNameSubdivision, err) } } diff --git a/internal/dnssvc/debug_internal_test.go b/internal/dnssvc/debug_internal_test.go index 91c5eb5..46c9547 100644 --- a/internal/dnssvc/debug_internal_test.go +++ b/internal/dnssvc/debug_internal_test.go @@ -26,7 +26,7 @@ func newTXTExtra(strs [][2]string) (extra []dns.RR) { Name: v[0], Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, - Ttl: uint32(agdtest.FilteredResponseTTL.Seconds()), + Ttl: agdtest.FilteredResponseTTLSec, }, Txt: []string{v[1]}, }) @@ -47,6 +47,7 @@ func TestService_writeDebugResponse(t *testing.T) { blockRule = "||example.com^" ) + clientIPStr := testClientIP.String() testCases := []struct { name string ri *agd.RequestInfo @@ -59,7 +60,7 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: nil, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"resp.res-type.adguard-dns.com.", "normal"}, }), }, { @@ -68,7 +69,7 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule}, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"req.res-type.adguard-dns.com.", "blocked"}, {"req.rule.adguard-dns.com.", "||example.com^"}, {"req.rule-list-id.adguard-dns.com.", "fl1"}, @@ -79,7 +80,7 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: nil, respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule}, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"resp.res-type.adguard-dns.com.", "blocked"}, {"resp.rule.adguard-dns.com.", "||example.com^"}, {"resp.rule-list-id.adguard-dns.com.", "fl2"}, @@ -90,7 +91,7 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: &filter.ResultAllowed{}, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"req.res-type.adguard-dns.com.", "allowed"}, {"req.rule.adguard-dns.com.", ""}, {"req.rule-list-id.adguard-dns.com.", ""}, @@ -101,7 +102,7 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: nil, respRes: &filter.ResultAllowed{}, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"resp.res-type.adguard-dns.com.", "allowed"}, {"resp.rule.adguard-dns.com.", ""}, {"resp.rule-list-id.adguard-dns.com.", ""}, @@ -114,31 +115,31 @@ func TestService_writeDebugResponse(t *testing.T) { }, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, + {"client-ip.adguard-dns.com.", clientIPStr}, {"req.res-type.adguard-dns.com.", "modified"}, {"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"}, {"req.rule-list-id.adguard-dns.com.", ""}, }), }, { name: "device", - ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}}, + ri: &agd.RequestInfo{Device: &agd.Device{ID: testDeviceID}}, reqRes: nil, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, - {"device-id.adguard-dns.com.", "dev1234"}, + {"client-ip.adguard-dns.com.", clientIPStr}, + {"device-id.adguard-dns.com.", testDeviceID}, {"resp.res-type.adguard-dns.com.", "normal"}, }), }, { name: "profile", ri: &agd.RequestInfo{ - Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")}, + Profile: &agd.Profile{ID: testProfileID}, }, reqRes: nil, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, - {"profile-id.adguard-dns.com.", "some-profile-id"}, + {"client-ip.adguard-dns.com.", clientIPStr}, + {"profile-id.adguard-dns.com.", testProfileID}, {"resp.res-type.adguard-dns.com.", "normal"}, }), }, { @@ -147,11 +148,25 @@ func TestService_writeDebugResponse(t *testing.T) { reqRes: nil, respRes: nil, wantExtra: newTXTExtra([][2]string{ - {"client-ip.adguard-dns.com.", "1.2.3.4"}, - {"country.adguard-dns.com.", "AD"}, + {"client-ip.adguard-dns.com.", clientIPStr}, + {"country.adguard-dns.com.", string(agd.CountryAD)}, {"asn.adguard-dns.com.", "0"}, {"resp.res-type.adguard-dns.com.", "normal"}, }), + }, { + name: "location_subdivision", + ri: &agd.RequestInfo{ + Location: &agd.Location{Country: agd.CountryAD, TopSubdivision: "CA"}, + }, + reqRes: nil, + respRes: nil, + wantExtra: newTXTExtra([][2]string{ + {"client-ip.adguard-dns.com.", clientIPStr}, + {"country.adguard-dns.com.", string(agd.CountryAD)}, + {"asn.adguard-dns.com.", "0"}, + {"subdivision.adguard-dns.com.", "CA"}, + {"resp.res-type.adguard-dns.com.", "normal"}, + }), }} for _, tc := range testCases { diff --git a/internal/dnssvc/deviceid_internal_test.go b/internal/dnssvc/deviceid_internal_test.go index f70ebe0..ebb6559 100644 --- a/internal/dnssvc/deviceid_internal_test.go +++ b/internal/dnssvc/deviceid_internal_test.go @@ -46,8 +46,8 @@ func TestService_Wrap_deviceID(t *testing.T) { proto: agd.ProtoDoT, }, { name: "tls_device_id", - cliSrvName: "dev.dns.example.com", - wantDeviceID: "dev", + cliSrvName: testDeviceID + ".dns.example.com", + wantDeviceID: testDeviceID, wantErrMsg: "", wildcards: []string{"*.dns.example.com"}, proto: agd.ProtoDoT, @@ -61,7 +61,7 @@ func TestService_Wrap_deviceID(t *testing.T) { proto: agd.ProtoDoT, }, { name: "tls_deep_subdomain", - cliSrvName: "abc.def.dns.example.com", + cliSrvName: "abc." + testDeviceID + ".dns.example.com", wantDeviceID: "", wantErrMsg: "", wildcards: []string{"*.dns.example.com"}, @@ -79,8 +79,8 @@ func TestService_Wrap_deviceID(t *testing.T) { proto: agd.ProtoDoT, }, { name: "quic_device_id", - cliSrvName: "dev.dns.example.com", - wantDeviceID: "dev", + cliSrvName: testDeviceID + ".dns.example.com", + wantDeviceID: testDeviceID, wantErrMsg: "", wildcards: []string{"*.dns.example.com"}, proto: agd.ProtoDoQ, @@ -93,8 +93,8 @@ func TestService_Wrap_deviceID(t *testing.T) { proto: agd.ProtoDoT, }, { name: "tls_device_id_subdomain_wildcard", - cliSrvName: "dev.sub.dns.example.com", - wantDeviceID: "dev", + cliSrvName: testDeviceID + ".sub.dns.example.com", + wantDeviceID: testDeviceID, wantErrMsg: "", wildcards: []string{ "*.dns.example.com", @@ -135,13 +135,13 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) { wantErrMsg: "", }, { name: "device_id", - path: "/dns-query/cli", - wantDeviceID: "cli", + path: "/dns-query/" + testDeviceID, + wantDeviceID: testDeviceID, wantErrMsg: "", }, { name: "device_id_slash", - path: "/dns-query/cli/", - wantDeviceID: "cli", + path: "/dns-query/" + testDeviceID + "/", + wantDeviceID: testDeviceID, wantErrMsg: "", }, { name: "bad_url", @@ -150,9 +150,9 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) { wantErrMsg: `http url device id check: bad path "/foo"`, }, { name: "extra", - path: "/dns-query/cli/foo", + path: "/dns-query/" + testDeviceID + "/foo", wantDeviceID: "", - wantErrMsg: `http url device id check: bad path "/dns-query/cli/foo": ` + + wantErrMsg: `http url device id check: bad path "/dns-query/` + testDeviceID + `/foo": ` + `extra parts`, }, { name: "bad_device_id", @@ -184,11 +184,9 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) { } t.Run("domain_name", func(t *testing.T) { - const want = "dev" - u := &url.URL{ Scheme: "https", - Host: want + ".dns.example.com", + Host: testDeviceID + ".dns.example.com", Path: "/dns-query", } @@ -204,7 +202,7 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) { deviceID, err := deviceIDFromContext(ctx, proto, []string{"*.dns.example.com"}) require.NoError(t, err) - assert.Equal(t, agd.DeviceID(want), deviceID) + assert.Equal(t, agd.DeviceID(testDeviceID), deviceID) }) } diff --git a/internal/dnssvc/dnssvc.go b/internal/dnssvc/dnssvc.go index a5872ad..ca715dc 100644 --- a/internal/dnssvc/dnssvc.go +++ b/internal/dnssvc/dnssvc.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/billstat" + "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter" "github.com/AdguardTeam/AdGuardDNS/internal/dnscheck" "github.com/AdguardTeam/AdGuardDNS/internal/dnsdb" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" @@ -20,6 +21,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit" "github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/geoip" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/AdGuardDNS/internal/querylog" "github.com/AdguardTeam/AdGuardDNS/internal/rulestat" "github.com/AdguardTeam/golibs/errors" @@ -38,15 +40,22 @@ type Config struct { // messages for this DNS service. Messages *dnsmsg.Constructor - // SafeBrowsing is the safe browsing TXT server. - SafeBrowsing *filter.SafeBrowsingServer + // ControlConf is the configuration of socket options. + ControlConf *netext.ControlConfig + + // ConnLimiter, if not nil, is used to limit the number of simultaneously + // active stream-connections. + ConnLimiter *connlimiter.Limiter + + // SafeBrowsing is the safe browsing TXT hash matcher. + SafeBrowsing filter.HashMatcher // BillStat is used to collect billing statistics. BillStat billstat.Recorder // ProfileDB is the AdGuard DNS profile database used to fetch data about // profiles, devices, and so on. - ProfileDB agd.ProfileDB + ProfileDB profiledb.Interface // DNSCheck is used by clients to check if they use AdGuard DNS. DNSCheck dnscheck.Interface @@ -75,14 +84,11 @@ type Config struct { // rule lists. RuleStat rulestat.Interface - // Upstream defines the upstream server and the group of fallback servers. - Upstream *agd.Upstream - // NewListener, when set, is used instead of the package-level function // NewListener when creating a DNS listener. // // TODO(a.garipov): The handler and service logic should really not be - // internwined in this way. See AGDNS-1327. + // intertwined in this way. See AGDNS-1327. NewListener NewListenerFunc // Handler is used as the main DNS handler instead of a simple forwarder. @@ -170,9 +176,9 @@ func New(c *Config) (svc *Service, err error) { dnsHdlr := dnsserver.WithMiddlewares( handler, &preServiceMw{ - messages: c.Messages, - filter: c.SafeBrowsing, - checker: c.DNSCheck, + messages: c.Messages, + hashMatcher: c.SafeBrowsing, + checker: c.DNSCheck, }, svc, ) @@ -475,8 +481,13 @@ func newServers( name := listenerName(s.Name, addr, s.Protocol) + lc := bindData.ListenConfig + if lc == nil { + lc = newListenConfig(c.ControlConf, c.ConnLimiter, s.Protocol) + } + var l Listener - l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, bindData.ListenConfig) + l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, lc) if err != nil { return nil, fmt.Errorf("server %q: %w", s.Name, err) } @@ -496,3 +507,26 @@ func newServers( return servers, nil } + +// newListenConfig returns the netext.ListenConfig used by the plain-DNS +// servers. The resulting ListenConfig sets additional socket flags and +// processes the control messages of connections created with ListenPacket. +// Additionally, if l is not nil, it is used to limit the number of +// simultaneously active stream-connections. +func newListenConfig( + ctrlConf *netext.ControlConfig, + l *connlimiter.Limiter, + p agd.Protocol, +) (lc netext.ListenConfig) { + if p == agd.ProtoDNS { + lc = netext.DefaultListenConfigWithOOB(ctrlConf) + } else { + lc = netext.DefaultListenConfig(ctrlConf) + } + + if l != nil { + lc = connlimiter.NewListenConfig(lc, l) + } + + return lc +} diff --git a/internal/dnssvc/dnssvc_internal_test.go b/internal/dnssvc/dnssvc_internal_test.go new file mode 100644 index 0000000..5138cd9 --- /dev/null +++ b/internal/dnssvc/dnssvc_internal_test.go @@ -0,0 +1,26 @@ +package dnssvc + +import ( + "net" + "net/netip" +) + +// Common addresses for tests. +var ( + testClientIP = net.IP{1, 2, 3, 4} + testRAddr = &net.TCPAddr{ + IP: testClientIP, + Port: 12345, + } + + testClientAddrPort = testRAddr.AddrPort() + testClientAddr = testClientAddrPort.Addr() + + testServerAddr = netip.MustParseAddr("5.6.7.8") +) + +// testDeviceID is the common device ID for tests +const testDeviceID = "dev1234" + +// testProfileID is the common profile ID for tests +const testProfileID = "prof1234" diff --git a/internal/dnssvc/dnssvc_test.go b/internal/dnssvc/dnssvc_test.go index c98aa78..ada8233 100644 --- a/internal/dnssvc/dnssvc_test.go +++ b/internal/dnssvc/dnssvc_test.go @@ -173,12 +173,6 @@ func TestService_Start(t *testing.T) { Name: "test_group", Servers: []*agd.Server{srv}, }}, - Upstream: &agd.Upstream{ - Server: netip.MustParseAddrPort("8.8.8.8:53"), - FallbackServers: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:53"), - }, - }, } svc, err := dnssvc.New(c) @@ -246,12 +240,6 @@ func TestNew(t *testing.T) { Name: "test_group", Servers: srvs, }}, - Upstream: &agd.Upstream{ - Server: netip.MustParseAddrPort("8.8.8.8:53"), - FallbackServers: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:53"), - }, - }, } svc, err := dnssvc.New(c) diff --git a/internal/dnssvc/initmw.go b/internal/dnssvc/initmw.go index c4dbd36..4272a59 100644 --- a/internal/dnssvc/initmw.go +++ b/internal/dnssvc/initmw.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/geoip" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" @@ -27,8 +28,6 @@ import ( // one. // // TODO(a.garipov): Add tests. -// -// TODO(a.garipov): Make other middlewares more compact as well. See AGDNS-328. type initMw struct { // messages is used to build the responses specific for the request's // context. @@ -43,8 +42,8 @@ type initMw struct { // srv is the current server which serves the request. srv *agd.Server - // db is the storage of user profiles. - db agd.ProfileDB + // db is the database of user profiles and devices. + db profiledb.Interface // geoIP detects the location of the request source. geoIP geoip.Interface @@ -69,6 +68,7 @@ func (mw *initMw) Wrap(h dnsserver.Handler) (wrapped dnsserver.Handler) { func (mw *initMw) newRequestInfo( ctx context.Context, req *dns.Msg, + laddr net.Addr, raddr net.Addr, fqdn string, qt dnsmsg.RRType, @@ -100,7 +100,8 @@ func (mw *initMw) newRequestInfo( } // Add the profile information, if any. - err = mw.addProfile(ctx, ri, req) + localIP := netutil.NetAddrToAddrPort(laddr).Addr() + err = mw.addProfile(ctx, ri, req, localIP) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err @@ -149,7 +150,12 @@ func (mw *initMw) locationData(ctx context.Context, ip netip.Addr, typ string) ( // addProfile adds profile and device information, if any, to the request // information. -func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) { +func (mw *initMw) addProfile( + ctx context.Context, + ri *agd.RequestInfo, + req *dns.Msg, + localIP netip.Addr, +) (err error) { defer func() { err = errors.Annotate(err, "getting profile from req: %w") }() var id agd.DeviceID @@ -170,14 +176,12 @@ func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns. optlog.Debug2("init mw: got device id %q and ip %s", id, ri.RemoteIP) - prof, dev, byWhat, err := mw.profile(ctx, ri.RemoteIP, id, mw.srv.Protocol) + prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id) if err != nil { - // Use two errors.Is calls to prevent unnecessary allocations. - if !errors.Is(err, agd.DeviceNotFoundError{}) && - !errors.Is(err, agd.ProfileNotFoundError{}) { + if !errors.Is(err, profiledb.ErrDeviceNotFound) { // Very unlikely, since those two error types are the only ones - // currently returned from the profile DB. - return err + // currently returned from the default profile DB. + return fmt.Errorf("unexpected profiledb error: %s", err) } optlog.Debug1("init mw: profile or device not found: %s", err) @@ -193,32 +197,51 @@ func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns. return nil } +// Constants for the parameter by which a device has been found. +const ( + byDeviceID = "device id" + byDedicatedIP = "dedicated ip" + byLinkedIP = "linked ip" +) + // profile finds the profile by the client data. func (mw *initMw) profile( ctx context.Context, - ip netip.Addr, + localIP netip.Addr, + remoteIP netip.Addr, id agd.DeviceID, - p agd.Protocol, ) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) { if id != "" { prof, dev, err = mw.db.ProfileByDeviceID(ctx, id) + if err != nil { + return nil, nil, "", err + } - return prof, dev, "device id", err + return prof, dev, byDeviceID, nil } if !mw.srv.LinkedIPEnabled { - optlog.Debug1("init mw: not matching by ip for server %s", mw.srv.Name) + optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name) - return nil, nil, "", agd.ProfileNotFoundError{} - } else if p != agd.ProtoDNS { - optlog.Debug1("init mw: not matching by ip for proto %v", p) + return nil, nil, "", profiledb.ErrDeviceNotFound + } else if p := mw.srv.Protocol; p != agd.ProtoDNS { + optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p) - return nil, nil, "", agd.ProfileNotFoundError{} + return nil, nil, "", profiledb.ErrDeviceNotFound } - prof, dev, err = mw.db.ProfileByIP(ctx, ip) + byWhat = byDedicatedIP + prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP) + if errors.Is(err, profiledb.ErrDeviceNotFound) { + byWhat = byLinkedIP + prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP) + } - return prof, dev, "linked ip", err + if err != nil { + return nil, nil, "", err + } + + return prof, dev, byWhat, nil } // initMwHandler implements the [dnsserver.Handler] interface and will be used @@ -259,7 +282,7 @@ func (mh *initMwHandler) ServeDNS( mw := mh.mw // Get the request's information, such as GeoIP data and user profiles. - ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl) + ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), fqdn, qt, cl) if err != nil { var ecsErr dnsmsg.BadECSError if errors.As(err, &ecsErr) { diff --git a/internal/dnssvc/initmw_internal_test.go b/internal/dnssvc/initmw_internal_test.go index 74ffb13..fd10fce 100644 --- a/internal/dnssvc/initmw_internal_test.go +++ b/internal/dnssvc/initmw_internal_test.go @@ -3,7 +3,6 @@ package dnssvc import ( "context" "crypto/tls" - "net" "net/netip" "testing" @@ -11,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" @@ -21,10 +21,148 @@ import ( "golang.org/x/exp/maps" ) -// testRAddr is the common remote address for tests. -var testRAddr = &net.TCPAddr{ - IP: net.IP{1, 2, 3, 4}, - Port: 12345, +func TestInitMw_profile(t *testing.T) { + prof := &agd.Profile{ + ID: testProfileID, + DeviceIDs: []agd.DeviceID{ + testDeviceID, + }, + } + dev := &agd.Device{ + ID: testDeviceID, + LinkedIP: testClientAddr, + DedicatedIPs: []netip.Addr{ + testServerAddr, + }, + } + + testCases := []struct { + wantDev *agd.Device + wantProf *agd.Profile + wantByWhat string + wantErrMsg string + name string + id agd.DeviceID + proto agd.Protocol + linkedIPEnabled bool + }{{ + wantDev: nil, + wantProf: nil, + wantByWhat: "", + wantErrMsg: "device not found", + name: "no_device_id", + id: "", + proto: agd.ProtoDNS, + linkedIPEnabled: true, + }, { + wantDev: dev, + wantProf: prof, + wantByWhat: byDeviceID, + wantErrMsg: "", + name: "device_id", + id: testDeviceID, + proto: agd.ProtoDNS, + linkedIPEnabled: true, + }, { + wantDev: dev, + wantProf: prof, + wantByWhat: byLinkedIP, + wantErrMsg: "", + name: "linked_ip", + id: "", + proto: agd.ProtoDNS, + linkedIPEnabled: true, + }, { + wantDev: nil, + wantProf: nil, + wantByWhat: "", + wantErrMsg: "device not found", + name: "linked_ip_dot", + id: "", + proto: agd.ProtoDoT, + linkedIPEnabled: true, + }, { + wantDev: nil, + wantProf: nil, + wantByWhat: "", + wantErrMsg: "device not found", + name: "linked_ip_disabled", + id: "", + proto: agd.ProtoDoT, + linkedIPEnabled: false, + }, { + wantDev: dev, + wantProf: prof, + wantByWhat: byDedicatedIP, + wantErrMsg: "", + name: "dedicated_ip", + id: "", + proto: agd.ProtoDNS, + linkedIPEnabled: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + db := &agdtest.ProfileDB{ + OnProfileByDeviceID: func( + _ context.Context, + gotID agd.DeviceID, + ) (p *agd.Profile, d *agd.Device, err error) { + assert.Equal(t, tc.id, gotID) + + if tc.wantByWhat == byDeviceID { + return prof, dev, nil + } + + return nil, nil, profiledb.ErrDeviceNotFound + }, + OnProfileByDedicatedIP: func( + _ context.Context, + gotLocalIP netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + assert.Equal(t, testServerAddr, gotLocalIP) + + if tc.wantByWhat == byDedicatedIP { + return prof, dev, nil + } + + return nil, nil, profiledb.ErrDeviceNotFound + }, + OnProfileByLinkedIP: func( + _ context.Context, + gotRemoteIP netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + assert.Equal(t, testClientAddr, gotRemoteIP) + + if tc.wantByWhat == byLinkedIP { + return prof, dev, nil + } + + return nil, nil, profiledb.ErrDeviceNotFound + }, + } + + mw := &initMw{ + srv: &agd.Server{ + Protocol: tc.proto, + LinkedIPEnabled: tc.linkedIPEnabled, + }, + db: db, + } + ctx := context.Background() + + gotProf, gotDev, gotByWhat, err := mw.profile( + ctx, + testServerAddr, + testClientAddr, + tc.id, + ) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + assert.Equal(t, tc.wantProf, gotProf) + assert.Equal(t, tc.wantDev, gotDev) + assert.Equal(t, tc.wantByWhat, gotByWhat) + }) + } } func TestInitMw_ServeDNS_ddr(t *testing.T) { @@ -32,15 +170,14 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) { resolverName = "dns.example.com" resolverFQDN = resolverName + "." - deviceID = "dev1234" - targetWithID = deviceID + ".d." + resolverName + "." + targetWithID = testDeviceID + ".d." + resolverName + "." ddrFQDN = ddrDomain + "." dohPath = "/dns-query" ) - testDevice := &agd.Device{ID: deviceID} + testDevice := &agd.Device{ID: testDeviceID} srvs := map[agd.ServerName]*agd.Server{ "dot": { @@ -98,15 +235,21 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) { _ context.Context, _ agd.DeviceID, ) (p *agd.Profile, d *agd.Device, err error) { - p = &agd.Profile{Devices: []*agd.Device{dev}} + p = &agd.Profile{} return p, dev, nil }, - OnProfileByIP: func( + OnProfileByDedicatedIP: func( _ context.Context, _ netip.Addr, ) (p *agd.Profile, d *agd.Device, err error) { - p = &agd.Profile{Devices: []*agd.Device{dev}} + return nil, nil, profiledb.ErrDeviceNotFound + }, + OnProfileByLinkedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + p = &agd.Profile{} return p, dev, nil }, @@ -397,12 +540,12 @@ func TestInitMw_ServeDNS_specialDomain(t *testing.T) { return rw.WriteMsg(ctx, req, resp) }) - onProfileByIP := func( + onProfileByLinkedIP := func( _ context.Context, _ netip.Addr, ) (p *agd.Profile, d *agd.Device, err error) { if !tc.hasProf { - return nil, nil, agd.DeviceNotFoundError{} + return nil, nil, profiledb.ErrDeviceNotFound } prof := &agd.Profile{ @@ -419,7 +562,13 @@ func TestInitMw_ServeDNS_specialDomain(t *testing.T) { ) (p *agd.Profile, d *agd.Device, err error) { panic("not implemented") }, - OnProfileByIP: onProfileByIP, + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + return nil, nil, profiledb.ErrDeviceNotFound + }, + OnProfileByLinkedIP: onProfileByLinkedIP, } geoIP := &agdtest.GeoIP{ @@ -543,7 +692,7 @@ func BenchmarkInitMw_Wrap(b *testing.B) { ctx := context.Background() ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{ - TLSServerName: "dev1234.dns.example.com", + TLSServerName: testDeviceID + ".dns.example.com", }) req := &dns.Msg{ @@ -573,6 +722,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) { ) (p *agd.Profile, d *agd.Device, err error) { return prof, dev, nil }, + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, + OnProfileByLinkedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, } b.Run("success", func(b *testing.B) { b.ReportAllocs() @@ -589,7 +750,19 @@ func BenchmarkInitMw_Wrap(b *testing.B) { _ context.Context, _ agd.DeviceID, ) (p *agd.Profile, d *agd.Device, err error) { - return nil, nil, agd.ProfileNotFoundError{} + return nil, nil, profiledb.ErrDeviceNotFound + }, + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, + OnProfileByLinkedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") }, } b.Run("not_found", func(b *testing.B) { @@ -616,6 +789,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) { ) (p *agd.Profile, d *agd.Device, err error) { return prof, dev, nil }, + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, + OnProfileByLinkedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, } b.Run("firefox_canary", func(b *testing.B) { b.ReportAllocs() @@ -630,13 +815,13 @@ func BenchmarkInitMw_Wrap(b *testing.B) { ddrReq := &dns.Msg{ Question: []dns.Question{{ // Check the worst case when wildcards are checked. - Name: "_dns.dev1234.dns.example.com.", + Name: "_dns." + testDeviceID + ".dns.example.com.", Qtype: dns.TypeSVCB, Qclass: dns.ClassINET, }}, } devWithID := &agd.Device{ - ID: "dev1234", + ID: testDeviceID, } mw.db = &agdtest.ProfileDB{ OnProfileByDeviceID: func( @@ -645,6 +830,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) { ) (p *agd.Profile, d *agd.Device, err error) { return prof, devWithID, nil }, + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, + OnProfileByLinkedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, } b.Run("ddr", func(b *testing.B) { b.ReportAllocs() diff --git a/internal/dnssvc/middleware.go b/internal/dnssvc/middleware.go index 5d0719b..f1a3294 100644 --- a/internal/dnssvc/middleware.go +++ b/internal/dnssvc/middleware.go @@ -3,6 +3,7 @@ package dnssvc import ( "context" "strconv" + "strings" "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" @@ -11,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/miekg/dns" + "golang.org/x/exp/slices" ) // Middlewares @@ -52,30 +54,58 @@ func (mh *svcHandler) ServeDNS( optlog.Debug2("processing request %q from %s", reqID, raddr) defer optlog.Debug2("finished processing request %q from %s", reqID, raddr) - // Assume that the cache is always hot and that we can always send the - // request and filter it out along with the response later if we need - // it. + reqInfo := agd.MustRequestInfoFromContext(ctx) + flt := mh.svc.fltStrg.FilterFromContext(ctx, reqInfo) + + modReq, reqRes, elapsedReq := mh.svc.filterRequest(ctx, req, flt, reqInfo) + nwrw := makeNonWriter(rw) - err = mh.next.ServeDNS(ctx, nwrw, req) + if modReq != nil { + // Modified request is set only if the request was modified by a CNAME + // rewrite rule, so resolve the request as if it was for the rewritten + // name. + + // Clone the request informaton and replace the host name with the + // rewritten one, since the request information from current context + // must only be accessed for reading, see [agd.RequestInfo]. Shallow + // copy is enough, because we only change the [agd.RequestInfo.Host] + // field, which is a string. + modReqInfo := &agd.RequestInfo{} + *modReqInfo = *reqInfo + modReqInfo.Host = strings.ToLower(strings.TrimSuffix(modReq.Question[0].Name, ".")) + + modReqCtx := agd.ContextWithRequestInfo(ctx, modReqInfo) + + optlog.Debug2( + "dnssvc: request for %q rewritten to %q by CNAME rewrite rule", + reqInfo.Host, + modReqInfo.Host, + ) + + err = mh.next.ServeDNS(modReqCtx, nwrw, modReq) + } else { + err = mh.next.ServeDNS(ctx, nwrw, req) + } if err != nil { return err } - ri := agd.MustRequestInfoFromContext(ctx) origResp := nwrw.Msg() - reqRes, respRes := mh.svc.filterQuery(ctx, req, origResp, ri) + respRes, elapsedResp := mh.svc.filterResponse(ctx, req, origResp, flt, reqInfo, modReq) + + mh.svc.reportMetrics(reqInfo, reqRes, respRes, elapsedReq+elapsedResp) if isDebug { return mh.svc.writeDebugResponse(ctx, rw, req, origResp, reqRes, respRes) } - resp, err := writeFilteredResp(ctx, ri, rw, req, origResp, reqRes, respRes) + resp, err := writeFilteredResp(ctx, reqInfo, rw, req, origResp, reqRes, respRes) if err != nil { // Don't wrap the error, because it's informative enough as is. return err } - mh.svc.recordQueryInfo(ctx, req, resp, origResp, ri, reqRes, respRes) + mh.svc.recordQueryInfo(ctx, req, resp, origResp, reqInfo, reqRes, respRes) return nil } @@ -91,31 +121,73 @@ func makeNonWriter(rw dnsserver.ResponseWriter) (nwrw *dnsserver.NonWriterRespon return dnsserver.NewNonWriterResponseWriter(rw.LocalAddr(), rw.RemoteAddr()) } -// filterQuery is a wrapper for f.FilterRequest and f.FilterResponse that treats -// filtering errors non-critical. It also records filtering metrics. -func (svc *Service) filterQuery( +// rewrittenRequest returns a request from res in case it's a CNAME rewrite, and +// returns nil otherwise. Note that the returned message is always a request +// since any other rewrite rule type turns into response. +func rewrittenRequest(res filter.Result) (req *dns.Msg) { + if res, ok := res.(*filter.ResultModified); ok && !res.Msg.Response { + return res.Msg + } + + return nil +} + +// filterRequest applies f to req and returns the result of filtering. If the +// result is the CNAME rewrite, it also returns the modified request to resolve. +// It also returns the time elapsed on filtering. +func (svc *Service) filterRequest( ctx context.Context, req *dns.Msg, - origResp *dns.Msg, + f filter.Interface, ri *agd.RequestInfo, -) (reqRes, respRes filter.Result) { +) (modReq *dns.Msg, reqRes filter.Result, elapsed time.Duration) { start := time.Now() - defer func() { - svc.reportMetrics(ri, reqRes, respRes, time.Since(start)) - }() - - f := svc.fltStrg.FilterFromContext(ctx, ri) reqRes, err := f.FilterRequest(ctx, req, ri) if err != nil { svc.reportf(ctx, "filtering request: %w", err) } - respRes, err = f.FilterResponse(ctx, origResp, ri) - if err != nil { - svc.reportf(ctx, "dnssvc: filtering original response: %w", err) + // Consider this operation related to filtering and account the elapsed + // time. + modReq = rewrittenRequest(reqRes) + + return modReq, reqRes, time.Since(start) +} + +// filterResponse applies f to resp and returns the result of filtering. If +// origReq has a different question name than resp, the request assumed being +// CNAME-rewritten and no filtering performed on resp, the CNAME is prepended to +// resp answer section instead. It also returns the time elapsed on filtering. +func (svc *Service) filterResponse( + ctx context.Context, + req *dns.Msg, + resp *dns.Msg, + f filter.Interface, + ri *agd.RequestInfo, + modReq *dns.Msg, +) (respRes filter.Result, elapsed time.Duration) { + start := time.Now() + + if modReq != nil { + // Return the request name to its original state, since it was + // previously rewritten by CNAME rewrite rule. + resp.Question[0] = req.Question[0] + + // Prepend the CNAME answer to the response and don't filter it. + var rr dns.RR = ri.Messages.NewAnswerCNAME(req, modReq.Question[0].Name) + resp.Answer = slices.Insert(resp.Answer, 0, rr) + + // Also consider this operation related to filtering and account the + // elapsed time. + return nil, time.Since(start) } - return reqRes, respRes + respRes, err := f.FilterResponse(ctx, resp, ri) + if err != nil { + svc.reportf(ctx, "filtering response: %w", err) + } + + return respRes, time.Since(start) } // reportMetrics extracts filtering metrics data from the context and reports it @@ -139,7 +211,7 @@ func (svc *Service) reportMetrics( metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc() metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc() - id, _, blocked := filteringData(reqRes, respRes) + id, _, isBlocked := filteringData(reqRes, respRes) metrics.DNSSvcRequestByFilterTotal.WithLabelValues( string(id), metrics.BoolString(ri.Profile == nil), @@ -149,19 +221,7 @@ func (svc *Service) reportMetrics( metrics.DNSSvcUsersCountUpdate(ri.RemoteIP) if svc.researchMetrics { - anonymous := ri.Profile == nil - filteringEnabled := ri.FilteringGroup != nil && - ri.FilteringGroup.RuleListsEnabled && - len(ri.FilteringGroup.RuleListIDs) > 0 - - metrics.ReportResearchMetrics( - anonymous, - filteringEnabled, - asn, - ctry, - string(id), - blocked, - ) + metrics.ReportResearchMetrics(ri, id, isBlocked) } } diff --git a/internal/dnssvc/middleware_test.go b/internal/dnssvc/middleware_test.go index 7615556..b7ea00e 100644 --- a/internal/dnssvc/middleware_test.go +++ b/internal/dnssvc/middleware_test.go @@ -5,6 +5,7 @@ import ( "net" "net/http" "net/netip" + "strings" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc" "github.com/AdguardTeam/AdGuardDNS/internal/filter" "github.com/AdguardTeam/AdGuardDNS/internal/querylog" @@ -22,46 +24,95 @@ import ( "github.com/stretchr/testify/require" ) -func TestService_Wrap_withClient(t *testing.T) { - // Part 1. Server Configuration - // - // Configure a server with fakes to make sure that the wrapped handler - // has all necessary entities and data in place. - // - // TODO(a.garipov): Put this thing into some kind of helper so that we - // could create several such tests. +const ( + // testProfID is the [agd.ProfileID] for tests. + testProfID agd.ProfileID = "prof1234" - const ( - id agd.ProfileID = "prof1234" - devID agd.DeviceID = "dev1234" - fltListID agd.FilterListID = "flt1234" - ) + // testDevID is the [agd.DeviceID] for tests. + testDevID agd.DeviceID = "dev1234" + + // testFltListID is the [agd.FilterListID] for tests. + testFltListID agd.FilterListID = "flt1234" + + // testSrvName is the [agd.ServerName] for tests. + testSrvName agd.ServerName = "test_server_dns_tls" + + // testSrvGrpName is the [agd.ServerGroupName] for tests. + testSrvGrpName agd.ServerGroupName = "test_group" + + // testDevIDWildcard is the wildcard domain for retrieving [agd.DeviceID] in + // tests. Use [strings.ReplaceAll] to replace the "*" symbol with the + // actual [agd.DeviceID]. + testDevIDWildcard string = "*.dns.example.com" +) + +// testTimeout is the common timeout for tests. +const testTimeout time.Duration = 1 * time.Second + +// newTestService creates a new [dnssvc.Service] for tests. The service built +// of stubs, that use the following data: +// +// - A filtering group containing a filter with [testFltListID] and enabled +// rule lists. +// - A device with [testDevID] and enabled filtering. +// - A profile with [testProfID] with enabled filtering and query +// logging, containing the device. +// - GeoIP database always returning [agd.CountryAD], [agd.ContinentEU], and +// ASN of 42. +// - A server with [testSrvName] under group with [testSrvGrpName], matching +// the DeviceID with [testDevIDWildcard]. +// +// Each stub also uses the corresponding channels to send the data it receives +// from the service. If the channel is [nil], the stub ignores it. Each +// sending to a channel wrapped with [testutil.RequireSend] using [testTimeout]. +// +// It also uses the [dnsservertest.DefaultHandler] to create the DNS handler. +func newTestService( + t testing.TB, + flt filter.Interface, + errCollCh chan<- error, + profileDBCh chan<- agd.DeviceID, + querylogCh chan<- *querylog.Entry, + geoIPCh chan<- string, + dnsDBCh chan<- *agd.RequestInfo, + ruleStatCh chan<- agd.FilterRuleText, +) (svc *dnssvc.Service, srvAddr netip.AddrPort) { + t.Helper() + + pt := testutil.PanicT{} dev := &agd.Device{ - ID: devID, + ID: testDevID, FilteringEnabled: true, } prof := &agd.Profile{ - ID: id, - Devices: []*agd.Device{dev}, - RuleListIDs: []agd.FilterListID{fltListID}, - FilteredResponseTTL: 10 * time.Second, + ID: testProfID, + DeviceIDs: []agd.DeviceID{testDevID}, + RuleListIDs: []agd.FilterListID{testFltListID}, + FilteredResponseTTL: agdtest.FilteredResponseTTL, FilteringEnabled: true, QueryLogEnabled: true, } - dbDeviceIDs := make(chan agd.DeviceID, 1) db := &agdtest.ProfileDB{ OnProfileByDeviceID: func( _ context.Context, id agd.DeviceID, ) (p *agd.Profile, d *agd.Device, err error) { - dbDeviceIDs <- id + if profileDBCh != nil { + testutil.RequireSend(pt, profileDBCh, id, testTimeout) + } return prof, dev, nil }, - OnProfileByIP: func( + OnProfileByDedicatedIP: func( + _ context.Context, + _ netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) { + panic("not implemented") + }, + OnProfileByLinkedIP: func( ctx context.Context, ip netip.Addr, ) (p *agd.Profile, d *agd.Device, err error) { @@ -71,24 +122,19 @@ func TestService_Wrap_withClient(t *testing.T) { // Make sure that any panics and errors within handlers are caught and // that they fail the test by panicking. - errCh := make(chan error, 1) - go func() { - pt := testutil.PanicT{} - - err, ok := <-errCh - if !ok { - return - } - - require.NoError(pt, err) - }() - errColl := &agdtest.ErrorCollector{ OnCollect: func(_ context.Context, err error) { - errCh <- err + if errCollCh != nil { + testutil.RequireSend(pt, errCollCh, err, testTimeout) + } }, } + loc := &agd.Location{ + Country: agd.CountryAD, + Continent: agd.ContinentEU, + ASN: 42, + } geoIP := &agdtest.GeoIP{ OnSubnetByLocation: func( _ agd.Country, @@ -97,34 +143,13 @@ func TestService_Wrap_withClient(t *testing.T) { ) (n netip.Prefix, err error) { panic("not implemented") }, - OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) { - return &agd.Location{ - Country: agd.CountryAD, - Continent: agd.ContinentEU, - ASN: 42, - }, nil - }, - } + OnData: func(host string, _ netip.Addr) (l *agd.Location, err error) { + if geoIPCh != nil { + testutil.RequireSend(pt, geoIPCh, host, testTimeout) + } - fltDomainCh := make(chan string, 1) - flt := &agdtest.Filter{ - OnFilterRequest: func( - _ context.Context, - req *dns.Msg, - _ *agd.RequestInfo, - ) (r filter.Result, err error) { - fltDomainCh <- req.Question[0].Name - - return nil, nil + return loc, nil }, - OnFilterResponse: func( - _ context.Context, - _ *dns.Msg, - _ *agd.RequestInfo, - ) (r filter.Result, err error) { - return nil, nil - }, - OnClose: func() (err error) { panic("not implemented") }, } fltStrg := &agdtest.FilterStorage{ @@ -134,23 +159,21 @@ func TestService_Wrap_withClient(t *testing.T) { OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") }, } - logDomainCh := make(chan string, 1) - logQTypeCh := make(chan dnsmsg.RRType, 1) var ql querylog.Interface = &agdtest.QueryLog{ OnWrite: func(_ context.Context, e *querylog.Entry) (err error) { - logDomainCh <- e.DomainFQDN - logQTypeCh <- e.RequestType + if querylogCh != nil { + testutil.RequireSend(pt, querylogCh, e, testTimeout) + } return nil }, } - srvAddr := netip.MustParseAddrPort("94.149.14.14:853") - srvName := agd.ServerName("test_server_dns_tls") + srvAddr = netip.MustParseAddrPort("94.149.14.14:853") srvs := []*agd.Server{{ DNSCrypt: nil, TLS: nil, - Name: srvName, + Name: testSrvName, BindData: []*agd.ServerBindData{{ AddrPort: srvAddr, }}, @@ -161,20 +184,6 @@ func TestService_Wrap_withClient(t *testing.T) { tl.onStart = func(_ context.Context) (err error) { return nil } tl.onShutdown = func(_ context.Context) (err error) { return nil } - var h dnsserver.Handler = dnsserver.HandlerFunc(func( - ctx context.Context, - rw dnsserver.ResponseWriter, - r *dns.Msg, - ) (err error) { - resp := &dns.Msg{} - resp.SetReply(r) - resp.Answer = append(resp.Answer, &dns.A{ - A: net.IP{1, 2, 3, 4}, - }) - - return rw.WriteMsg(ctx, r, resp) - }) - dnsCk := &agdtest.DNSCheck{ OnCheck: func( _ context.Context, @@ -185,17 +194,19 @@ func TestService_Wrap_withClient(t *testing.T) { }, } - numDNSDBReq := 0 dnsDB := &agdtest.DNSDB{ - OnRecord: func(_ context.Context, _ *dns.Msg, _ *agd.RequestInfo) { - numDNSDBReq++ + OnRecord: func(_ context.Context, _ *dns.Msg, ri *agd.RequestInfo) { + if dnsDBCh != nil { + testutil.RequireSend(pt, dnsDBCh, ri, testTimeout) + } }, } - numRuleStatReq := 0 ruleStat := &agdtest.RuleStat{ - OnCollect: func(_ context.Context, _ agd.FilterListID, _ agd.FilterRuleText) { - numRuleStatReq++ + OnCollect: func(_ context.Context, _ agd.FilterListID, text agd.FilterRuleText) { + if ruleStatCh != nil { + testutil.RequireSend(pt, ruleStatCh, text, testTimeout) + } }, } @@ -207,11 +218,13 @@ func TestService_Wrap_withClient(t *testing.T) { ) (drop, allowlisted bool, err error) { return true, false, nil }, - OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {}, + OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) { + panic("not implemented") + }, } - fltGrpID := agd.FilteringGroupID("1234") - srvGrpName := agd.ServerGroupName("test_group") + testFltGrpID := agd.FilteringGroupID("1234") + c := &dnssvc.Config{ Messages: agdtest.NewConstructor(), BillStat: &agdtest.BillStatRecorder{ @@ -234,31 +247,25 @@ func TestService_Wrap_withClient(t *testing.T) { GeoIP: geoIP, QueryLog: ql, RuleStat: ruleStat, - Upstream: &agd.Upstream{ - Server: netip.MustParseAddrPort("8.8.8.8:53"), - FallbackServers: []netip.AddrPort{ - netip.MustParseAddrPort("1.1.1.1:53"), - }, - }, - NewListener: newTestListenerFunc(tl), - Handler: h, - RateLimit: rl, + NewListener: newTestListenerFunc(tl), + Handler: dnsservertest.DefaultHandler(), + RateLimit: rl, FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{ - fltGrpID: { - ID: fltGrpID, - RuleListIDs: []agd.FilterListID{fltListID}, + testFltGrpID: { + ID: testFltGrpID, + RuleListIDs: []agd.FilterListID{testFltListID}, RuleListsEnabled: true, }, }, ServerGroups: []*agd.ServerGroup{{ TLS: &agd.TLS{ - DeviceIDWildcards: []string{"*.dns.example.com"}, + DeviceIDWildcards: []string{testDevIDWildcard}, }, DDR: &agd.DDR{ Enabled: true, }, - Name: srvGrpName, - FilteringGroup: fltGrpID, + Name: testSrvGrpName, + FilteringGroup: testFltGrpID, Servers: srvs, }}, } @@ -266,56 +273,190 @@ func TestService_Wrap_withClient(t *testing.T) { svc, err := dnssvc.New(c) require.NoError(t, err) require.NotNil(t, svc) + + err = svc.Start() + require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, func() (err error) { return svc.Shutdown(context.Background()) }) - err = svc.Start() - require.NoError(t, err) + return svc, srvAddr +} - // Part 2. Testing Proper - // - // Create a context, a request, and a simple handler. Wrap the handler - // and make sure that all processing went as needed. +func TestService_Wrap(t *testing.T) { + profileDBCh := make(chan agd.DeviceID, 1) + querylogCh := make(chan *querylog.Entry, 1) + geoIPCh := make(chan string, 2) + dnsDBCh := make(chan *agd.RequestInfo, 1) + ruleStatCh := make(chan agd.FilterRuleText, 1) + + errCollCh := make(chan error, 1) + go func() { + for err := range errCollCh { + require.NoError(t, err) + } + }() + + const domain = "example.org" + + domainFQDN := dns.Fqdn(domain) - domain := "example.org." reqType := dns.TypeA - req := &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - RecursionDesired: true, - }, - Question: []dns.Question{{ - Name: domain, - Qtype: reqType, - Qclass: dns.ClassINET, - }}, - } + req := dnsservertest.CreateMessage(domain, reqType) + + clientAddr := &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 12345} ctx := context.Background() ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{ - TLSServerName: string(devID) + ".dns.example.com", + TLSServerName: strings.ReplaceAll(testDevIDWildcard, "*", string(testDevID)), }) ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{ Proto: agd.ProtoDoT, }) - ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) - clientAddr := &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 12345} - rw := &testResponseWriter{ - onLocalAddr: func() (a net.Addr) { return net.TCPAddrFromAddrPort(srvAddr) }, - onRemoteAddr: func() (a net.Addr) { return clientAddr }, - onWriteMsg: func(_ context.Context, _, _ *dns.Msg) (err error) { - return nil - }, - } - err = svc.Handle(ctx, srvGrpName, srvName, rw, req) - require.NoError(t, err) + t.Run("simple_success", func(t *testing.T) { + noMatch := func( + _ context.Context, + m *dns.Msg, + _ *agd.RequestInfo, + ) (r filter.Result, err error) { + pt := testutil.PanicT{} + require.NotEmpty(pt, m.Question) + require.Equal(pt, domainFQDN, m.Question[0].Name) - assert.Equal(t, devID, <-dbDeviceIDs) - assert.Equal(t, domain, <-fltDomainCh) - assert.Equal(t, domain, <-logDomainCh) - assert.Equal(t, reqType, <-logQTypeCh) - assert.Equal(t, 1, numDNSDBReq) - assert.Equal(t, 1, numRuleStatReq) + return nil, nil + } + + flt := &agdtest.Filter{ + OnFilterRequest: noMatch, + OnFilterResponse: noMatch, + } + + svc, srvAddr := newTestService( + t, + flt, + errCollCh, + profileDBCh, + querylogCh, + geoIPCh, + dnsDBCh, + ruleStatCh, + ) + + rw := dnsserver.NewNonWriterResponseWriter( + net.TCPAddrFromAddrPort(srvAddr), + clientAddr, + ) + + ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) + + err := svc.Handle(ctx, testSrvGrpName, testSrvName, rw, req) + require.NoError(t, err) + + resp := rw.Msg() + dnsservertest.RequireResponse(t, req, resp, 1, dns.RcodeSuccess, false) + + assert.Equal(t, testDevID, <-profileDBCh) + + logEntry := <-querylogCh + assert.Equal(t, domainFQDN, logEntry.DomainFQDN) + assert.Equal(t, reqType, logEntry.RequestType) + + assert.Equal(t, "", <-geoIPCh) + assert.Equal(t, domain, <-geoIPCh) + + dnsDBReqInfo := <-dnsDBCh + assert.NotNil(t, dnsDBReqInfo) + assert.Equal(t, agd.FilterRuleText(""), <-ruleStatCh) + }) + + t.Run("request_cname", func(t *testing.T) { + const ( + cname = "cname.example.org" + cnameRule agd.FilterRuleText = "||" + domain + "^$dnsrewrite=" + cname + ) + + cnameFQDN := dns.Fqdn(cname) + + flt := &agdtest.Filter{ + OnFilterRequest: func( + _ context.Context, + m *dns.Msg, + _ *agd.RequestInfo, + ) (r filter.Result, err error) { + // Pretend a CNAME rewrite matched the request. + mod := dnsmsg.Clone(m) + mod.Question[0].Name = cnameFQDN + + return &filter.ResultModified{ + Msg: mod, + List: testFltListID, + Rule: cnameRule, + }, nil + }, + OnFilterResponse: func( + _ context.Context, + _ *dns.Msg, + _ *agd.RequestInfo, + ) (filter.Result, error) { + panic("not implemented") + }, + } + + svc, srvAddr := newTestService( + t, + flt, + errCollCh, + profileDBCh, + querylogCh, + geoIPCh, + dnsDBCh, + ruleStatCh, + ) + + rw := dnsserver.NewNonWriterResponseWriter( + net.TCPAddrFromAddrPort(srvAddr), + clientAddr, + ) + + ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) + + err := svc.Handle(ctx, testSrvGrpName, testSrvName, rw, req) + require.NoError(t, err) + + resp := rw.Msg() + require.NotNil(t, resp) + require.Len(t, resp.Answer, 2) + + assert.Equal(t, []dns.RR{&dns.CNAME{ + Hdr: dns.RR_Header{ + Name: domainFQDN, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: uint32(agdtest.FilteredResponseTTL.Seconds()), + }, + Target: cnameFQDN, + }, &dns.A{ + Hdr: dns.RR_Header{ + Name: cnameFQDN, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(dnsservertest.AnswerTTL.Seconds()), + }, + A: netutil.IPv4Localhost().AsSlice(), + }}, resp.Answer) + + assert.Equal(t, testDevID, <-profileDBCh) + + logEntry := <-querylogCh + assert.Equal(t, domainFQDN, logEntry.DomainFQDN) + assert.Equal(t, reqType, logEntry.RequestType) + + assert.Equal(t, "", <-geoIPCh) + assert.Equal(t, cname, <-geoIPCh) + + dnsDBReqInfo := <-dnsDBCh + assert.Equal(t, cname, dnsDBReqInfo.Host) + assert.Equal(t, cnameRule, <-ruleStatCh) + }) } diff --git a/internal/dnssvc/presvcmw.go b/internal/dnssvc/presvcmw.go index 2229632..16e0c48 100644 --- a/internal/dnssvc/presvcmw.go +++ b/internal/dnssvc/presvcmw.go @@ -25,8 +25,8 @@ type preServiceMw struct { // messages is used to construct TXT responses. messages *dnsmsg.Constructor - // filter is the safe browsing DNS filter. - filter *filter.SafeBrowsingServer + // hashMatcher is the safe browsing DNS hashMatcher. + hashMatcher filter.HashMatcher // checker is used to detect and process DNS-check requests. checker dnscheck.Interface @@ -91,7 +91,7 @@ func (mh *preServiceMwHandler) respondWithHashes( ) (err error) { optlog.Debug1("presvc mw: safe browsing: got txt req for %q", ri.Host) - hashes, matched, err := mh.mw.filter.Hashes(ctx, ri.Host) + hashes, matched, err := mh.mw.hashMatcher.MatchByPrefix(ctx, ri.Host) if err != nil { // Don't return or collect this error to prevent DDoS of the error // collector by sending bad requests. diff --git a/internal/dnssvc/presvcmw_test.go b/internal/dnssvc/presvcmw_test.go index 7f2071f..f6db9da 100644 --- a/internal/dnssvc/presvcmw_test.go +++ b/internal/dnssvc/presvcmw_test.go @@ -14,7 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,11 +30,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { sum := sha256.Sum256([]byte(safeBrowsingHost)) hashStr := hex.EncodeToString(sum[:]) - hashes, herr := hashstorage.New(safeBrowsingHost) - require.NoError(t, herr) - - srv := filter.NewSafeBrowsingServer(hashes, nil) - host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix + host := hashStr[:hashprefix.PrefixEncLen] + filter.GeneralTXTSuffix ctx := context.Background() ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{}) @@ -48,12 +44,14 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { req *dns.Msg dnscheckResp *dns.Msg ri *agd.RequestInfo + hashes []string wantAns []dns.RR }{{ name: "normal", req: dnsservertest.CreateMessage(name, dns.TypeA), dnscheckResp: nil, ri: &agd.RequestInfo{}, + hashes: nil, wantAns: []dns.RR{ dnsservertest.NewA(name, 100, ip), }, @@ -63,16 +61,14 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { dnscheckResp: dnsservertest.NewResp( dns.RcodeSuccess, dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET), - dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)}, - Sec: dnsservertest.SectionAnswer, - }, + dnsservertest.SectionAnswer{dnsservertest.NewA(name, ttl, ip)}, ), ri: &agd.RequestInfo{ Host: name, QType: dns.TypeA, QClass: dns.ClassINET, }, + hashes: nil, wantAns: []dns.RR{ dnsservertest.NewA(name, ttl, ip), }, @@ -81,6 +77,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT), dnscheckResp: nil, ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT}, + hashes: []string{hashStr}, wantAns: []dns.RR{&dns.TXT{ Hdr: dns.RR_Header{ Name: safeBrowsingHost, @@ -95,6 +92,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { req: dnsservertest.CreateMessage(name, dns.TypeTXT), dnscheckResp: nil, ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT}, + hashes: nil, wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)}, }} @@ -113,10 +111,19 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) { }, } + hashMatcher := &agdtest.HashMatcher{ + OnMatchByPrefix: func( + ctx context.Context, + host string, + ) (hashes []string, matched bool, err error) { + return tc.hashes, len(tc.hashes) > 0, nil + }, + } + mw := &preServiceMw{ - messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second), - filter: srv, - checker: dnsCk, + messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second), + hashMatcher: hashMatcher, + checker: dnsCk, } handler := dnsservertest.DefaultHandler() h := mw.Wrap(handler) diff --git a/internal/dnssvc/preupstreammw_test.go b/internal/dnssvc/preupstreammw_test.go index 58a0c91..09a5d6a 100644 --- a/internal/dnssvc/preupstreammw_test.go +++ b/internal/dnssvc/preupstreammw_test.go @@ -28,9 +28,8 @@ func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) { aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) respIP := remoteIP.AsSlice() - resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)}, - Sec: dnsservertest.SectionAnswer, + resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, defaultTTL, respIP), }) ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{ Host: aReq.Question[0].Name, @@ -91,18 +90,9 @@ func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) { const ctry = agd.CountryAD - resp := dnsservertest.NewResp( - dns.RcodeSuccess, - aReq, - dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA( - reqHostname, - defaultTTL, - net.IP{1, 2, 3, 4}, - )}, - Sec: dnsservertest.SectionAnswer, - }, - ) + resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}), + }) numReq := 0 handler := dnsserver.HandlerFunc( @@ -176,8 +166,16 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) { ctx = dnsserver.ContextWithStartTime(ctx, time.Now()) ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{}) + ipA := net.IP{1, 2, 3, 4} + ipB := net.IP{1, 2, 3, 5} + const ttl = 100 + const ( + httpsDomain = "-dnsohttps-ds.metric.gstatic.com." + tlsDomain = "-dnsotls-ds.metric.gstatic.com." + ) + testCases := []struct { name string req *dns.Msg @@ -191,49 +189,28 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) { wantName: "example.com.", wantAns: nil, }, { - name: "android-tls-metric", - req: dnsservertest.CreateMessage( - "12345678-dnsotls-ds.metric.gstatic.com.", - dns.TypeA, - ), + name: "android-tls-metric", + req: dnsservertest.CreateMessage("12345678"+tlsDomain, dns.TypeA), resp: resp, - wantName: "00000000-dnsotls-ds.metric.gstatic.com.", + wantName: "00000000" + tlsDomain, wantAns: nil, }, { - name: "android-https-metric", - req: dnsservertest.CreateMessage( - "123456-dnsohttps-ds.metric.gstatic.com.", - dns.TypeA, - ), + name: "android-https-metric", + req: dnsservertest.CreateMessage("123456"+httpsDomain, dns.TypeA), resp: resp, - wantName: "000000-dnsohttps-ds.metric.gstatic.com.", + wantName: "000000" + httpsDomain, wantAns: nil, }, { name: "multiple_answers_metric", - req: dnsservertest.CreateMessage( - "123456-dnsohttps-ds.metric.gstatic.com.", - dns.TypeA, - ), - resp: dnsservertest.NewResp( - dns.RcodeSuccess, - req, - dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA( - "123456-dnsohttps-ds.metric.gstatic.com.", - ttl, - net.IP{1, 2, 3, 4}, - ), dnsservertest.NewA( - "654321-dnsohttps-ds.metric.gstatic.com.", - ttl, - net.IP{1, 2, 3, 5}, - )}, - Sec: dnsservertest.SectionAnswer, - }, - ), - wantName: "000000-dnsohttps-ds.metric.gstatic.com.", + req: dnsservertest.CreateMessage("123456"+httpsDomain, dns.TypeA), + resp: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{ + dnsservertest.NewA("123456"+httpsDomain, ttl, ipA), + dnsservertest.NewA("654321"+httpsDomain, ttl, ipB), + }), + wantName: "000000" + httpsDomain, wantAns: []dns.RR{ - dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}), - dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}), + dnsservertest.NewA("123456"+httpsDomain, ttl, ipA), + dnsservertest.NewA("123456"+httpsDomain, ttl, ipB), }, }} @@ -248,6 +225,7 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) { return rw.WriteMsg(ctx, req, tc.resp) }) + h := mw.Wrap(handler) rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr) diff --git a/internal/dnssvc/record.go b/internal/dnssvc/record.go index f7adcb3..2da0b7f 100644 --- a/internal/dnssvc/record.go +++ b/internal/dnssvc/record.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strings" "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" @@ -68,7 +69,14 @@ func (svc *Service) recordQueryInfo( var respCtry agd.Country if !respIP.IsUnspecified() { - respCtry = svc.country(ctx, ri.Host, respIP) + host := ri.Host + if modReq := rewrittenRequest(reqRes); modReq != nil { + // If the request was modified by CNAME rule, the actual result + // belongs to the hostname from that CNAME. + host = strings.TrimSuffix(modReq.Question[0].Name, ".") + } + + respCtry = svc.country(ctx, host, respIP) } q := req.Question[0] diff --git a/internal/dnssvc/resp.go b/internal/dnssvc/resp.go index 7ed198a..12bfa1c 100644 --- a/internal/dnssvc/resp.go +++ b/internal/dnssvc/resp.go @@ -31,16 +31,23 @@ func writeFilteredResp( case *filter.ResultAllowed: err = rw.WriteMsg(ctx, req, resp) if err != nil { - err = fmt.Errorf("writing allowed response: %w", err) + err = fmt.Errorf("writing response to allowed request: %w", err) } else { written = resp } case *filter.ResultModified: - err = rw.WriteMsg(ctx, req, reqRes.Msg) + if reqRes.Msg.Response { + // Only use the request filtering result in case it's already a + // response. Otherwise, it's a CNAME rewrite result, which isn't + // filtered after resolving. + resp = reqRes.Msg + } + + err = rw.WriteMsg(ctx, req, resp) if err != nil { - err = fmt.Errorf("writing modified response: %w", err) + err = fmt.Errorf("writing response to modified request: %w", err) } else { - written = reqRes.Msg + written = resp } default: // Consider unhandled sum type members as unrecoverable programmer diff --git a/internal/dnssvc/resp_internal_test.go b/internal/dnssvc/resp_internal_test.go index d8d65e4..a1bfe7a 100644 --- a/internal/dnssvc/resp_internal_test.go +++ b/internal/dnssvc/resp_internal_test.go @@ -4,10 +4,9 @@ import ( "context" "net" "testing" - "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" @@ -20,10 +19,10 @@ import ( func TestWriteFilteredResp(t *testing.T) { const ( - fltRespTTL = 42 - respTTL = 10 + respTTL = 60 ) + const fltRespTTL = agdtest.FilteredResponseTTLSec respIP := net.IP{1, 2, 3, 4} rewrIP := net.IP{5, 6, 7, 8} blockIP := netutil.IPv4Zero() @@ -85,14 +84,14 @@ func TestWriteFilteredResp(t *testing.T) { ctx := context.Background() ri := &agd.RequestInfo{ - Messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, fltRespTTL*time.Second), + Messages: agdtest.NewConstructor(), } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { rw := dnsserver.NewNonWriterResponseWriter(nil, nil) resp := dnsservertest.NewResp(dns.RcodeSuccess, req) - resp.Answer = append(resp.Answer, dnsservertest.NewA(domain, 10, respIP)) + resp.Answer = append(resp.Answer, dnsservertest.NewA(domain, respTTL, respIP)) written, err := writeFilteredResp(ctx, ri, rw, req, resp, tc.reqRes, tc.respRes) require.NoError(t, err) diff --git a/internal/ecscache/ecscache_test.go b/internal/ecscache/ecscache_test.go index 26ae13b..79309b6 100644 --- a/internal/ecscache/ecscache_test.go +++ b/internal/ecscache/ecscache_test.go @@ -42,13 +42,11 @@ var remoteIP = netip.MustParseAddr("1.2.3.4") func TestMiddleware_Wrap_noECS(t *testing.T) { aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) cnameReq := dnsservertest.NewReq(reqHostname, dns.TypeCNAME, dns.ClassINET) - cnameAns := dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME)}, - Sec: dnsservertest.SectionAnswer, + cnameAns := dnsservertest.SectionAnswer{ + dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME), } - soaNS := dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewSOA(reqHostname, defaultTTL, reqNS1, reqNS2)}, - Sec: dnsservertest.SectionNs, + soaNS := dnsservertest.SectionNs{ + dnsservertest.NewSOA(reqHostname, defaultTTL, reqNS1, reqNS2), } const N = 5 @@ -60,8 +58,8 @@ func TestMiddleware_Wrap_noECS(t *testing.T) { wantTTL uint32 }{{ req: aReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4})}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}), }), name: "simple_a", wantNumReq: 1, @@ -92,9 +90,8 @@ func TestMiddleware_Wrap_noECS(t *testing.T) { wantTTL: defaultTTL, }, { req: aReq, - resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewNS(reqHostname, defaultTTL, reqNS1)}, - Sec: dnsservertest.SectionNs, + resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.SectionNs{ + dnsservertest.NewNS(reqHostname, defaultTTL, reqNS1), }), name: "non_authoritative_nxdomain", // TODO(ameshkov): Consider https://datatracker.ietf.org/doc/html/rfc2308#section-3. @@ -114,16 +111,16 @@ func TestMiddleware_Wrap_noECS(t *testing.T) { wantTTL: ecscache.ServFailMaxCacheTTL, }, { req: cnameReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME)}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.SectionAnswer{ + dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME), }), name: "simple_cname_ans", wantNumReq: 1, wantTTL: defaultTTL, }, { req: aReq, - resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4})}, + resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{ + dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}), }), name: "expired_one", wantNumReq: N, @@ -270,14 +267,12 @@ func TestMiddleware_Wrap_ecs(t *testing.T) { resp := dnsservertest.NewResp( dns.RcodeSuccess, aReq, - dnsservertest.RRSection{ - RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4})}, - Sec: dnsservertest.SectionAnswer, - }, - dnsservertest.RRSection{ - RRs: []dns.RR{tc.respECS}, - Sec: dnsservertest.SectionExtra, - }, + dnsservertest.SectionAnswer{dnsservertest.NewA( + reqHostname, + defaultTTL, + net.IP{1, 2, 3, 4}, + )}, + dnsservertest.SectionExtra{tc.respECS}, ) numReq := 0 @@ -337,13 +332,12 @@ func TestMiddleware_Wrap_ecsOrder(t *testing.T) { newResp := func(t *testing.T, req *dns.Msg, answer, extra dns.RR) (resp *dns.Msg) { t.Helper() - return dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{ - RRs: []dns.RR{answer}, - Sec: dnsservertest.SectionAnswer, - }, dnsservertest.RRSection{ - RRs: []dns.RR{extra}, - Sec: dnsservertest.SectionExtra, - }) + return dnsservertest.NewResp( + dns.RcodeSuccess, + req, + dnsservertest.SectionAnswer{answer}, + dnsservertest.SectionExtra{extra}, + ) } reqNoECS := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET) diff --git a/internal/errcoll/sentry.go b/internal/errcoll/sentry.go index e4eb994..86f828e 100644 --- a/internal/errcoll/sentry.go +++ b/internal/errcoll/sentry.go @@ -79,9 +79,9 @@ func isReportable(err error) (ok bool) { } else if errors.As(err, &dnsWErr) { switch dnsWErr.Protocol { case "tcp": - return isReportableTCP(dnsWErr.Err) + return isReportableWriteTCP(dnsWErr.Err) case "udp": - return isReportableUDP(dnsWErr.Err) + return isReportableWriteUDP(dnsWErr.Err) default: return true } @@ -102,25 +102,28 @@ func isReportableNetwork(err error) (ok bool) { return errors.As(err, &netErr) && !netErr.Timeout() } -// isReportableTCP returns true if err is a TCP or TLS error that should be +// isReportableWriteTCP returns true if err is a TCP or TLS error that should be // reported. -func isReportableTCP(err error) (ok bool) { +func isReportableWriteTCP(err error) (ok bool) { if isConnectionBreak(err) { return false } - // Ignore the TLS errors that are probably caused by a network error and - // errors about protocol versions. + // Ignore the TLS errors that are probably caused by a network error, a + // record overflow attempt, and errors about protocol versions. + // + // See also AGDNS-1520. // // TODO(a.garipov): Propose exporting these from crypto/tls. errStr := err.Error() return !strings.Contains(errStr, "bad record MAC") && - !strings.Contains(errStr, "protocol version not supported") + !strings.Contains(errStr, "protocol version not supported") && + !strings.Contains(errStr, "local error: tls: record overflow") } -// isReportableUDP returns true if err is a UDP error that should be reported. -func isReportableUDP(err error) (ok bool) { +// isReportableWriteUDP returns true if err is a UDP error that should be reported. +func isReportableWriteUDP(err error) (ok bool) { switch { case errors.Is(err, io.EOF), diff --git a/internal/filter/compfilter.go b/internal/filter/compfilter.go deleted file mode 100644 index b579a2a..0000000 --- a/internal/filter/compfilter.go +++ /dev/null @@ -1,315 +0,0 @@ -package filter - -import ( - "context" - "fmt" - "strings" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/urlfilter/rules" - "github.com/miekg/dns" -) - -// Composite Filter - -// type check -var _ Interface = (*compFilter)(nil) - -// compFilter is a composite filter based on several types of safe search -// filters and rule lists. -type compFilter struct { - safeBrowsing *HashPrefix - adultBlocking *HashPrefix - - genSafeSearch *safeSearch - ytSafeSearch *safeSearch - - ruleLists []*ruleListFilter -} - -// qtHostFilter is a filter that can filter a request based on its query type -// and host. -// -// TODO(a.garipov): See if devirtualizing this interface would give us any -// considerable performance gains. -type qtHostFilter interface { - filterReq( - ctx context.Context, - ri *agd.RequestInfo, - req *dns.Msg, - ) (r Result, err error) - name() (n string) -} - -// FilterRequest implements the Interface interface for *compFilter. If there -// is a safe search result, it returns it. Otherwise, it returns the action -// created from the filter list network rule with the highest priority. If f is -// empty, it returns nil with no error. -func (f *compFilter) FilterRequest( - ctx context.Context, - req *dns.Msg, - ri *agd.RequestInfo, -) (r Result, err error) { - if f.isEmpty() { - return nil, nil - } - - // Prepare common data for filters. - reqID := ri.ID - log.Debug("filters: filtering req %s: %d rule lists", reqID, len(f.ruleLists)) - - // Firstly, check the profile's filter list rules, custom rules, and the - // rules from blocked services settings. - host := ri.Host - flRes := f.filterMsg(ri, host, ri.QType, req, false) - switch flRes := flRes.(type) { - case *ResultAllowed: - // Skip any additional filtering if the domain is explicitly allowed by - // user's custom rule. - if flRes.List == agd.FilterListIDCustom { - return flRes, nil - } - case *ResultBlocked: - // Skip any additional filtering if the domain is already blocked. - return flRes, nil - default: - // Go on. - } - - // Secondly, apply the safe browsing and safe search filters in the - // following order. - // - // DO NOT change the order of filters without necessity. - filters := []qtHostFilter{ - f.safeBrowsing, - f.adultBlocking, - f.genSafeSearch, - f.ytSafeSearch, - } - - for _, flt := range filters { - name := flt.name() - if name == "" { - // A nil filter, skip. - continue - } - - log.Debug("filter %s: filtering req %s", name, reqID) - r, err = flt.filterReq(ctx, ri, req) - log.Debug("filter %s: finished filtering req %s, errors: %v", name, reqID, err) - if err != nil { - return nil, err - } else if r != nil { - return r, nil - } - } - - // Thirdly, return the previously obtained filter list result. - return flRes, nil -} - -// FilterResponse implements the Interface interface for *compFilter. It -// returns the action created from the filter list network rule with the highest -// priority. If f is empty, it returns nil with no error. -func (f *compFilter) FilterResponse( - ctx context.Context, - resp *dns.Msg, - ri *agd.RequestInfo, -) (r Result, err error) { - if f.isEmpty() || len(resp.Answer) == 0 { - return nil, nil - } - - for _, ans := range resp.Answer { - host, rrType, ok := parseRespAnswer(ans) - if !ok { - continue - } - - r = f.filterMsg(ri, host, rrType, resp, true) - if r != nil { - break - } - } - - return r, nil -} - -// parseRespAnswer parses hostname and rrType from the answer if there are any. -// If ans is of a type that doesn't have an IP address or a hostname in it, ok -// is false. -func parseRespAnswer(ans dns.RR) (hostname string, rrType dnsmsg.RRType, ok bool) { - switch ans := ans.(type) { - case *dns.A: - return ans.A.String(), dns.TypeA, true - case *dns.AAAA: - return ans.AAAA.String(), dns.TypeAAAA, true - case *dns.CNAME: - return strings.TrimSuffix(ans.Target, "."), dns.TypeCNAME, true - default: - return "", dns.TypeNone, false - } -} - -// isEmpty returns true if this composite filter is an empty filter. -func (f *compFilter) isEmpty() (ok bool) { - return f == nil || (f.safeBrowsing == nil && - f.adultBlocking == nil && - f.genSafeSearch == nil && - f.ytSafeSearch == nil && - len(f.ruleLists) == 0) -} - -// filterMsg filters one question's or answer's information through all rule -// list filters of the composite filter. -func (f *compFilter) filterMsg( - ri *agd.RequestInfo, - host string, - rrType dnsmsg.RRType, - msg *dns.Msg, - answer bool, -) (r Result) { - var devName agd.DeviceName - if d := ri.Device; d != nil { - devName = d.Name - } - - var networkRules []*rules.NetworkRule - var hostRules4 []*rules.HostRule - var hostRules6 []*rules.HostRule - for _, rl := range f.ruleLists { - dr := rl.dnsResult(ri.RemoteIP, string(devName), host, rrType, answer) - if dr == nil { - continue - } - - // Collect only custom $dnsrewrite rules. It's much more easy - // to process dnsrewrite rules only from one list, cause when - // there is no problem with merging them among different lists. - if !answer && rl.id() == agd.FilterListIDCustom { - dnsRewriteResult := processDNSRewrites(ri.Messages, msg, dr.DNSRewrites(), host) - if dnsRewriteResult != nil { - dnsRewriteResult.List = rl.id() - - return dnsRewriteResult - } - } - - networkRules = append(networkRules, dr.NetworkRules...) - hostRules4 = append(hostRules4, dr.HostRulesV4...) - hostRules6 = append(hostRules6, dr.HostRulesV6...) - } - - mr := rules.NewMatchingResult(networkRules, nil) - if nr := mr.GetBasicResult(); nr != nil { - return f.ruleDataToResult(nr.FilterListID, nr.RuleText, nr.Whitelist) - } - - return f.hostsRulesToResult(hostRules4, hostRules6, rrType) -} - -// mustRuleListDataByURLFilterID returns the rule list data by its synthetic -// integer ID in the urlfilter engine. It panics if id is not found. -func (f *compFilter) mustRuleListDataByURLFilterID(id int) (fltID agd.FilterListID, subID string) { - for _, rl := range f.ruleLists { - if rl.urlFilterID == id { - return rl.id(), rl.subID - } - } - - // Technically shouldn't happen, since id is supposed to be among the rule - // list filters in the composite filter. - panic(fmt.Errorf("filter: synthetic id %d not found", id)) -} - -// hostsRulesToResult converts /etc/hosts-style rules into a filtering action. -func (f *compFilter) hostsRulesToResult( - hostRules4 []*rules.HostRule, - hostRules6 []*rules.HostRule, - rrType dnsmsg.RRType, -) (r Result) { - if len(hostRules4) == 0 && len(hostRules6) == 0 { - return nil - } - - // Only use the first matched rule, since we currently don't care about the - // IP addresses in the rule. If the request is neither an A one nor an AAAA - // one, or if there are no matching rules of the requested type, then use - // whatever rule isn't empty. - // - // See also AGDNS-591. - var resHostRule *rules.HostRule - if rrType == dns.TypeA && len(hostRules4) > 0 { - resHostRule = hostRules4[0] - } else if rrType == dns.TypeAAAA && len(hostRules6) > 0 { - resHostRule = hostRules6[0] - } else { - if len(hostRules4) > 0 { - resHostRule = hostRules4[0] - } else { - resHostRule = hostRules6[0] - } - } - - return f.ruleDataToResult(resHostRule.FilterListID, resHostRule.RuleText, false) -} - -// ruleDataToResult converts a urlfilter rule data into a filtering result. -func (f *compFilter) ruleDataToResult( - urlFilterID int, - ruleText string, - allowlist bool, -) (r Result) { - // Use the urlFilterID crutch to find the actual IDs of the filtering rule - // list and blocked service. - fltID, subID := f.mustRuleListDataByURLFilterID(urlFilterID) - - var rule agd.FilterRuleText - if fltID == agd.FilterListIDBlockedService { - rule = agd.FilterRuleText(subID) - } else { - rule = agd.FilterRuleText(ruleText) - } - - if allowlist { - log.Debug("rule list %s: allowed by rule %s", fltID, rule) - - return &ResultAllowed{ - List: fltID, - Rule: rule, - } - } - - log.Debug("rule list %s: blocked by rule %s", fltID, rule) - - return &ResultBlocked{ - List: fltID, - Rule: rule, - } -} - -// Close implements the Filter interface for *compFilter. It closes all -// underlying filters. -func (f *compFilter) Close() (err error) { - if f.isEmpty() { - return nil - } - - errs := make([]error, len(f.ruleLists)) - for i, rl := range f.ruleLists { - err = rl.Close() - if err != nil { - errs[i] = fmt.Errorf("rule list at index %d: %w", i, err) - } - } - - err = errors.Join(errs...) - if err != nil { - return fmt.Errorf("closing filters: %w", err) - } - - return nil -} diff --git a/internal/filter/compfilter_internal_test.go b/internal/filter/compfilter_internal_test.go deleted file mode 100644 index d0ab69c..0000000 --- a/internal/filter/compfilter_internal_test.go +++ /dev/null @@ -1,284 +0,0 @@ -package filter - -import ( - "context" - "net" - "strings" - "testing" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompFilter_FilterRequest_badrequest(t *testing.T) { - const ( - fltListID1 agd.FilterListID = "fl1" - fltListID2 agd.FilterListID = "fl2" - - blockRule = "||example.com^" - ) - - rl1, err := newRuleListFltFromStr(blockRule, fltListID1, "", 0, false) - require.NoError(t, err) - - rl2, err := newRuleListFltFromStr("||example.com^$badfilter", fltListID2, "", 0, false) - require.NoError(t, err) - - req := &dns.Msg{ - Question: []dns.Question{{ - Name: testReqFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - testCases := []struct { - name string - wantRes Result - ruleLists []*ruleListFilter - }{{ - name: "block", - wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule}, - ruleLists: []*ruleListFilter{rl1}, - }, { - name: "badfilter_no_block", - wantRes: nil, - ruleLists: []*ruleListFilter{rl2}, - }, { - name: "badfilter_removes_block", - wantRes: nil, - ruleLists: []*ruleListFilter{rl1, rl2}, - }} - - ri := &agd.RequestInfo{ - Messages: newConstructor(), - Host: testReqHost, - RemoteIP: testRemoteIP, - QType: dns.TypeA, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - f := &compFilter{ - ruleLists: tc.ruleLists, - } - - ctx := context.Background() - - res, rerr := f.FilterRequest(ctx, req, ri) - require.NoError(t, rerr) - - assert.Equal(t, tc.wantRes, res) - }) - } -} - -func TestCompFilter_FilterRequest_hostsRules(t *testing.T) { - const ( - fltListID agd.FilterListID = "fl1" - - reqHost4 = "www.example.com" - reqHost6 = "www.example.net" - - blockRule4 = "127.0.0.1 www.example.com" - blockRule6 = "::1 www.example.net" - ) - - const rules = blockRule4 + "\n" + blockRule6 - - rl, err := newRuleListFltFromStr(rules, fltListID, "", 0, false) - require.NoError(t, err) - - f := &compFilter{ - ruleLists: []*ruleListFilter{rl}, - } - - testCases := []struct { - wantRes Result - name string - reqHost string - reqType dnsmsg.RRType - }{{ - wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4}, - name: "a", - reqHost: reqHost4, - reqType: dns.TypeA, - }, { - wantRes: &ResultBlocked{List: fltListID, Rule: blockRule6}, - name: "aaaa", - reqHost: reqHost6, - reqType: dns.TypeAAAA, - }, { - wantRes: &ResultBlocked{List: fltListID, Rule: blockRule6}, - name: "a_with_ipv6_rule", - reqHost: reqHost6, - reqType: dns.TypeA, - }, { - wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4}, - name: "aaaa_with_ipv4_rule", - reqHost: reqHost4, - reqType: dns.TypeAAAA, - }, { - wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4}, - name: "mx", - reqHost: reqHost4, - reqType: dns.TypeMX, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ri := &agd.RequestInfo{ - Messages: newConstructor(), - Host: tc.reqHost, - RemoteIP: testRemoteIP, - QType: tc.reqType, - } - - req := &dns.Msg{ - Question: []dns.Question{{ - Name: dns.Fqdn(tc.reqHost), - Qtype: tc.reqType, - Qclass: dns.ClassINET, - }}, - } - - ctx := context.Background() - - res, rerr := f.FilterRequest(ctx, req, ri) - require.NoError(t, rerr) - - assert.Equal(t, tc.wantRes, res) - assert.Equal(t, tc.wantRes, res) - }) - } -} - -func TestCompFilter_FilterRequest_dnsrewrite(t *testing.T) { - const ( - fltListID1 agd.FilterListID = "fl1" - fltListID2 agd.FilterListID = "fl2" - - fltListIDCustom = agd.FilterListIDCustom - - blockRule = "||example.com^" - dnsRewriteRuleRefused = "||example.com^$dnsrewrite=REFUSED" - dnsRewriteRuleCname = "||example.com^$dnsrewrite=cname" - dnsRewriteRule1 = "||example.com^$dnsrewrite=1.2.3.4" - dnsRewriteRule2 = "||example.com^$dnsrewrite=1.2.3.5" - ) - - rl1, err := newRuleListFltFromStr(blockRule, fltListID1, "", 0, false) - require.NoError(t, err) - - rl2, err := newRuleListFltFromStr(dnsRewriteRuleRefused, fltListID2, "", 0, false) - require.NoError(t, err) - - rlCustomRefused, err := newRuleListFltFromStr( - dnsRewriteRuleRefused, - fltListIDCustom, - string(testProfID), - 0, - false, - ) - require.NoError(t, err) - - rlCustomCname, err := newRuleListFltFromStr( - dnsRewriteRuleCname, - fltListIDCustom, - "prof1235", - 0, - false, - ) - require.NoError(t, err) - - rlCustom2, err := newRuleListFltFromStr( - strings.Join([]string{dnsRewriteRule1, dnsRewriteRule2}, "\n"), - fltListIDCustom, - "prof1236", - 0, - false, - ) - require.NoError(t, err) - - question := dns.Question{ - Name: testReqFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - - req := &dns.Msg{ - Question: []dns.Question{question}, - } - - testCases := []struct { - name string - wantRes Result - ruleLists []*ruleListFilter - }{{ - name: "block", - wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule}, - ruleLists: []*ruleListFilter{rl1}, - }, { - name: "dnsrewrite_no_effect", - wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule}, - ruleLists: []*ruleListFilter{rl1, rl2}, - }, { - name: "dnsrewrite_block", - wantRes: &ResultModified{ - Msg: dnsservertest.NewResp(dns.RcodeRefused, req, dnsservertest.RRSection{}), - List: fltListIDCustom, - Rule: dnsRewriteRuleRefused, - }, - ruleLists: []*ruleListFilter{rl1, rl2, rlCustomRefused}, - }, { - name: "dnsrewrite_cname", - wantRes: &ResultModified{ - Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{ - RRs: []dns.RR{ - dnsservertest.NewCNAME(testReqHost, 10, "cname"), - }, - }), - List: fltListIDCustom, - Rule: dnsRewriteRuleCname, - }, - ruleLists: []*ruleListFilter{rl1, rl2, rlCustomCname}, - }, { - name: "dnsrewrite_answers", - wantRes: &ResultModified{ - Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{ - RRs: []dns.RR{ - dnsservertest.NewA(testReqHost, 10, net.IP{1, 2, 3, 4}), - dnsservertest.NewA(testReqHost, 10, net.IP{1, 2, 3, 5}), - }, - }), - List: fltListIDCustom, - Rule: "", - }, - ruleLists: []*ruleListFilter{rl1, rl2, rlCustom2}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - f := &compFilter{ - ruleLists: tc.ruleLists, - } - - ctx := context.Background() - ri := &agd.RequestInfo{ - Messages: newConstructor(), - Host: testReqHost, - RemoteIP: testRemoteIP, - QType: dns.TypeA, - } - - res, rerr := f.FilterRequest(ctx, req, ri) - require.NoError(t, rerr) - - assert.Equal(t, tc.wantRes, res) - }) - } -} diff --git a/internal/filter/custom_internal_test.go b/internal/filter/custom_internal_test.go deleted file mode 100644 index 5a64acc..0000000 --- a/internal/filter/custom_internal_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package filter - -import ( - "context" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/bluele/gcache" -) - -// errorCollector is an agd.ErrorCollector for tests. This is a copy of the -// code from package agdtest to evade an import cycle. -type errorCollector struct { - OnCollect func(ctx context.Context, err error) -} - -// type check -var _ agd.ErrorCollector = (*errorCollector)(nil) - -// Collect implements the agd.GeoIP interface for *GeoIP. -func (c *errorCollector) Collect(ctx context.Context, err error) { - c.OnCollect(ctx, err) -} - -var ruleListsSink []*ruleListFilter - -func BenchmarkCustomFilters_ruleCache(b *testing.B) { - f := &customFilters{ - cache: gcache.New(1).LRU().Build(), - errColl: &errorCollector{ - OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, - }, - } - - p := &agd.Profile{ - ID: testProfID, - UpdateTime: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), - CustomRules: []agd.FilterRuleText{ - "||example.com", - "||example.org", - "||example.net", - }, - } - - ctx := context.Background() - - b.Run("cache", func(b *testing.B) { - rls := make([]*ruleListFilter, 0, 1) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - ruleListsSink = f.appendRuleLists(ctx, rls, p) - } - }) - - b.Run("no_cache", func(b *testing.B) { - rls := make([]*ruleListFilter, 0, 1) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - // Update the time on each iteration to make sure that the cache is - // never used. - p.UpdateTime.Add(1 * time.Millisecond) - ruleListsSink = f.appendRuleLists(ctx, rls, p) - } - }) -} diff --git a/internal/filter/dnsrewrite_internal_test.go b/internal/filter/dnsrewrite_internal_test.go deleted file mode 100644 index 83aa3d3..0000000 --- a/internal/filter/dnsrewrite_internal_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package filter - -import ( - "net" - "testing" - - "github.com/AdguardTeam/golibs/testutil" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" - "github.com/AdguardTeam/urlfilter/rules" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" -) - -func Test_processDNSRewriteRules(t *testing.T) { - cnameRule, _ := rules.NewNetworkRule("|cname^$dnsrewrite=new-cname", 1) - aRecordRule, _ := rules.NewNetworkRule("|a-record^$dnsrewrite=127.0.0.1", 1) - refusedRule, _ := rules.NewNetworkRule("|refused^$dnsrewrite=REFUSED", 1) - - testCases := []struct { - name string - want *DNSRewriteResult - dnsr []*rules.NetworkRule - }{{ - name: "empty", - want: &DNSRewriteResult{ - Response: DNSRewriteResultResponse{}, - }, - dnsr: []*rules.NetworkRule{}, - }, { - name: "cname", - want: &DNSRewriteResult{ - ResRuleText: agd.FilterRuleText(cnameRule.RuleText), - CanonName: cnameRule.DNSRewrite.NewCNAME, - }, - dnsr: []*rules.NetworkRule{ - cnameRule, - aRecordRule, - refusedRule, - }, - }, { - name: "refused", - want: &DNSRewriteResult{ - ResRuleText: agd.FilterRuleText(refusedRule.RuleText), - RCode: refusedRule.DNSRewrite.RCode, - }, - dnsr: []*rules.NetworkRule{ - aRecordRule, - refusedRule, - }, - }, { - name: "a_record", - want: &DNSRewriteResult{ - Rules: []*rules.NetworkRule{aRecordRule}, - RCode: aRecordRule.DNSRewrite.RCode, - Response: DNSRewriteResultResponse{ - aRecordRule.DNSRewrite.RRType: []rules.RRValue{aRecordRule.DNSRewrite.Value}, - }, - }, - dnsr: []*rules.NetworkRule{aRecordRule}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got := processDNSRewriteRules(tc.dnsr) - assert.Equal(t, tc.want, got) - }) - } -} - -func Test_filterDNSRewrite(t *testing.T) { - cnameRule, _ := rules.NewNetworkRule("|cname^$dnsrewrite=new-cname", 1) - aRecordRule, _ := rules.NewNetworkRule("|a-record^$dnsrewrite=127.0.0.1", 1) - refusedRule, _ := rules.NewNetworkRule("|refused^$dnsrewrite=REFUSED", 1) - - messages := newConstructor() - - req := dnsservertest.NewReq(testReqFQDN, dns.TypeA, dns.ClassINET) - - testCases := []struct { - dnsrr *DNSRewriteResult - want *dns.Msg - name string - wantErr string - }{{ - dnsrr: &DNSRewriteResult{ - Response: DNSRewriteResultResponse{}, - }, - want: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{}), - name: "empty", - wantErr: "", - }, { - dnsrr: &DNSRewriteResult{ - Rules: []*rules.NetworkRule{cnameRule}, - CanonName: cnameRule.DNSRewrite.NewCNAME, - }, - want: nil, - name: "cname", - wantErr: "no dns rewrite rule responses", - }, { - dnsrr: &DNSRewriteResult{ - Rules: []*rules.NetworkRule{refusedRule}, - RCode: refusedRule.DNSRewrite.RCode, - }, - want: nil, - name: "refused", - wantErr: "non-success answer", - }, { - dnsrr: &DNSRewriteResult{ - Rules: []*rules.NetworkRule{aRecordRule}, - RCode: aRecordRule.DNSRewrite.RCode, - Response: DNSRewriteResultResponse{ - aRecordRule.DNSRewrite.RRType: []rules.RRValue{aRecordRule.DNSRewrite.Value}, - }, - }, - want: dnsservertest.NewResp( - aRecordRule.DNSRewrite.RCode, - req, - dnsservertest.RRSection{ - RRs: []dns.RR{ - dnsservertest.NewA(testReqHost, 10, net.IP{127, 0, 0, 1}), - }, - }, - ), - name: "a_record", - wantErr: "", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - resp, err := filterDNSRewrite(messages, req, tc.dnsrr) - testutil.AssertErrorMsg(t, tc.wantErr, err) - assert.Equal(t, tc.want, resp) - }) - } -} diff --git a/internal/filter/filter.go b/internal/filter/filter.go index e176111..91652b5 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -5,50 +5,14 @@ package filter import ( "context" - "time" - "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" - "github.com/c2h5oh/datasize" - "github.com/miekg/dns" ) -// Common Constants, Functions, and Types - -// maxFilterSize is the maximum size of downloaded filters. -// -// TODO(ameshkov): Consider making configurable. -const maxFilterSize = 256 * int64(datasize.MB) - -// defaultFilterRefreshTimeout is the default timeout to use when fetching -// filter lists data. -// -// TODO(a.garipov): Consider making timeouts where they are used configurable. -const defaultFilterRefreshTimeout = 180 * time.Second - -// defaultResolveTimeout is the default timeout for resolving hosts for safe -// search and safe browsing filters. -// -// TODO(ameshkov): Consider making configurable. -const defaultResolveTimeout = 1 * time.Second - // Interface is the DNS request and response filter interface. -type Interface interface { - // FilterRequest filters the DNS request for the provided client. All - // parameters must be non-nil. req must have exactly one question. If a is - // nil, the request doesn't match any of the rules. - FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error) +type Interface = internal.Interface - // FilterResponse filters the DNS response for the provided client. All - // parameters must be non-nil. If a is nil, the response doesn't match any - // of the rules. - FilterResponse(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) (r Result, err error) - - // Close closes the filter and frees resources associated with it. - Close() (err error) -} - -// Filtering Result Aliases +// Filtering result aliases // Result is a sum type of all possible filtering actions. See the following // types as implementations: @@ -69,3 +33,17 @@ type ResultBlocked = internal.ResultBlocked // ResultModified means that this request or response was rewritten or modified // by a rewrite rule within the given filter list. type ResultModified = internal.ResultModified + +// Hash matching for safe-browsing and adult-content blocking + +// HashMatcher is the interface for a safe-browsing and adult-blocking hash +// matcher, which is used to respond to a TXT query based on the domain name. +type HashMatcher interface { + MatchByPrefix(ctx context.Context, host string) (hashes []string, matched bool, err error) +} + +// Default safe-browsing host suffixes. +const ( + GeneralTXTSuffix = ".sb.dns.adguard.com" + AdultBlockingTXTSuffix = ".pc.dns.adguard.com" +) diff --git a/internal/filter/filter_internal_test.go b/internal/filter/filter_internal_test.go deleted file mode 100644 index 31e659e..0000000 --- a/internal/filter/filter_internal_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package filter - -import ( - "net/netip" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" -) - -// Common test constants. - -// testTimeout is the timeout for tests. -const testTimeout = 1 * time.Second - -// testProfID is the profile ID for tests. -const testProfID agd.ProfileID = "prof1234" - -// testReqHost is the request host for tests. -const testReqHost = "www.example.com" - -// testReqFQDN is the request FQDN for tests. -const testReqFQDN = testReqHost + "." - -// testRemoteIP is the client IP for tests -var testRemoteIP = netip.MustParseAddr("1.2.3.4") - -// newConstructor returns a standard dnsmsg.Constructor for tests. -// -// TODO(a.garipov): Use [agdtest.NewConstructor] once the package is split and -// import cycles are resolved. -func newConstructor() (c *dnsmsg.Constructor) { - return dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, 10*time.Second) -} diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go index 8c360c2..de8edb6 100644 --- a/internal/filter/filter_test.go +++ b/internal/filter/filter_test.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" "github.com/AdguardTeam/AdGuardDNS/internal/filter" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/require" ) @@ -182,8 +183,8 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) { FilterIndexURL: fltsURL, GeneralSafeSearchRulesURL: ssURL, YoutubeSafeSearchRulesURL: ssURL, - SafeBrowsing: &filter.HashPrefix{}, - AdultBlocking: &filter.HashPrefix{}, + SafeBrowsing: &hashprefix.Filter{}, + AdultBlocking: &hashprefix.Filter{}, Now: time.Now, ErrColl: nil, Resolver: nil, diff --git a/internal/filter/hashprefix.go b/internal/filter/hashprefix/filter.go similarity index 62% rename from internal/filter/hashprefix.go rename to internal/filter/hashprefix/filter.go index 3d0dd48..c1d8708 100644 --- a/internal/filter/hashprefix.go +++ b/internal/filter/hashprefix/filter.go @@ -1,4 +1,4 @@ -package filter +package hashprefix import ( "context" @@ -8,9 +8,8 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/golibs/log" @@ -20,12 +19,10 @@ import ( "golang.org/x/net/publicsuffix" ) -// Hash-prefix filter - -// HashPrefixConfig is the hash-prefix filter configuration structure. -type HashPrefixConfig struct { +// FilterConfig is the hash-prefix filter configuration structure. +type FilterConfig struct { // Hashes are the hostname hashes for this filter. - Hashes *hashstorage.Storage + Hashes *Storage // URL is the URL used to update the filter. URL *url.URL @@ -54,71 +51,67 @@ type HashPrefixConfig struct { // CacheTTL is the time-to-live value used to cache the results of the // filter. // - // TODO(a.garipov): Currently unused. + // TODO(a.garipov): Currently unused. See AGDNS-398. CacheTTL time.Duration // CacheSize is the size of the filter's result cache. CacheSize int } -// HashPrefix is a filter that matches hosts by their hashes based on a +// Filter is a filter that matches hosts by their hashes based on a // hash-prefix table. -type HashPrefix struct { - hashes *hashstorage.Storage - refr *refreshableFilter - resCache *resultcache.Cache[*ResultModified] +type Filter struct { + hashes *Storage + refr *internal.Refreshable + resCache *resultcache.Cache[*internal.ResultModified] resolver agdnet.Resolver errColl agd.ErrorCollector + id agd.FilterListID repHost string } -// NewHashPrefix returns a new hash-prefix filter. c must not be nil. -func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) { - f = &HashPrefix{ - hashes: c.Hashes, - refr: &refreshableFilter{ - http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultFilterRefreshTimeout, - }), - url: c.URL, - id: c.ID, - cachePath: c.CachePath, - typ: "hash storage", - staleness: c.Staleness, - }, - resCache: resultcache.New[*ResultModified](c.CacheSize), +// NewFilter returns a new hash-prefix filter. c must not be nil. +func NewFilter(c *FilterConfig) (f *Filter, err error) { + id := c.ID + f = &Filter{ + hashes: c.Hashes, + resCache: resultcache.New[*internal.ResultModified](c.CacheSize), resolver: c.Resolver, errColl: c.ErrColl, + id: id, repHost: c.ReplacementHost, } - f.refr.resetRules = f.resetRules + f.refr = internal.NewRefreshable( + &agd.FilterList{ + ID: id, + URL: c.URL, + RefreshIvl: c.Staleness, + }, + c.CachePath, + ) err = f.refresh(context.Background(), true) if err != nil { + // Don't wrap the error, because it's informative enough as is. return nil, err } return f, nil } -// id returns the ID of the hash storage. -func (f *HashPrefix) id() (fltID agd.FilterListID) { - return f.refr.id -} - // type check -var _ qtHostFilter = (*HashPrefix)(nil) +var _ internal.RequestFilter = (*Filter)(nil) -// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It -// modifies the response if host matches f. -func (f *HashPrefix) filterReq( +// FilterRequest implements the [internal.RequestFilter] interface for +// *Filter. It modifies the response if host matches f. +func (f *Filter) FilterRequest( ctx context.Context, - ri *agd.RequestInfo, req *dns.Msg, -) (r Result, err error) { - host, qt := ri.Host, ri.QType - cacheKey := resultcache.DefaultKey(host, qt, false) + ri *agd.RequestInfo, +) (r internal.Result, err error) { + host, qt, cl := ri.Host, ri.QType, ri.QClass + cacheKey := resultcache.DefaultKey(host, qt, cl, false) rm, ok := f.resCache.Get(cacheKey) f.updateCacheLookupsMetrics(ok) if ok { @@ -152,25 +145,25 @@ func (f *HashPrefix) filterReq( return nil, nil } - ctx, cancel := context.WithTimeout(ctx, defaultResolveTimeout) + ctx, cancel := context.WithTimeout(ctx, internal.DefaultResolveTimeout) defer cancel() var result *dns.Msg ips, err := f.resolver.LookupIP(ctx, fam, f.repHost) if err != nil { - agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err) + agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err) result = ri.Messages.NewMsgSERVFAIL(req) } else { result, err = ri.Messages.NewIPRespMsg(req, ips...) if err != nil { - return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err) + return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err) } } - rm = &ResultModified{ + rm = &internal.ResultModified{ Msg: result, - List: f.id(), + List: f.id, Rule: agd.FilterRuleText(matched), } @@ -185,8 +178,8 @@ func (f *HashPrefix) filterReq( } // updateCacheSizeMetrics updates cache size metrics. -func (f *HashPrefix) updateCacheSizeMetrics(size int) { - switch id := f.id(); id { +func (f *Filter) updateCacheSizeMetrics(size int) { + switch id := f.id; id { case agd.FilterListIDSafeBrowsing: metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size)) case agd.FilterListIDAdultBlocking: @@ -197,9 +190,9 @@ func (f *HashPrefix) updateCacheSizeMetrics(size int) { } // updateCacheLookupsMetrics updates cache lookups metrics. -func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) { +func (f *Filter) updateCacheLookupsMetrics(hit bool) { var hitsMetric, missesMetric prometheus.Counter - switch id := f.id(); id { + switch id := f.id; id { case agd.FilterListIDSafeBrowsing: hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses @@ -207,7 +200,7 @@ func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) { hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses default: - panic(fmt.Errorf("unsupported FilterListID %s", id)) + panic(fmt.Errorf("unsupported filter list id %s", id)) } if hit { @@ -217,49 +210,37 @@ func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) { } } -// name implements the qtHostFilter interface for *hashPrefixFilter. -func (f *HashPrefix) name() (n string) { - if f == nil { - return "" - } - - return string(f.id()) -} - // type check -var _ agd.Refresher = (*HashPrefix)(nil) +var _ agd.Refresher = (*Filter)(nil) // Refresh implements the [agd.Refresher] interface for *hashPrefixFilter. -func (f *HashPrefix) Refresh(ctx context.Context) (err error) { +func (f *Filter) Refresh(ctx context.Context) (err error) { return f.refresh(ctx, false) } -// refresh reloads the hash filter data. If acceptStale is true, do not try to -// load the list from its URL when there is already a file in the cache -// directory, regardless of its staleness. -func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) { - return f.refr.refresh(ctx, acceptStale) -} - -// resetRules resets the hosts in the index. -func (f *HashPrefix) resetRules(text string) (err error) { - n, err := f.hashes.Reset(text) - - // Report the filter update to prometheus. - promLabels := prometheus.Labels{ - "filter": string(f.id()), - } - - metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err) - +// refresh reloads and resets the hash-filter data. If acceptStale is true, do +// not try to load the list from its URL when there is already a file in the +// cache directory, regardless of its staleness. +func (f *Filter) refresh(ctx context.Context, acceptStale bool) (err error) { + text, err := f.refr.Refresh(ctx, acceptStale) if err != nil { + // Don't wrap the error, because it's informative enough as is. return err } - metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime() - metrics.FilterRulesTotal.With(promLabels).Set(float64(n)) + n, err := f.hashes.Reset(text) + fltIDStr := string(f.id) + metrics.SetStatusGauge(metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr), err) + if err != nil { + return fmt.Errorf("resetting: %w", err) + } - log.Info("filter %s: reset %d hosts", f.id(), n) + f.resCache.Clear() + + metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime() + metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(n)) + + log.Info("filter %s: reset %d hosts", f.id, n) return nil } diff --git a/internal/filter/hashprefix/filter_test.go b/internal/filter/hashprefix/filter_test.go new file mode 100644 index 0000000..386ced8 --- /dev/null +++ b/internal/filter/hashprefix/filter_test.go @@ -0,0 +1,338 @@ +package hashprefix_test + +import ( + "context" + "net" + "net/http" + "os" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFilter_FilterRequest(t *testing.T) { + cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, testHost, http.StatusOK) + + strg, err := hashprefix.NewStorage("") + require.NoError(t, err) + + replIP := net.IP{1, 2, 3, 4} + f, err := hashprefix.NewFilter(&hashprefix.FilterConfig{ + Hashes: strg, + URL: srvURL, + ErrColl: &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, _ error) { + panic("not implemented") + }, + }, + Resolver: &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + return []net.IP{replIP}, nil + }, + }, + ID: agd.FilterListIDAdultBlocking, + CachePath: cachePath, + ReplacementHost: "repl.example", + Staleness: 1 * time.Minute, + CacheTTL: 1 * time.Minute, + CacheSize: 1, + }) + require.NoError(t, err) + + messages := agdtest.NewConstructor() + + testCases := []struct { + name string + host string + qType dnsmsg.RRType + wantResult bool + }{{ + name: "not_a_or_aaaa", + host: testHost, + qType: dns.TypeTXT, + wantResult: false, + }, { + name: "success", + host: testHost, + qType: dns.TypeA, + wantResult: true, + }, { + name: "success_subdomain", + host: "a.b.c." + testHost, + qType: dns.TypeA, + wantResult: true, + }, { + name: "no_match", + host: testOtherHost, + qType: dns.TypeA, + wantResult: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := dnsservertest.NewReq( + dns.Fqdn(tc.host), + tc.qType, + dns.ClassINET, + ) + ri := &agd.RequestInfo{ + Messages: messages, + Host: tc.host, + QType: tc.qType, + } + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + if tc.wantResult { + wantRes := newModifiedResult(t, req, messages, replIP) + assert.Equal(t, wantRes, r) + } else { + assert.Nil(t, r) + } + }) + } + + t.Run("cached_success", func(t *testing.T) { + req := dnsservertest.NewReq( + dns.Fqdn(testHost), + dns.TypeA, + dns.ClassINET, + ) + ri := &agd.RequestInfo{ + Messages: messages, + Host: testHost, + QType: dns.TypeA, + } + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + wantRes := newModifiedResult(t, req, messages, replIP) + assert.Equal(t, wantRes, r) + }) + + t.Run("cached_no_match", func(t *testing.T) { + req := dnsservertest.NewReq( + dns.Fqdn(testOtherHost), + dns.TypeA, + dns.ClassINET, + ) + ri := &agd.RequestInfo{ + Messages: messages, + Host: testOtherHost, + QType: dns.TypeA, + } + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + assert.Nil(t, r) + }) +} + +// newModifiedResult is a helper for creating modified results for tests. +func newModifiedResult( + tb testing.TB, + req *dns.Msg, + messages *dnsmsg.Constructor, + replIP net.IP, +) (r *internal.ResultModified) { + resp, err := messages.NewIPRespMsg(req, replIP) + require.NoError(tb, err) + + return &internal.ResultModified{ + Msg: resp, + List: testFltListID, + Rule: testHost, + } +} + +func TestFilter_Refresh(t *testing.T) { + reqCh := make(chan struct{}, 1) + cachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, testHost, http.StatusOK) + + strg, err := hashprefix.NewStorage("") + require.NoError(t, err) + + f, err := hashprefix.NewFilter(&hashprefix.FilterConfig{ + Hashes: strg, + URL: srvURL, + ErrColl: &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, _ error) { + panic("not implemented") + }, + }, + Resolver: &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + panic("not implemented") + }, + }, + ID: agd.FilterListIDAdultBlocking, + CachePath: cachePath, + ReplacementHost: "", + Staleness: 1 * time.Minute, + CacheTTL: 1 * time.Minute, + CacheSize: 1, + }) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + err = f.Refresh(ctx) + assert.NoError(t, err) + + testutil.RequireReceive(t, reqCh, filtertest.Timeout) +} + +func TestFilter_FilterRequest_staleCache(t *testing.T) { + refrCh := make(chan struct{}, 1) + cachePath, srvURL := filtertest.PrepareRefreshable(t, refrCh, testHost, http.StatusOK) + + // Put some initial data into the cache to avoid the first refresh. + + cf, err := os.OpenFile(cachePath, os.O_WRONLY|os.O_APPEND, os.ModeAppend) + require.NoError(t, err) + + _, err = cf.WriteString(testOtherHost) + require.NoError(t, err) + require.NoError(t, cf.Close()) + + // Create the filter. + + strg, err := hashprefix.NewStorage("") + require.NoError(t, err) + + replIP := net.IP{1, 2, 3, 4} + fconf := &hashprefix.FilterConfig{ + Hashes: strg, + URL: srvURL, + ErrColl: &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, _ error) { + panic("not implemented") + }, + }, + Resolver: &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + return []net.IP{replIP}, nil + }, + }, + ID: agd.FilterListIDAdultBlocking, + CachePath: cachePath, + ReplacementHost: "repl.example", + Staleness: 1 * time.Minute, + CacheTTL: 1 * time.Minute, + CacheSize: 1, + } + f, err := hashprefix.NewFilter(fconf) + require.NoError(t, err) + + messages := agdtest.NewConstructor() + + // Test the following: + // + // 1. Check that the stale rules cache is used. + // 2. Refresh the stale rules cache. + // 3. Ensure the result cache is cleared. + // 4. Ensure the stale rules aren't used. + + testHostReq := dnsservertest.NewReq(dns.Fqdn(testHost), dns.TypeA, dns.ClassINET) + testReqInfo := &agd.RequestInfo{Messages: messages, Host: testHost, QType: dns.TypeA} + + testOtherHostReq := dnsservertest.NewReq(dns.Fqdn(testOtherHost), dns.TypeA, dns.ClassINET) + testOtherReqInfo := &agd.RequestInfo{Messages: messages, Host: testOtherHost, QType: dns.TypeA} + + t.Run("hit_cached_host", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, testOtherHostReq, testOtherReqInfo) + require.NoError(t, err) + + var resp *dns.Msg + resp, err = messages.NewIPRespMsg(testOtherHostReq, replIP) + require.NoError(t, err) + + assert.Equal(t, &internal.ResultModified{ + Msg: resp, + List: testFltListID, + Rule: testOtherHost, + }, r) + }) + + t.Run("refresh", func(t *testing.T) { + // Make the cache stale. + now := time.Now() + err = os.Chtimes(cachePath, now, now.Add(-2*fconf.Staleness)) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + err = f.Refresh(ctx) + assert.NoError(t, err) + + testutil.RequireReceive(t, refrCh, filtertest.Timeout) + }) + + t.Run("previously_cached", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, testOtherHostReq, testOtherReqInfo) + require.NoError(t, err) + + assert.Nil(t, r) + }) + + t.Run("new_host", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var r internal.Result + r, err = f.FilterRequest(ctx, testHostReq, testReqInfo) + require.NoError(t, err) + + wantRes := newModifiedResult(t, testHostReq, messages, replIP) + assert.Equal(t, wantRes, r) + }) +} diff --git a/internal/filter/hashprefix/hashprefix.go b/internal/filter/hashprefix/hashprefix.go new file mode 100644 index 0000000..34dfcd1 --- /dev/null +++ b/internal/filter/hashprefix/hashprefix.go @@ -0,0 +1,32 @@ +// Package hashprefix defines a storage of hashes of domain names used for +// filtering and serving TXT records with domain-name hashes. +package hashprefix + +import "crypto/sha256" + +// Hash and hash part length constants. +const ( + // PrefixLen is the length of the hash prefix of the filtered hostname. + PrefixLen = 2 + + // PrefixEncLen is the encoded length of the hash prefix. Two text + // bytes per one binary byte. + PrefixEncLen = PrefixLen * 2 + + // hashLen is the length of the whole hash of the checked hostname. + hashLen = sha256.Size + + // suffixLen is the length of the hash suffix of the filtered hostname. + suffixLen = hashLen - PrefixLen + + // hashEncLen is the encoded length of the hash. Two text bytes per one + // binary byte. + hashEncLen = hashLen * 2 +) + +// Prefix is the type of the SHA256 hash prefix used to match against the +// domain-name database. +type Prefix [PrefixLen]byte + +// suffix is the type of the rest of a SHA256 hash of the filtered domain names. +type suffix [suffixLen]byte diff --git a/internal/filter/hashprefix/hashprefix_test.go b/internal/filter/hashprefix/hashprefix_test.go new file mode 100644 index 0000000..de2de4f --- /dev/null +++ b/internal/filter/hashprefix/hashprefix_test.go @@ -0,0 +1,21 @@ +package hashprefix_test + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/golibs/testutil" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// testFltListID is the common filtering-list for tests. +const testFltListID = agd.FilterListIDAdultBlocking + +// Common hostnames for tests. +const ( + testHost = "porn.example" + testOtherHost = "otherporn.example" +) diff --git a/internal/filter/hashprefix/matcher.go b/internal/filter/hashprefix/matcher.go new file mode 100644 index 0000000..5d8f752 --- /dev/null +++ b/internal/filter/hashprefix/matcher.go @@ -0,0 +1,105 @@ +package hashprefix + +import ( + "context" + "encoding/hex" + "fmt" + "strings" + + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/stringutil" +) + +// Matcher is a hash-prefix matcher that uses the hash-prefix storages as the +// source of its data. +type Matcher struct { + // storages is a mapping of domain-name suffixes to the storage containing + // hashes for this domain. + storages map[string]*Storage +} + +// NewMatcher returns a new hash-prefix matcher. storages is a mapping of +// domain-name suffixes to the storage containing hashes for this domain. +func NewMatcher(storages map[string]*Storage) (m *Matcher) { + return &Matcher{ + storages: storages, + } +} + +// MatchByPrefix implements the [filter.HashMatcher] interface for *Matcher. It +// returns the matched hashes if the host matched one of the domain names in m's +// storages. +// +// TODO(a.garipov): Use the context for logging etc. +func (m *Matcher) MatchByPrefix( + _ context.Context, + host string, +) (hashes []string, matched bool, err error) { + var ( + suffix string + prefixesStr string + strg *Storage + ) + + for suffix, strg = range m.storages { + if strings.HasSuffix(host, suffix) { + prefixesStr = host[:len(host)-len(suffix)] + matched = true + + break + } + } + + if !matched { + return nil, false, nil + } + + log.Debug("hashprefix matcher: got prefixes string %q", prefixesStr) + + hashPrefixes, err := prefixesFromStr(prefixesStr) + if err != nil { + return nil, false, err + } + + return strg.Hashes(hashPrefixes), true, nil +} + +// legacyPrefixEncLen is the encoded length of a legacy hash. +const legacyPrefixEncLen = 8 + +// prefixesFromStr returns hash prefixes from a dot-separated string. +func prefixesFromStr(prefixesStr string) (hashPrefixes []Prefix, err error) { + if prefixesStr == "" { + return nil, nil + } + + prefixSet := stringutil.NewSet() + prefixStrs := strings.Split(prefixesStr, ".") + for _, s := range prefixStrs { + if len(s) != PrefixEncLen { + // Some legacy clients send eight-character hashes instead of + // four-character ones. For now, remove the final four characters. + // + // TODO(a.garipov): Either remove this crutch or support such + // prefixes better. + if len(s) == legacyPrefixEncLen { + s = s[:PrefixEncLen] + } else { + return nil, fmt.Errorf("bad hash len for %q", s) + } + } + + prefixSet.Add(s) + } + + hashPrefixes = make([]Prefix, prefixSet.Len()) + prefixStrs = prefixSet.Values() + for i, s := range prefixStrs { + _, err = hex.Decode(hashPrefixes[i][:], []byte(s)) + if err != nil { + return nil, fmt.Errorf("bad hash encoding for %q", s) + } + } + + return hashPrefixes, nil +} diff --git a/internal/filter/safebrowsing_test.go b/internal/filter/hashprefix/matcher_test.go similarity index 63% rename from internal/filter/safebrowsing_test.go rename to internal/filter/hashprefix/matcher_test.go index 55afa30..1c57288 100644 --- a/internal/filter/safebrowsing_test.go +++ b/internal/filter/hashprefix/matcher_test.go @@ -1,4 +1,4 @@ -package filter_test +package hashprefix_test import ( "context" @@ -8,23 +8,29 @@ import ( "testing" "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestSafeBrowsingServer(t *testing.T) { - // Hashes +// type check +// +// TODO(a.garipov): Move this into the actual package instead of keeping it in +// the test package if [filter.Storage] and [filter.compFilter] are moved. +var _ filter.HashMatcher = (*hashprefix.Matcher)(nil) +func TestMatcher(t *testing.T) { const ( realisticHostIdx = iota samePrefixHost1Idx samePrefixHost2Idx ) + const suffix = filter.GeneralTXTSuffix + hosts := []string{ // Data closer to real world. - realisticHostIdx: safeBrowsingHost, + realisticHostIdx: "scam.example.net", // Additional data that has the same prefixes. samePrefixHost1Idx: "3z", @@ -37,7 +43,7 @@ func TestSafeBrowsingServer(t *testing.T) { hashStrs[i] = hex.EncodeToString(sum[:]) } - hashes, err := hashstorage.New(strings.Join(hosts, "\n")) + hashes, err := hashprefix.NewStorage(strings.Join(hosts, "\n")) require.NoError(t, err) ctx := context.Background() @@ -53,14 +59,14 @@ func TestSafeBrowsingServer(t *testing.T) { wantMatched: false, }, { name: "realistic", - host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix, + host: hashStrs[realisticHostIdx][:hashprefix.PrefixEncLen] + suffix, wantHashStrs: []string{ hashStrs[realisticHostIdx], }, wantMatched: true, }, { name: "same_prefix", - host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix, + host: hashStrs[samePrefixHost1Idx][:hashprefix.PrefixEncLen] + suffix, wantHashStrs: []string{ hashStrs[samePrefixHost1Idx], hashStrs[samePrefixHost2Idx], @@ -70,11 +76,13 @@ func TestSafeBrowsingServer(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - srv := filter.NewSafeBrowsingServer(hashes, nil) + srv := hashprefix.NewMatcher(map[string]*hashprefix.Storage{ + suffix: hashes, + }) var gotHashStrs []string var matched bool - gotHashStrs, matched, err = srv.Hashes(ctx, tc.host) + gotHashStrs, matched, err = srv.MatchByPrefix(ctx, tc.host) require.NoError(t, err) assert.Equal(t, tc.wantMatched, matched) diff --git a/internal/filter/hashstorage/hashstorage.go b/internal/filter/hashprefix/storage.go similarity index 77% rename from internal/filter/hashstorage/hashstorage.go rename to internal/filter/hashprefix/storage.go index 29afda9..49a522f 100644 --- a/internal/filter/hashstorage/hashstorage.go +++ b/internal/filter/hashprefix/storage.go @@ -1,6 +1,4 @@ -// Package hashstorage defines a storage of hashes of domain names used for -// filtering. -package hashstorage +package hashprefix import ( "bufio" @@ -11,47 +9,19 @@ import ( "sync" ) -// Hash and hash part length constants. -const ( - // PrefixLen is the length of the prefix of the hash of the filtered - // hostname. - PrefixLen = 2 - - // PrefixEncLen is the encoded length of the hash prefix. Two text - // bytes per one binary byte. - PrefixEncLen = PrefixLen * 2 - - // hashLen is the length of the whole hash of the checked hostname. - hashLen = sha256.Size - - // suffixLen is the length of the suffix of the hash of the filtered - // hostname. - suffixLen = hashLen - PrefixLen - - // hashEncLen is the encoded length of the hash. Two text bytes per one - // binary byte. - hashEncLen = hashLen * 2 -) - -// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a -// host being checked. -type Prefix [PrefixLen]byte - -// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a -// host being checked. -type suffix [suffixLen]byte - // Storage stores hashes of the filtered hostnames. All methods are safe for // concurrent use. +// +// TODO(a.garipov): See if we could unexport this. type Storage struct { // mu protects hashSuffixes. mu *sync.RWMutex hashSuffixes map[Prefix][]suffix } -// New returns a new hash storage containing hashes of the domain names listed -// in hostnames, one domain name per line. -func New(hostnames string) (s *Storage, err error) { +// NewStorage returns a new hash storage containing hashes of the domain names +// listed in hostnames, one domain name per line. +func NewStorage(hostnames string) (s *Storage, err error) { s = &Storage{ mu: &sync.RWMutex{}, hashSuffixes: map[Prefix][]suffix{}, diff --git a/internal/filter/hashstorage/hashstorage_test.go b/internal/filter/hashprefix/storage_test.go similarity index 71% rename from internal/filter/hashstorage/hashstorage_test.go rename to internal/filter/hashprefix/storage_test.go index 944d304..98f6a94 100644 --- a/internal/filter/hashstorage/hashstorage_test.go +++ b/internal/filter/hashprefix/storage_test.go @@ -1,4 +1,4 @@ -package hashstorage_test +package hashprefix_test import ( "crypto/sha256" @@ -8,61 +8,55 @@ import ( "strings" "testing" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Common hostnames for tests. -const ( - testHost = "porn.example" - otherHost = "otherporn.example" -) - func TestStorage_Hashes(t *testing.T) { - s, err := hashstorage.New(testHost) + s, err := hashprefix.NewStorage(testHost) require.NoError(t, err) h := sha256.Sum256([]byte(testHost)) want := []string{hex.EncodeToString(h[:])} - p := hashstorage.Prefix{h[0], h[1]} - got := s.Hashes([]hashstorage.Prefix{p}) + p := hashprefix.Prefix{h[0], h[1]} + got := s.Hashes([]hashprefix.Prefix{p}) assert.Equal(t, want, got) - wrong := s.Hashes([]hashstorage.Prefix{{}}) + wrong := s.Hashes([]hashprefix.Prefix{{}}) assert.Empty(t, wrong) } func TestStorage_Matches(t *testing.T) { - s, err := hashstorage.New(testHost) + s, err := hashprefix.NewStorage(testHost) require.NoError(t, err) got := s.Matches(testHost) assert.True(t, got) - got = s.Matches(otherHost) + got = s.Matches(testOtherHost) assert.False(t, got) } func TestStorage_Reset(t *testing.T) { - s, err := hashstorage.New(testHost) + s, err := hashprefix.NewStorage(testHost) require.NoError(t, err) - n, err := s.Reset(otherHost) + n, err := s.Reset(testOtherHost) require.NoError(t, err) assert.Equal(t, 1, n) - h := sha256.Sum256([]byte(otherHost)) + h := sha256.Sum256([]byte(testOtherHost)) want := []string{hex.EncodeToString(h[:])} - p := hashstorage.Prefix{h[0], h[1]} - got := s.Hashes([]hashstorage.Prefix{p}) + p := hashprefix.Prefix{h[0], h[1]} + got := s.Hashes([]hashprefix.Prefix{p}) assert.Equal(t, want, got) prevHash := sha256.Sum256([]byte(testHost)) - prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}}) + prev := s.Hashes([]hashprefix.Prefix{{prevHash[0], prevHash[1]}}) assert.Empty(t, prev) } @@ -80,12 +74,12 @@ func BenchmarkStorage_Hashes(b *testing.B) { hosts = append(hosts, fmt.Sprintf("%d."+testHost, i)) } - s, err := hashstorage.New(strings.Join(hosts, "\n")) + s, err := hashprefix.NewStorage(strings.Join(hosts, "\n")) require.NoError(b, err) - var hashPrefixes []hashstorage.Prefix + var hashPrefixes []hashprefix.Prefix for i := 0; i < 4; i++ { - hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]}) + hashPrefixes = append(hashPrefixes, hashprefix.Prefix{hosts[i][0], hosts[i][1]}) } for n := 1; n <= 4; n++ { @@ -104,7 +98,7 @@ func BenchmarkStorage_Hashes(b *testing.B) { // // goos: linux // goarch: amd64 - // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics // BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op // BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op @@ -121,7 +115,7 @@ func BenchmarkStorage_ResetHosts(b *testing.B) { } hostnames := strings.Join(hosts, "\n") - s, err := hashstorage.New(hostnames) + s, err := hashprefix.NewStorage(hostnames) require.NoError(b, err) b.ReportAllocs() @@ -136,7 +130,7 @@ func BenchmarkStorage_ResetHosts(b *testing.B) { // // goos: linux // goarch: amd64 - // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics // BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op } diff --git a/internal/filter/hashprefix_test.go b/internal/filter/hashprefix_test.go deleted file mode 100644 index fac4d93..0000000 --- a/internal/filter/hashprefix_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package filter_test - -import ( - "context" - "net" - "os" - "path/filepath" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" - "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" - "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { - cacheDir := t.TempDir() - cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing)) - err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644) - require.NoError(t, err) - - hashes, err := hashstorage.New("") - require.NoError(t, err) - - errColl := &agdtest.ErrorCollector{ - OnCollect: func(_ context.Context, err error) { - panic("not implemented") - }, - } - - resolver := &agdtest.Resolver{ - OnLookupIP: func( - _ context.Context, - _ netutil.AddrFamily, - _ string, - ) (ips []net.IP, err error) { - return []net.IP{safeBrowsingSafeIP4}, nil - }, - } - - c := prepareConf(t) - - c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ - Hashes: hashes, - ErrColl: errColl, - Resolver: resolver, - ID: agd.FilterListIDSafeBrowsing, - CachePath: cachePath, - ReplacementHost: safeBrowsingSafeHost, - Staleness: 1 * time.Hour, - CacheTTL: 10 * time.Second, - CacheSize: 100, - }) - require.NoError(t, err) - - c.ErrColl = errColl - c.Resolver = resolver - - s, err := filter.NewDefaultStorage(c) - require.NoError(t, err) - - g := &agd.FilteringGroup{ - ID: "default", - RuleListIDs: []agd.FilterListID{}, - ParentalEnabled: true, - SafeBrowsingEnabled: true, - } - - // Test - - req := &dns.Msg{ - Question: []dns.Question{{ - Name: safeBrowsingSubFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r) - - assert.Equal(t, rm.Rule, agd.FilterRuleText(safeBrowsingHost)) - assert.Equal(t, rm.List, agd.FilterListIDSafeBrowsing) -} diff --git a/internal/filter/internal/composite/composite.go b/internal/filter/internal/composite/composite.go new file mode 100644 index 0000000..e480c6f --- /dev/null +++ b/internal/filter/internal/composite/composite.go @@ -0,0 +1,369 @@ +// Package composite implements a composite filter based on several types of +// filters and the logic of the filter application. +package composite + +import ( + "context" + "fmt" + "strings" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" +) + +// Filter is a composite filter based on several types of safe-search and +// rule-list filters. +// +// An empty composite filter is a filter that always returns a nil filtering +// result. +type Filter struct { + safeBrowsing *hashprefix.Filter + adultBlocking *hashprefix.Filter + + genSafeSearch *safesearch.Filter + ytSafeSearch *safesearch.Filter + + // custom is the custom rule-list filter of the profile, if any. + custom *rulelist.Immutable + + // ruleLists are the enabled rule-list filters of the profile or filtering + // group. + ruleLists []*rulelist.Refreshable + + // svcLists are the rule-list filters of the profile's enabled blocked + // services, if any. + svcLists []*rulelist.Immutable +} + +// Config is the configuration structure for the composite filter. +type Config struct { + // SafeBrowsing is the safe-browsing filter to apply, if any. + SafeBrowsing *hashprefix.Filter + + // AdultBlocking is the adult-content filter to apply, if any. + AdultBlocking *hashprefix.Filter + + // GeneralSafeSearch is the general safe-search filter to apply, if any. + GeneralSafeSearch *safesearch.Filter + + // YouTubeSafeSearch is the youtube safe-search filter to apply, if any. + YouTubeSafeSearch *safesearch.Filter + + // Custom is the custom rule-list filter of the profile, if any. + Custom *rulelist.Immutable + + // RuleLists are the enabled rule-list filters of the profile or filtering + // group. + RuleLists []*rulelist.Refreshable + + // ServiceLists are the rule-list filters of the profile's enabled blocked + // services, if any. + ServiceLists []*rulelist.Immutable +} + +// New returns a new composite filter. If c is nil or empty, f returns a filter +// that always returns a nil filtering result. +func New(c *Config) (f *Filter) { + if c == nil { + return &Filter{} + } + + return &Filter{ + safeBrowsing: c.SafeBrowsing, + adultBlocking: c.AdultBlocking, + genSafeSearch: c.GeneralSafeSearch, + ytSafeSearch: c.YouTubeSafeSearch, + custom: c.Custom, + ruleLists: c.RuleLists, + svcLists: c.ServiceLists, + } +} + +// type check +var _ internal.Interface = (*Filter)(nil) + +// FilterRequest implements the [internal.Interface] interface for *Filter. If +// there is a safe-search result, it returns it. Otherwise, it returns the +// action created from the filter list network rule with the highest priority. +// If f is empty, it returns nil with no error. +func (f *Filter) FilterRequest( + ctx context.Context, + req *dns.Msg, + ri *agd.RequestInfo, +) (r internal.Result, err error) { + if f.isEmpty() { + return nil, nil + } + + // Prepare common data for filters. + reqID := ri.ID + log.Debug("filters: filtering req %s: %d rule lists", reqID, len(f.ruleLists)) + + // Firstly, check the profile's rule-list filtering, the custom rules, and + // the rules from blocked services settings. + host := ri.Host + rlRes := f.filterWithRuleLists(ri, host, ri.QType, req, false) + switch flRes := rlRes.(type) { + case *internal.ResultAllowed: + // Skip any additional filtering if the domain is explicitly allowed by + // user's custom rule. + if flRes.List == agd.FilterListIDCustom { + return flRes, nil + } + case *internal.ResultBlocked: + // Skip any additional filtering if the domain is already blocked. + return flRes, nil + default: + // Go on. + } + + // Secondly, apply the safe browsing and safe search request filters in the + // following order. + // + // DO NOT change the order of reqFilters without necessity. + reqFilters := []struct { + filter internal.RequestFilter + id agd.FilterListID + }{{ + filter: nullify(f.safeBrowsing), + id: agd.FilterListIDSafeBrowsing, + }, { + filter: nullify(f.adultBlocking), + id: agd.FilterListIDAdultBlocking, + }, { + filter: nullify(f.genSafeSearch), + id: agd.FilterListIDGeneralSafeSearch, + }, { + filter: nullify(f.ytSafeSearch), + id: agd.FilterListIDYoutubeSafeSearch, + }} + + for _, rf := range reqFilters { + if rf.filter == nil { + continue + } + + log.Debug("filter %s: filtering req %s", rf.id, reqID) + r, err = rf.filter.FilterRequest(ctx, req, ri) + log.Debug("filter %s: finished filtering req %s, errors: %v", rf.id, reqID, err) + if err != nil { + return nil, err + } else if r != nil { + return r, nil + } + } + + // Thirdly, return the previously obtained filter list result. + return rlRes, nil +} + +// nullify returns a nil interface value if flt is a nil pointer. Otherwise, it +// returns flt converted to the interface type. It is used to avoid situations +// where an interface value doesn't have any data but does have a type. +func nullify[T *safesearch.Filter | *hashprefix.Filter](flt T) (fr internal.RequestFilter) { + if flt == nil { + return nil + } + + return internal.RequestFilter(flt) +} + +// FilterResponse implements the [internal.Interface] interface for *Filter. It +// returns the action created from the filter list network rule with the highest +// priority. If f is empty, it returns nil with no error. +func (f *Filter) FilterResponse( + ctx context.Context, + resp *dns.Msg, + ri *agd.RequestInfo, +) (r internal.Result, err error) { + if f.isEmpty() { + return nil, nil + } + + for _, ans := range resp.Answer { + host, rrType, ok := parseRespAnswer(ans) + if !ok { + continue + } + + r = f.filterWithRuleLists(ri, host, rrType, resp, true) + if r != nil { + break + } + } + + return r, nil +} + +// parseRespAnswer parses hostname and rrType from the answer if there are any. +// If ans is of a type that doesn't have an IP address or a hostname in it, ok +// is false. +func parseRespAnswer(ans dns.RR) (hostname string, rrType dnsmsg.RRType, ok bool) { + switch ans := ans.(type) { + case *dns.A: + return ans.A.String(), dns.TypeA, true + case *dns.AAAA: + return ans.AAAA.String(), dns.TypeAAAA, true + case *dns.CNAME: + return strings.TrimSuffix(ans.Target, "."), dns.TypeCNAME, true + default: + return "", dns.TypeNone, false + } +} + +// isEmpty returns true if this composite filter is an empty filter. +func (f *Filter) isEmpty() (ok bool) { + return f == nil || + (f.safeBrowsing == nil && + f.adultBlocking == nil && + f.genSafeSearch == nil && + f.ytSafeSearch == nil && + f.custom == nil && + len(f.ruleLists) == 0 && + len(f.svcLists) == 0) +} + +// filterWithRuleLists filters one question's or answer's information through +// all rule list filters of the composite filter. +func (f *Filter) filterWithRuleLists( + ri *agd.RequestInfo, + host string, + rrType dnsmsg.RRType, + msg *dns.Msg, + isAnswer bool, +) (r internal.Result) { + var devName string + if d := ri.Device; d != nil { + devName = string(d.Name) + } + + ufRes := &urlFilterResult{} + for _, rl := range f.ruleLists { + ufRes.add(rl.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer)) + } + + if f.custom != nil { + dr := f.custom.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer) + // Collect only custom $dnsrewrite rules. It's much easier to process + // dnsrewrite rules only from one list, cause when there is no problem + // with merging them among different lists. + if !isAnswer { + modified := processDNSRewrites(ri.Messages, msg, dr.DNSRewrites(), host) + if modified != nil { + return modified + } + } + + ufRes.add(dr) + } + + for _, rl := range f.svcLists { + ufRes.add(rl.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer)) + } + + mr := rules.NewMatchingResult(ufRes.networkRules, nil) + if nr := mr.GetBasicResult(); nr != nil { + return f.ruleDataToResult(nr.FilterListID, nr.RuleText, nr.Whitelist) + } + + return f.hostsRulesToResult(ufRes.hostRules4, ufRes.hostRules6, rrType) +} + +// mustRuleListDataByURLFilterID returns the rule list data by its synthetic +// integer ID in the urlfilter engine. It panics if id is not found. +func (f *Filter) mustRuleListDataByURLFilterID( + id int, +) (fltID agd.FilterListID, svcID agd.BlockedServiceID) { + for _, rl := range f.ruleLists { + if rl.URLFilterID() == id { + return rl.ID() + } + } + + if rl := f.custom; rl != nil && rl.URLFilterID() == id { + return rl.ID() + } + + for _, rl := range f.svcLists { + if rl.URLFilterID() == id { + return rl.ID() + } + } + + // Technically shouldn't happen, since id is supposed to be among the rule + // list filters in the composite filter. + panic(fmt.Errorf("filter: synthetic id %d not found", id)) +} + +// hostsRulesToResult converts /etc/hosts-style rules into a filtering action. +func (f *Filter) hostsRulesToResult( + hostRules4 []*rules.HostRule, + hostRules6 []*rules.HostRule, + rrType dnsmsg.RRType, +) (r internal.Result) { + if len(hostRules4) == 0 && len(hostRules6) == 0 { + return nil + } + + // Only use the first matched rule, since we currently don't care about the + // IP addresses in the rule. If the request is neither an A one nor an AAAA + // one, or if there are no matching rules of the requested type, then use + // whatever rule isn't empty. + // + // See also AGDNS-591. + var resHostRule *rules.HostRule + if rrType == dns.TypeA && len(hostRules4) > 0 { + resHostRule = hostRules4[0] + } else if rrType == dns.TypeAAAA && len(hostRules6) > 0 { + resHostRule = hostRules6[0] + } else { + if len(hostRules4) > 0 { + resHostRule = hostRules4[0] + } else { + resHostRule = hostRules6[0] + } + } + + return f.ruleDataToResult(resHostRule.FilterListID, resHostRule.RuleText, false) +} + +// ruleDataToResult converts a urlfilter rule data into a filtering result. +func (f *Filter) ruleDataToResult( + urlFilterID int, + ruleText string, + allowlist bool, +) (r internal.Result) { + // Use the urlFilterID crutch to find the actual IDs of the filtering rule + // list and blocked service. + fltID, svcID := f.mustRuleListDataByURLFilterID(urlFilterID) + + var rule agd.FilterRuleText + if fltID == agd.FilterListIDBlockedService { + rule = agd.FilterRuleText(svcID) + } else { + rule = agd.FilterRuleText(ruleText) + } + + if allowlist { + log.Debug("rule list %s: allowed by rule %s", fltID, rule) + + return &internal.ResultAllowed{ + List: fltID, + Rule: rule, + } + } + + log.Debug("rule list %s: blocked by rule %s", fltID, rule) + + return &internal.ResultBlocked{ + List: fltID, + Rule: rule, + } +} diff --git a/internal/filter/internal/composite/composite_test.go b/internal/filter/internal/composite/composite_test.go new file mode 100644 index 0000000..8e517d4 --- /dev/null +++ b/internal/filter/internal/composite/composite_test.go @@ -0,0 +1,521 @@ +package composite_test + +import ( + "context" + "net" + "net/http" + "net/netip" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/composite" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// Common filter list IDs for tests. +const ( + testFltListID1 agd.FilterListID = "fl1" + testFltListID2 agd.FilterListID = "fl2" +) + +// newFromStr is a helper to create a rule-list filter from a rule text and a +// filtering-list ID. +func newFromStr(tb testing.TB, text string, id agd.FilterListID) (rl *rulelist.Refreshable) { + tb.Helper() + + rl, err := rulelist.NewFromString(text, id, "", 0, false) + require.NoError(tb, err) + + return rl +} + +// newImmutable is a helper to create an immutable rule-list filter from a rule +// text and a filtering-list ID. +func newImmutable(tb testing.TB, text string, id agd.FilterListID) (rl *rulelist.Immutable) { + tb.Helper() + + rl, err := rulelist.NewImmutable(text, id, "", 0, false) + require.NoError(tb, err) + + return rl +} + +// newReqData returns data for calling FilterRequest. The context uses +// [filtertest.Timeout] and [tb.Cleanup] is used for its cancellation. Both req +// and ri use [filtertest.ReqFQDN], [dns.TypeA], and [dns.ClassINET] for the +// request data. +func newReqData(tb testing.TB) (ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + tb.Cleanup(cancel) + + req = dnsservertest.NewReq(filtertest.ReqFQDN, dns.TypeA, dns.ClassINET) + ri = &agd.RequestInfo{ + Messages: agdtest.NewConstructor(), + Host: filtertest.ReqHost, + QType: dns.TypeA, + QClass: dns.ClassINET, + } + + return ctx, req, ri +} + +func TestFilter_nil(t *testing.T) { + testCases := []struct { + flt *composite.Filter + name string + }{{ + flt: nil, + name: "nil", + }, { + flt: composite.New(nil), + name: "config_nil", + }, { + flt: composite.New(&composite.Config{}), + name: "config_empty", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, req, ri := newReqData(t) + res, err := tc.flt.FilterRequest(ctx, req, ri) + assert.NoError(t, err) + assert.Nil(t, res) + + resp := dnsservertest.NewResp(dns.RcodeSuccess, req) + res, err = tc.flt.FilterResponse(ctx, resp, ri) + assert.NoError(t, err) + assert.Nil(t, res) + }) + } +} + +func TestFilter_FilterRequest_badfilter(t *testing.T) { + const ( + blockRule = filtertest.BlockRule + badFilterRule = filtertest.BlockRule + "$badfilter" + ) + + rl1 := newFromStr(t, blockRule, testFltListID1) + rl2 := newFromStr(t, badFilterRule, testFltListID2) + + testCases := []struct { + name string + wantRes internal.Result + ruleLists []*rulelist.Refreshable + }{{ + name: "block", + wantRes: &internal.ResultBlocked{ + List: testFltListID1, + Rule: blockRule, + }, + ruleLists: []*rulelist.Refreshable{rl1}, + }, { + name: "badfilter_no_block", + wantRes: nil, + ruleLists: []*rulelist.Refreshable{rl2}, + }, { + name: "badfilter_removes_block", + wantRes: nil, + ruleLists: []*rulelist.Refreshable{rl1, rl2}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + f := composite.New(&composite.Config{ + RuleLists: tc.ruleLists, + }) + + ctx, req, ri := newReqData(t) + res, err := f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestFilter_FilterRequest_customAllow(t *testing.T) { + const allowRule = "@@" + filtertest.BlockRule + + blockingRL := newFromStr(t, filtertest.BlockRule, testFltListID1) + customRL := newImmutable(t, allowRule, agd.FilterListIDCustom) + + f := composite.New(&composite.Config{ + Custom: customRL, + RuleLists: []*rulelist.Refreshable{blockingRL}, + }) + + ctx, req, ri := newReqData(t) + res, err := f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + want := &internal.ResultAllowed{ + List: agd.FilterListIDCustom, + Rule: allowRule, + } + assert.Equal(t, want, res) +} + +func TestFilter_FilterRequest_dnsrewrite(t *testing.T) { + const ( + blockRule = filtertest.BlockRule + dnsRewriteRuleRefused = filtertest.BlockRule + "$dnsrewrite=REFUSED" + dnsRewriteRuleCname = filtertest.BlockRule + "$dnsrewrite=new-cname.example" + dnsRewrite2Rules = filtertest.BlockRule + "$dnsrewrite=1.2.3.4\n" + + filtertest.BlockRule + "$dnsrewrite=1.2.3.5" + dnsRewriteRuleTXT = filtertest.BlockRule + "$dnsrewrite=NOERROR;TXT;abcdefg" + ) + + var ( + rlNonRewrite = newFromStr(t, blockRule, testFltListID1) + rlRewriteIgnored = newFromStr(t, dnsRewriteRuleRefused, testFltListID2) + rlCustomRefused = newImmutable(t, dnsRewriteRuleRefused, agd.FilterListIDCustom) + rlCustomCname = newImmutable(t, dnsRewriteRuleCname, agd.FilterListIDCustom) + rlCustom2Rules = newImmutable(t, dnsRewrite2Rules, agd.FilterListIDCustom) + ) + + req := dnsservertest.NewReq(filtertest.ReqFQDN, dns.TypeA, dns.ClassINET) + + // Create a CNAME-modified request. + modifiedReq := dnsmsg.Clone(req) + modifiedReq.Question[0].Name = "new-cname.example." + + testCases := []struct { + custom *rulelist.Immutable + wantRes internal.Result + name string + ruleLists []*rulelist.Refreshable + }{{ + custom: nil, + wantRes: &internal.ResultBlocked{List: testFltListID1, Rule: blockRule}, + name: "block", + ruleLists: []*rulelist.Refreshable{rlNonRewrite}, + }, { + custom: nil, + wantRes: &internal.ResultBlocked{List: testFltListID1, Rule: blockRule}, + name: "dnsrewrite_no_effect", + ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored}, + }, { + custom: rlCustomRefused, + wantRes: &internal.ResultModified{ + Msg: dnsservertest.NewResp(dns.RcodeRefused, req), + List: agd.FilterListIDCustom, + Rule: dnsRewriteRuleRefused, + }, + name: "dnsrewrite_block", + ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored}, + }, { + custom: rlCustomCname, + wantRes: &internal.ResultModified{ + Msg: modifiedReq, + List: agd.FilterListIDCustom, + Rule: dnsRewriteRuleCname, + }, + name: "dnsrewrite_cname", + ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored}, + }, { + custom: rlCustom2Rules, + wantRes: &internal.ResultModified{ + Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{ + dnsservertest.NewA( + filtertest.ReqFQDN, + agdtest.FilteredResponseTTLSec, + net.IP{1, 2, 3, 4}, + ), + dnsservertest.NewA( + filtertest.ReqFQDN, + agdtest.FilteredResponseTTLSec, + net.IP{1, 2, 3, 5}, + ), + }), + List: agd.FilterListIDCustom, + Rule: "", + }, + name: "dnsrewrite_answers", + ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + f := composite.New(&composite.Config{ + Custom: tc.custom, + RuleLists: tc.ruleLists, + }) + + ctx := context.Background() + ri := &agd.RequestInfo{ + Messages: agdtest.NewConstructor(), + Host: filtertest.ReqHost, + QType: dns.TypeA, + } + + res, fltErr := f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestFilter_FilterRequest_hostsRules(t *testing.T) { + const ( + reqHost4 = "www.example.com" + reqHost6 = "www.example.net" + ) + + const ( + blockRule4 = "127.0.0.1 www.example.com" + blockRule6 = "::1 www.example.net" + rules = blockRule4 + "\n" + blockRule6 + ) + + rl := newFromStr(t, rules, testFltListID1) + f := composite.New(&composite.Config{ + RuleLists: []*rulelist.Refreshable{rl}, + }) + + resBlocked4 := &internal.ResultBlocked{ + List: testFltListID1, + Rule: blockRule4, + } + + resBlocked6 := &internal.ResultBlocked{ + List: testFltListID1, + Rule: blockRule6, + } + + testCases := []struct { + wantRes internal.Result + name string + reqHost string + reqType dnsmsg.RRType + }{{ + wantRes: resBlocked4, + name: "a", + reqHost: reqHost4, + reqType: dns.TypeA, + }, { + wantRes: resBlocked6, + name: "aaaa", + reqHost: reqHost6, + reqType: dns.TypeAAAA, + }, { + wantRes: resBlocked6, + name: "a_with_ipv6_rule", + reqHost: reqHost6, + reqType: dns.TypeA, + }, { + wantRes: resBlocked4, + name: "aaaa_with_ipv4_rule", + reqHost: reqHost4, + reqType: dns.TypeAAAA, + }, { + wantRes: resBlocked4, + name: "mx", + reqHost: reqHost4, + reqType: dns.TypeMX, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ri := &agd.RequestInfo{ + Messages: agdtest.NewConstructor(), + Host: tc.reqHost, + QType: tc.reqType, + } + + req := &dns.Msg{ + Question: []dns.Question{{ + Name: dns.Fqdn(tc.reqHost), + Qtype: tc.reqType, + Qclass: dns.ClassINET, + }}, + } + + ctx := context.Background() + + res, rerr := f.FilterRequest(ctx, req, ri) + require.NoError(t, rerr) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestFilter_FilterRequest_safeSearch(t *testing.T) { + const safeSearchIPStr = "1.2.3.4" + + const rewriteRule = filtertest.BlockRule + "$dnsrewrite=NOERROR;A;" + safeSearchIPStr + + var safeSearchIP net.IP = netip.MustParseAddr(safeSearchIPStr).AsSlice() + cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, rewriteRule, http.StatusOK) + + const fltListID = agd.FilterListIDGeneralSafeSearch + + gen := safesearch.New(&safesearch.Config{ + List: &agd.FilterList{ + URL: srvURL, + ID: fltListID, + }, + Resolver: &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + return []net.IP{safeSearchIP}, nil + }, + }, + ErrColl: &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, _ error) { + panic("not implemented") + }, + }, + CacheDir: filepath.Dir(cachePath), + CacheTTL: 1 * time.Minute, + CacheSize: 100, + }) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + err := gen.Refresh(ctx, false) + require.NoError(t, err) + + f := composite.New(&composite.Config{ + GeneralSafeSearch: gen, + }) + + ctx, req, ri := newReqData(t) + res, err := f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + wantResp := dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{ + dnsservertest.NewA(filtertest.ReqFQDN, agdtest.FilteredResponseTTLSec, safeSearchIP), + }) + want := &internal.ResultModified{ + Msg: wantResp, + List: fltListID, + Rule: filtertest.ReqHost, + } + assert.Equal(t, want, res) +} + +func TestFilter_FilterRequest_services(t *testing.T) { + const svcID = "test_service" + + svcRL, err := rulelist.NewImmutable( + filtertest.BlockRule, + agd.FilterListIDBlockedService, + svcID, + 0, + false, + ) + require.NoError(t, err) + + f := composite.New(&composite.Config{ + ServiceLists: []*rulelist.Immutable{svcRL}, + }) + + ctx, req, ri := newReqData(t) + res, err := f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + want := &internal.ResultBlocked{ + List: agd.FilterListIDBlockedService, + Rule: svcID, + } + assert.Equal(t, want, res) +} + +func TestFilter_FilterResponse(t *testing.T) { + const cnameReqFQDN = "sub." + filtertest.ReqFQDN + + const ( + blockedCNAME = filtertest.ReqHost + blockedIPv4Str = "1.2.3.4" + blockedIPv6Str = "1234::cdef" + blockRules = blockedCNAME + "\n" + blockedIPv4Str + "\n" + blockedIPv6Str + "\n" + ) + + var ( + blockedIPv4 net.IP = netip.MustParseAddr(blockedIPv4Str).AsSlice() + blockedIPv6 net.IP = netip.MustParseAddr(blockedIPv6Str).AsSlice() + ) + + blockingRL := newFromStr(t, blockRules, testFltListID1) + f := composite.New(&composite.Config{ + RuleLists: []*rulelist.Refreshable{blockingRL}, + }) + + const ttl = agdtest.FilteredResponseTTLSec + + testCases := []struct { + name string + reqFQDN string + wantRule agd.FilterRuleText + respAns dnsservertest.SectionAnswer + qType dnsmsg.RRType + }{{ + name: "cname", + reqFQDN: cnameReqFQDN, + wantRule: filtertest.ReqHost, + respAns: dnsservertest.SectionAnswer{ + dnsservertest.NewCNAME(cnameReqFQDN, ttl, filtertest.ReqFQDN), + dnsservertest.NewA(filtertest.ReqFQDN, ttl, net.IP{1, 2, 3, 4}), + }, + qType: dns.TypeA, + }, { + name: "ipv4", + reqFQDN: filtertest.ReqFQDN, + wantRule: blockedIPv4Str, + respAns: dnsservertest.SectionAnswer{ + dnsservertest.NewA(filtertest.ReqFQDN, ttl, blockedIPv4), + }, + qType: dns.TypeA, + }, { + name: "ipv6", + reqFQDN: filtertest.ReqFQDN, + wantRule: blockedIPv6Str, + respAns: dnsservertest.SectionAnswer{ + dnsservertest.NewAAAA(filtertest.ReqFQDN, ttl, blockedIPv6), + }, + qType: dns.TypeAAAA, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, req, ri := newReqData(t) + req.Question[0].Name = tc.reqFQDN + req.Question[0].Qtype = tc.qType + + resp := dnsservertest.NewResp(dns.RcodeSuccess, req, tc.respAns) + res, err := f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + want := &internal.ResultBlocked{ + List: testFltListID1, + Rule: tc.wantRule, + } + assert.Equal(t, want, res) + }) + } +} diff --git a/internal/filter/internal/composite/dnsresult.go b/internal/filter/internal/composite/dnsresult.go new file mode 100644 index 0000000..7d4cd27 --- /dev/null +++ b/internal/filter/internal/composite/dnsresult.go @@ -0,0 +1,27 @@ +package composite + +import ( + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/rules" +) + +// urlFilterResult is an entity simplifying the collection and compilation of +// urlfilter results. +// +// TODO(a.garipov): Think of ways to move all urlfilter result processing to +// ./internal/rulelist. +type urlFilterResult struct { + networkRules []*rules.NetworkRule + hostRules4 []*rules.HostRule + hostRules6 []*rules.HostRule +} + +// add appends the rules from dr to the slices within r. If dr is nil, add does +// nothing. +func (r *urlFilterResult) add(dr *urlfilter.DNSResult) { + if dr != nil { + r.networkRules = append(r.networkRules, dr.NetworkRules...) + r.hostRules4 = append(r.hostRules4, dr.HostRulesV4...) + r.hostRules6 = append(r.hostRules6, dr.HostRulesV6...) + } +} diff --git a/internal/filter/dnsrewrite.go b/internal/filter/internal/composite/dnsrewrite.go similarity index 84% rename from internal/filter/dnsrewrite.go rename to internal/filter/internal/composite/dnsrewrite.go index df5126c..8b30f45 100644 --- a/internal/filter/dnsrewrite.go +++ b/internal/filter/internal/composite/dnsrewrite.go @@ -1,11 +1,13 @@ -package filter +package composite import ( "fmt" "net" + "strings" "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" @@ -13,13 +15,13 @@ import ( ) // processDNSRewrites processes $dnsrewrite rules dnsr and creates a filtering -// result, if necessary. +// result, if necessary. res.List, if any, is set to [agd.FilterListIDCustom]. func processDNSRewrites( messages *dnsmsg.Constructor, req *dns.Msg, dnsr []*rules.NetworkRule, host string, -) (res *ResultModified) { +) (res *internal.ResultModified) { if len(dnsr) == 0 { return nil } @@ -27,16 +29,18 @@ func processDNSRewrites( dnsRewriteResult := processDNSRewriteRules(dnsr) if resCanonName := dnsRewriteResult.CanonName; resCanonName != "" { - if resCanonName == host { + // Rewrite the question name to a matched CNAME. + if strings.EqualFold(resCanonName, host) { // A rewrite of a host to itself. return nil } - resp := messages.NewRespMsg(req) - resp.Answer = append(resp.Answer, messages.NewAnswerCNAME(req, resCanonName)) + req = dnsmsg.Clone(req) + req.Question[0].Name = dns.Fqdn(resCanonName) - return &ResultModified{ - Msg: resp, + return &internal.ResultModified{ + Msg: req, + List: agd.FilterListIDCustom, Rule: dnsRewriteResult.ResRuleText, } } @@ -45,8 +49,9 @@ func processDNSRewrites( resp := messages.NewRespMsg(req) resp.Rcode = dnsRewriteResult.RCode - return &ResultModified{ + return &internal.ResultModified{ Msg: resp, + List: agd.FilterListIDCustom, Rule: dnsRewriteResult.ResRuleText, } } @@ -56,36 +61,37 @@ func processDNSRewrites( return nil } - return &ResultModified{ - Msg: resp, + return &internal.ResultModified{ + Msg: resp, + List: agd.FilterListIDCustom, } } -// DNSRewriteResult is the result of application of $dnsrewrite rules. -type DNSRewriteResult struct { - Response DNSRewriteResultResponse +// dnsRewriteResult is the result of application of $dnsrewrite rules. +type dnsRewriteResult struct { + Response dnsRewriteResultResponse CanonName string ResRuleText agd.FilterRuleText Rules []*rules.NetworkRule RCode rules.RCode } -// DNSRewriteResultResponse is the collection of DNS response records +// dnsRewriteResultResponse is the collection of DNS response records // the server returns. -type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue +type dnsRewriteResultResponse map[rules.RRType][]rules.RRValue // processDNSRewriteRules processes DNS rewrite rules in dnsr. The result will // have either CanonName or RCode or Response set. -func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) { - dnsrr := &DNSRewriteResult{ - Response: DNSRewriteResultResponse{}, +func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *dnsRewriteResult) { + dnsrr := &dnsRewriteResult{ + Response: dnsRewriteResultResponse{}, } for _, rule := range dnsr { dr := rule.DNSRewrite if dr.NewCNAME != "" { // NewCNAME rules have a higher priority than other rules. - return &DNSRewriteResult{ + return &dnsRewriteResult{ ResRuleText: agd.FilterRuleText(rule.RuleText), CanonName: dr.NewCNAME, } @@ -99,7 +105,7 @@ func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) { default: // RcodeRefused and other such codes have higher priority. Return // immediately. - return &DNSRewriteResult{ + return &dnsRewriteResult{ ResRuleText: agd.FilterRuleText(rule.RuleText), RCode: dr.RCode, } @@ -114,7 +120,7 @@ func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) { func filterDNSRewrite( messages *dnsmsg.Constructor, req *dns.Msg, - dnsrr *DNSRewriteResult, + dnsrr *dnsRewriteResult, ) (resp *dns.Msg, err error) { if dnsrr.RCode != dns.RcodeSuccess { return nil, errors.Error("non-success answer") diff --git a/internal/filter/custom.go b/internal/filter/internal/custom/custom.go similarity index 61% rename from internal/filter/custom.go rename to internal/filter/internal/custom/custom.go index ac4ce24..eccb4e9 100644 --- a/internal/filter/custom.go +++ b/internal/filter/internal/custom/custom.go @@ -1,4 +1,6 @@ -package filter +// Package custom contains the caching storage of filters made from custom +// filtering rules of profiles. +package custom import ( "context" @@ -7,51 +9,50 @@ import ( "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" - "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" "github.com/bluele/gcache" ) -// Custom Filters For Profiles - -// customFilters contains custom filters made from custom filtering rules of -// profiles. -type customFilters struct { +// Filters contains custom filters made from custom filtering rules of profiles. +type Filters struct { cache gcache.Cache errColl agd.ErrorCollector } -// appendRuleLists appends the custom rule list filter made from the profile's -// custom rules list, if any, to rls. -func (f *customFilters) appendRuleLists( - ctx context.Context, - rls []*ruleListFilter, - p *agd.Profile, -) (res []*ruleListFilter) { +// New returns a new custom filter storage. +func New(cache gcache.Cache, errColl agd.ErrorCollector) (f *Filters) { + return &Filters{ + cache: cache, + errColl: errColl, + } +} + +// Get returns the custom rule-list filter made from the profile's custom rules +// list, if any. +func (f *Filters) Get(ctx context.Context, p *agd.Profile) (rl *rulelist.Immutable) { if len(p.CustomRules) == 0 { // Technically, there could be an old filter left in the cache, but it // will eventually be evicted, so don't do anything about it. - return rls + return nil } - optlog.Debug2("%s: compiling custom filter for profile %s", strgLogPrefix, p.ID) - defer optlog.Debug2("%s: finished compiling custom filter for profile %s", strgLogPrefix, p.ID) - // Report the custom filters cache lookup to prometheus so that we could // keep track of whether the cache size is enough. defer func() { - if rls == nil { + if rl == nil { metrics.FilterCustomCacheLookupsMisses.Inc() } else { metrics.FilterCustomCacheLookupsHits.Inc() } }() - rl := f.get(p) + rl = f.get(p) if rl != nil { - return append(rls, rl) + return rl } // TODO(a.garipov): Consider making a copy of strings.Join for @@ -68,8 +69,17 @@ func (f *customFilters) appendRuleLists( stringutil.WriteToBuilder(b, string(r), "\n") } - // Don't use cache for users' custom filters. - rl, err := newRuleListFltFromStr(b.String(), agd.FilterListIDCustom, string(p.ID), 0, false) + rl, err := rulelist.NewImmutable( + b.String(), + agd.FilterListIDCustom, + "", + // Don't use cache for users' custom filters, because resultcache + // doesn't take $client rules into account. + // + // TODO(a.garipov): Consider enabling caching if necessary. + 0, + false, + ) if err != nil { // In a rare situation where the custom rules are so badly formed that // we cannot even create a filtering engine, consider that there is no @@ -80,14 +90,16 @@ func (f *customFilters) appendRuleLists( return nil } + log.Info("%s/%s: got %d rules", agd.FilterListIDCustom, p.ID, rl.RulesCount()) + f.set(p, rl) - return append(rls, rl) + return rl } -// get returns the cached custom rule list filter, if there is one and the +// get returns the cached custom rule-list filter, if there is one and the // profile hasn't changed since the filter was cached. -func (f *customFilters) get(p *agd.Profile) (rl *ruleListFilter) { +func (f *Filters) get(p *agd.Profile) (rl *rulelist.Immutable) { itemVal, err := f.cache.Get(p.ID) if errors.Is(err, gcache.KeyNotFoundError) { return nil @@ -104,8 +116,8 @@ func (f *customFilters) get(p *agd.Profile) (rl *ruleListFilter) { return item.ruleList } -// set caches the custom rule list filter. -func (f *customFilters) set(p *agd.Profile, rl *ruleListFilter) { +// set caches the custom rule-list filter. +func (f *Filters) set(p *agd.Profile, rl *rulelist.Immutable) { item := &customFilterCacheItem{ updTime: p.UpdateTime, ruleList: rl, @@ -121,5 +133,5 @@ func (f *customFilters) set(p *agd.Profile, rl *ruleListFilter) { // customFilterCacheItem is an item of the custom filter cache. type customFilterCacheItem struct { updTime time.Time - ruleList *ruleListFilter + ruleList *rulelist.Immutable } diff --git a/internal/filter/internal/custom/custom_test.go b/internal/filter/internal/custom/custom_test.go new file mode 100644 index 0000000..136a5c6 --- /dev/null +++ b/internal/filter/internal/custom/custom_test.go @@ -0,0 +1,102 @@ +package custom_test + +import ( + "context" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/golibs/testutil" + "github.com/bluele/gcache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// testProfID is the profile ID for tests. +const testProfID agd.ProfileID = "prof1234" + +func TestFilters_Get(t *testing.T) { + f := custom.New( + gcache.New(1).LRU().Build(), + &agdtest.ErrorCollector{ + OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, + }, + ) + + p := &agd.Profile{ + ID: testProfID, + UpdateTime: time.Now(), + CustomRules: []agd.FilterRuleText{ + "||first.example", + }, + } + + ctx := context.Background() + + rl := f.Get(ctx, p) + require.NotNil(t, rl) + + // Recheck cached. + cachedRL := f.Get(ctx, p) + require.NotNil(t, cachedRL) + + assert.Same(t, rl, cachedRL) +} + +var ruleListSink *rulelist.Immutable + +func BenchmarkFilters_Get(b *testing.B) { + f := custom.New( + gcache.New(1).LRU().Build(), + &agdtest.ErrorCollector{ + OnCollect: func(ctx context.Context, err error) { panic("not implemented") }, + }, + ) + + p := &agd.Profile{ + ID: testProfID, + UpdateTime: time.Now(), + CustomRules: []agd.FilterRuleText{ + "||first.example", + "||second.example", + "||third.example", + }, + } + + ctx := context.Background() + + b.Run("cache", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ruleListSink = f.Get(ctx, p) + } + }) + + b.Run("no_cache", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Update the time on each iteration to make sure that the cache is + // never used. + p.UpdateTime = p.UpdateTime.Add(1 * time.Millisecond) + ruleListSink = f.Get(ctx, p) + } + }) + + // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkFilters_Get/cache-16 7870251 233.4 ns/op 16 B/op 1 allocs/op + // BenchmarkFilters_Get/no_cache-16 53073 23490 ns/op 14610 B/op 93 allocs/op +} diff --git a/internal/filter/internal/filtertest/filtertest.go b/internal/filter/internal/filtertest/filtertest.go new file mode 100644 index 0000000..61082cc --- /dev/null +++ b/internal/filter/internal/filtertest/filtertest.go @@ -0,0 +1,80 @@ +// Package filtertest contains common constants and utilities for the internal +// filtering packages. +package filtertest + +import ( + "io" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/golibs/httphdr" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/require" +) + +// BlockRule is the common blocking rule for filtering tests that blocks +// [ReqHost]. +const BlockRule = "|" + ReqHost + "^" + +// RemoteIP is the common client IP for filtering tests +var RemoteIP = netip.MustParseAddr("1.2.3.4") + +// ReqHost is the common request host for filtering tests. +const ReqHost = "www.host.example" + +// ReqFQDN is the common request FQDN for filtering tests. +const ReqFQDN = ReqHost + "." + +// ServerName is the common server name for filtering tests. +const ServerName = "testServer/1.0" + +// Timeout is the common timeout for filtering tests. +const Timeout = 1 * time.Second + +// PrepareRefreshable launches an HTTP server serving the given text and code, +// as well as creates a cache file. If code is zero, the server isn't started. +// If reqCh not nil, a signal is sent every time the server is called. The +// server uses [ServerName] as the value of the Server header. +func PrepareRefreshable( + tb testing.TB, + reqCh chan<- struct{}, + text string, + code int, +) (cachePath string, srvURL *url.URL) { + tb.Helper() + + if code != 0 { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + pt := testutil.PanicT{} + if reqCh != nil { + testutil.RequireSend(pt, reqCh, struct{}{}, Timeout) + } + + w.Header().Set(httphdr.Server, ServerName) + + w.WriteHeader(code) + + _, writeErr := io.WriteString(w, text) + require.NoError(pt, writeErr) + })) + tb.Cleanup(srv.Close) + + var err error + srvURL, err = agdhttp.ParseHTTPURL(srv.URL) + require.NoError(tb, err) + } + + cacheDir := tb.TempDir() + cacheFile, err := os.CreateTemp(cacheDir, filepath.Base(tb.Name())) + require.NoError(tb, err) + require.NoError(tb, cacheFile.Close()) + + return cacheFile.Name(), srvURL +} diff --git a/internal/filter/internal/internal.go b/internal/filter/internal/internal.go index 6565cd8..1db7315 100644 --- a/internal/filter/internal/internal.go +++ b/internal/filter/internal/internal.go @@ -3,3 +3,48 @@ // // TODO(a.garipov): Move more code to subpackages, see AGDNS-824. package internal + +import ( + "context" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/c2h5oh/datasize" + "github.com/miekg/dns" +) + +// Make sure that the signatures for FilterRequest match. +var _ RequestFilter = (Interface)(nil) + +// Interface is the DNS request and response filter interface. +type Interface interface { + // FilterRequest filters the DNS request for the provided client. All + // parameters must be non-nil. req must have exactly one question. If a is + // nil, the request doesn't match any of the rules. + FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error) + + // FilterResponse filters the DNS response for the provided client. All + // parameters must be non-nil. If a is nil, the response doesn't match any + // of the rules. + FilterResponse(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) (r Result, err error) +} + +// maxFilterSize is the maximum size of downloaded filters. +const maxFilterSize = 256 * int64(datasize.MB) + +// DefaultFilterRefreshTimeout is the default timeout to use when fetching +// filter lists data. +// +// TODO(a.garipov): Consider making timeouts where they are used configurable. +const DefaultFilterRefreshTimeout = 3 * time.Minute + +// DefaultResolveTimeout is the default timeout for resolving hosts for +// safe-search and safe-browsing filters. +// +// TODO(ameshkov): Consider making configurable. +const DefaultResolveTimeout = 1 * time.Second + +// RequestFilter can filter a request based on the request info. +type RequestFilter interface { + FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error) +} diff --git a/internal/filter/internal/internal_test.go b/internal/filter/internal/internal_test.go new file mode 100644 index 0000000..c76bb5b --- /dev/null +++ b/internal/filter/internal/internal_test.go @@ -0,0 +1,11 @@ +package internal_test + +import ( + "testing" + + "github.com/AdguardTeam/golibs/testutil" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} diff --git a/internal/filter/refrfilter.go b/internal/filter/internal/refreshable.go similarity index 58% rename from internal/filter/refrfilter.go rename to internal/filter/internal/refreshable.go index f776266..57bc5f7 100644 --- a/internal/filter/refrfilter.go +++ b/internal/filter/internal/refreshable.go @@ -1,4 +1,4 @@ -package filter +package internal import ( "context" @@ -16,88 +16,79 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdio" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "github.com/google/renameio" ) -// Refreshable Filter - -// refreshableFilter contains entities common to filters that can refresh -// themselves from a file and a URL. -type refreshableFilter struct { +// Refreshable contains entities common to filters that can refresh themselves +// from a file and a URL. +type Refreshable struct { // http is the HTTP client used to refresh the filter. http *agdhttp.Client // url is the URL used to refresh the filter. url *url.URL - // resetRules is the function that compiles the rules and replaces the - // filtering engine. It must not be nil and should return an informative - // error. - resetRules func(text string) (err error) - - // id is the identifier of the filter. + // id is the filter list ID, if any. id agd.FilterListID // cachePath is the path to the file containing the cached filter rules. cachePath string - // typ is the type of this filter used for logging and error reporting. - typ string - // staleness is the time after which a file is considered stale. staleness time.Duration } -// refresh reloads the filter data. If acceptStale is true, refresh doesn't try +// NewRefreshable returns a new refreshable filter. All parameters must be +// non-zero. +func NewRefreshable(l *agd.FilterList, cachePath string) (f *Refreshable) { + return &Refreshable{ + http: agdhttp.NewClient(&agdhttp.ClientConfig{ + Timeout: DefaultFilterRefreshTimeout, + }), + url: l.URL, + id: l.ID, + cachePath: cachePath, + staleness: l.RefreshIvl, + } +} + +// Refresh reloads the filter data. If acceptStale is true, refresh doesn't try // to load the filter data from its URL when there is already a file in the // cache directory, regardless of its staleness. -func (f *refreshableFilter) refresh( +func (f *Refreshable) Refresh( ctx context.Context, acceptStale bool, -) (err error) { - // TODO(a.garipov): Consider adding a helper for enriching errors with - // context deadline data. - start := time.Now() - deadline, hasDeadline := ctx.Deadline() +) (text string, err error) { + now := time.Now() - defer func() { - if err != nil && hasDeadline { - err = fmt.Errorf("started refresh at %s, deadline at %s: %w", start, deadline, err) - } + defer func() { err = errors.Annotate(err, "%s: %w", f.id) }() - err = errors.Annotate(err, "%s %q: %w", f.typ, f.id) - }() - - text, err := f.refreshFromFile(acceptStale) + text, err = f.refreshFromFile(acceptStale, now) if err != nil { - return fmt.Errorf("refreshing from file %q: %w", f.cachePath, err) + return "", fmt.Errorf("refreshing from file %q: %w", f.cachePath, err) } if text == "" { - log.Info("filter %s: refreshing from url %q", f.id, f.url) + log.Info("%s: refreshing from url %q", f.id, f.url) - text, err = f.refreshFromURL(ctx) + text, err = f.refreshFromURL(ctx, now) if err != nil { - return fmt.Errorf("refreshing from url %q: %w", f.url, err) + return "", fmt.Errorf("refreshing from url %q: %w", f.url, err) } } - err = f.resetRules(text) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - return nil + return text, nil } // refreshFromFile loads filter data from a file if the file's mtime shows that -// it's still fresh. If acceptStale is true, and the cache file exists, the -// data is read from there regardless of its staleness. If err is nil and text -// is empty, a refresh from a URL is required. -func (f *refreshableFilter) refreshFromFile( +// it's still fresh relative to updTime. If acceptStale is true, and the cache +// file exists, the data is read from there regardless of its staleness. If err +// is nil and text is empty, a refresh from a URL is required. +func (f *Refreshable) refreshFromFile( acceptStale bool, + updTime time.Time, ) (text string, err error) { // #nosec G304 -- Assume that cachePath is always cacheDir + a valid, // no-slash filter list ID. @@ -117,7 +108,7 @@ func (f *refreshableFilter) refreshFromFile( return "", fmt.Errorf("reading filter file stat: %w", err) } - if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) { + if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(updTime) { return "", nil } } @@ -132,15 +123,19 @@ func (f *refreshableFilter) refreshFromFile( } // refreshFromURL loads the filter data from u, puts it into the file specified -// by cachePath, and also returns its content. -func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, err error) { +// by cachePath, returns its content, and also sets its atime and mtime to +// updTime. +func (f *Refreshable) refreshFromURL( + ctx context.Context, + updTime time.Time, +) (text string, err error) { // TODO(a.garipov): Cache these like renameio recommends. tmpDir := renameio.TempDir(filepath.Dir(f.cachePath)) tmpFile, err := renameio.TempFile(tmpDir, f.cachePath) if err != nil { return "", fmt.Errorf("creating temporary filter file: %w", err) } - defer func() { err = withDeferredTmpCleanup(err, tmpFile) }() + defer func() { err = f.withDeferredTmpCleanup(err, tmpFile, updTime) }() resp, err := f.http.Get(ctx, f.url) if err != nil { @@ -148,10 +143,17 @@ func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, er } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() - srv := resp.Header.Get(agdhttp.HdrNameServer) + srv := resp.Header.Get(httphdr.Server) cl := resp.ContentLength - log.Info("loading from %q: got content-length %d, code %d, srv %q", f.url, cl, resp.StatusCode, srv) + log.Info( + "%s: loading from %q: got content-length %d, code %d, srv %q", + f.id, + f.url, + cl, + resp.StatusCode, + srv, + ) err = agdhttp.CheckStatus(resp, http.StatusOK) if err != nil { @@ -178,12 +180,21 @@ func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, er // withDeferredTmpCleanup is a helper that performs the necessary cleanups and // finalizations of the temporary files based on the returned error. -func withDeferredTmpCleanup(returned error, tmpFile *renameio.PendingFile) (err error) { +func (f *Refreshable) withDeferredTmpCleanup( + returned error, + tmpFile *renameio.PendingFile, + updTime time.Time, +) (err error) { + // Make sure that any error returned from here is marked as a deferred one. if returned != nil { return errors.WithDeferred(returned, tmpFile.Cleanup()) } - // Make sure that the error returned from CloseAtomicallyReplace is marked - // as a deferred one. - return errors.WithDeferred(nil, tmpFile.CloseAtomicallyReplace()) + err = tmpFile.CloseAtomicallyReplace() + if err != nil { + return errors.WithDeferred(nil, err) + } + + // Set the modification and access times to the moment the refresh started. + return errors.WithDeferred(nil, os.Chtimes(f.cachePath, updTime, updTime)) } diff --git a/internal/filter/internal/refreshable_test.go b/internal/filter/internal/refreshable_test.go new file mode 100644 index 0000000..146098a --- /dev/null +++ b/internal/filter/internal/refreshable_test.go @@ -0,0 +1,198 @@ +package internal_test + +import ( + "context" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// refrID is the ID of a [agd.FilterList] used for testing. +const refrID = "test_id" + +func TestRefreshable_Refresh(t *testing.T) { + const ( + defaultFileText = "||filefilter.example\n" + defaultURLText = "||urlfilter.example\n" + ) + + testCases := []struct { + name string + wantText string + wantErrMsg string + srvText string + staleness time.Duration + srvCode int + acceptStale bool + expectReq bool + useCacheFile bool + }{{ + name: "no_file", + wantText: defaultURLText, + wantErrMsg: "", + srvText: defaultURLText, + staleness: 0, + srvCode: http.StatusOK, + acceptStale: true, + expectReq: true, + useCacheFile: false, + }, { + name: "no_file_http_empty", + wantText: "", + wantErrMsg: refrID + `: refreshing from url "URL": ` + + `server "` + filtertest.ServerName + `": empty text, not resetting`, + srvText: "", + staleness: 0, + srvCode: http.StatusOK, + acceptStale: true, + expectReq: true, + useCacheFile: false, + }, { + name: "no_file_http_error", + wantText: "", + wantErrMsg: refrID + `: refreshing from url "URL": ` + + `server "` + filtertest.ServerName + `": ` + + `status code error: expected 200, got 500`, + srvText: "internal server error", + staleness: 0, + srvCode: http.StatusInternalServerError, + acceptStale: true, + expectReq: true, + useCacheFile: false, + }, { + name: "file", + wantText: defaultFileText, + wantErrMsg: "", + srvText: "", + staleness: 1 * time.Hour, + srvCode: 0, + acceptStale: true, + expectReq: false, + useCacheFile: true, + }, { + name: "file_stale", + wantText: defaultURLText, + wantErrMsg: "", + srvText: defaultURLText, + staleness: -1 * time.Hour, + srvCode: http.StatusOK, + acceptStale: false, + expectReq: true, + useCacheFile: true, + }, { + name: "file_stale_accept", + wantText: defaultFileText, + wantErrMsg: "", + srvText: "", + staleness: -1 * time.Hour, + srvCode: 0, + acceptStale: true, + expectReq: false, + useCacheFile: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var err error + + reqCh := make(chan struct{}, 1) + var cachePath string + realCachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, tc.srvText, tc.srvCode) + if tc.useCacheFile { + cachePath = realCachePath + + err = os.WriteFile(cachePath, []byte(defaultFileText), 0o600) + require.NoError(t, err) + } else { + cachePath = filepath.Join(t.TempDir(), "does_not_exist") + } + + fl := &agd.FilterList{ + URL: srvURL, + ID: refrID, + RefreshIvl: tc.staleness, + } + f := internal.NewRefreshable(fl, cachePath) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var gotText string + gotText, err = f.Refresh(ctx, tc.acceptStale) + if tc.expectReq { + testutil.RequireReceive(t, reqCh, filtertest.Timeout) + } + + // Since we only get the actual URL within the subtest, replace it + // here and check the error message. + if srvURL != nil { + tc.wantErrMsg = strings.ReplaceAll(tc.wantErrMsg, "URL", srvURL.String()) + } + + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + assert.Equal(t, tc.wantText, gotText) + }) + } +} + +func TestRefreshable_Refresh_properStaleness(t *testing.T) { + const ( + responseDur = time.Second / 5 + staleness = time.Hour + ) + + reqCh := make(chan struct{}) + cachePath, addr := filtertest.PrepareRefreshable(t, reqCh, filtertest.BlockRule, http.StatusOK) + + fl := &agd.FilterList{ + URL: addr, + ID: refrID, + RefreshIvl: staleness, + } + f := internal.NewRefreshable(fl, cachePath) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + var err error + var now time.Time + go func() { + <-reqCh + now = time.Now() + _, err = f.Refresh(ctx, false) + <-reqCh + }() + + // Start the refresh. + reqCh <- struct{}{} + + // Hold the handler to guarantee the refresh will endure some time. + time.Sleep(responseDur) + + // Continue the refresh. + testutil.RequireReceive(t, reqCh, filtertest.Timeout) + + // Ensure the refresh finished. + reqCh <- struct{}{} + + require.NoError(t, err) + + file, err := os.Open(cachePath) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, file.Close) + + fi, err := file.Stat() + require.NoError(t, err) + + assert.InDelta(t, fi.ModTime().Sub(now), 0, float64(time.Millisecond)) +} diff --git a/internal/filter/internal/resultcache/resultcache.go b/internal/filter/internal/resultcache/resultcache.go index f230be5..ea4b4d9 100644 --- a/internal/filter/internal/resultcache/resultcache.go +++ b/internal/filter/internal/resultcache/resultcache.go @@ -83,7 +83,7 @@ var hashSeed = maphash.MakeSeed() // DefaultKey produces a cache key based on host, qt, and isAns using the // default algorithm. -func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) { +func DefaultKey(host string, qt dnsmsg.RRType, cl dnsmsg.Class, isAns bool) (k Key) { // Use maphash explicitly instead of using a key structure to reduce // allocations and optimize interface conversion up the stack. h := &maphash.Hash{} @@ -92,9 +92,10 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) { _, _ = h.WriteString(host) // Save on allocations by reusing a buffer. - var buf [3]byte + var buf [5]byte binary.LittleEndian.PutUint16(buf[:2], qt) - buf[2] = mathutil.BoolToNumber[byte](isAns) + binary.LittleEndian.PutUint16(buf[2:4], cl) + buf[4] = mathutil.BoolToNumber[byte](isAns) _, _ = h.Write(buf[:]) diff --git a/internal/filter/internal/resultcache/resultcache_test.go b/internal/filter/internal/resultcache/resultcache_test.go index 5594800..aaa0190 100644 --- a/internal/filter/internal/resultcache/resultcache_test.go +++ b/internal/filter/internal/resultcache/resultcache_test.go @@ -11,8 +11,8 @@ import ( // Common keys for tests. var ( - testKey = resultcache.DefaultKey("example.com", dns.TypeA, true) - otherKey = resultcache.DefaultKey("example.org", dns.TypeAAAA, false) + testKey = resultcache.DefaultKey("example.com", dns.TypeA, dns.ClassINET, true) + otherKey = resultcache.DefaultKey("example.org", dns.TypeAAAA, dns.ClassINET, false) ) // val is the common value for tests. diff --git a/internal/filter/internal/rulelist/immutable.go b/internal/filter/internal/rulelist/immutable.go new file mode 100644 index 0000000..8126a34 --- /dev/null +++ b/internal/filter/internal/rulelist/immutable.go @@ -0,0 +1,40 @@ +package rulelist + +import ( + "github.com/AdguardTeam/AdGuardDNS/internal/agd" +) + +// Immutable is a rule-list filter that doesn't refresh or change. +// It is used for users' custom rule-lists as well as in service blocking. +// +// TODO(a.garipov): Consider not using rule-list engines for service and custom +// filters at all. It could be faster to simply go through all enabled rules +// sequentially instead. Alternatively, rework the urlfilter.DNSEngine and make +// it use the sequential scan if the number of rules is less than some constant +// value. +// +// See AGDNS-342. +type Immutable struct { + // TODO(a.garipov): Find ways to embed it in a way that shows the methods, + // doesn't result in double dereferences, and doesn't cause naming issues. + *filter +} + +// NewImmutable returns a new immutable DNS request and response filter using +// the provided rule text and ID. +func NewImmutable( + text string, + id agd.FilterListID, + svcID agd.BlockedServiceID, + memCacheSize int, + useMemCache bool, +) (f *Immutable, err error) { + f = &Immutable{} + f.filter, err = newFilter(text, id, svcID, memCacheSize, useMemCache) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + return f, nil +} diff --git a/internal/filter/internal/rulelist/refreshable.go b/internal/filter/internal/rulelist/refreshable.go new file mode 100644 index 0000000..52d7031 --- /dev/null +++ b/internal/filter/internal/rulelist/refreshable.go @@ -0,0 +1,136 @@ +package rulelist + +import ( + "context" + "fmt" + "net/netip" + "path/filepath" + "sync" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" +) + +// Refreshable is a refreshable DNS request and response filter based on filter +// rule lists. +// +// TODO(a.garipov): Consider adding a separate version that uses a single engine +// for multiple rule lists and using it to optimize the filtering using default +// filtering groups. +type Refreshable struct { + *filter + + // mu protects filter.engine. + mu *sync.RWMutex + + // refr contains data for refreshing the filter. + refr *internal.Refreshable +} + +// NewRefreshable returns a new refreshable DNS request and response filter +// based on the provided rule list. l must be non-nil. The initial refresh +// should be called explicitly if necessary. +func NewRefreshable( + l *agd.FilterList, + fileCacheDir string, + memCacheSize int, + useMemCache bool, +) (f *Refreshable) { + f = &Refreshable{ + mu: &sync.RWMutex{}, + refr: internal.NewRefreshable(l, filepath.Join(fileCacheDir, string(l.ID))), + } + + var err error + f.filter, err = newFilter("", l.ID, "", memCacheSize, useMemCache) + if err != nil { + // Should never happen, since text is empty. + panic(fmt.Errorf("unexpected filter error: %w", err)) + } + + return f +} + +// NewFromString returns a new DNS request and response filter using the +// provided rule text and ID. +// +// TODO(a.garipov): Only used in tests. Consider removing later. +func NewFromString( + text string, + id agd.FilterListID, + svcID agd.BlockedServiceID, + memCacheSize int, + useMemCache bool, +) (f *Refreshable, err error) { + f = &Refreshable{ + mu: &sync.RWMutex{}, + } + + f.filter, err = newFilter(text, id, svcID, memCacheSize, useMemCache) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + return f, nil +} + +// DNSResult returns the result of applying the urlfilter DNS filtering engine. +// If the request is not filtered, DNSResult returns nil. +func (f *Refreshable) DNSResult( + clientIP netip.Addr, + clientName string, + host string, + rrType dnsmsg.RRType, + isAns bool, +) (res *urlfilter.DNSResult) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.filter.DNSResult(clientIP, clientName, host, rrType, isAns) +} + +// Refresh reloads the rule list data. If acceptStale is true, do not try to +// load the list from its URL when there is already a file in the cache +// directory, regardless of its staleness. +func (f *Refreshable) Refresh(ctx context.Context, acceptStale bool) (err error) { + text, err := f.refr.Refresh(ctx, acceptStale) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + // TODO(a.garipov): Add filterlist.BytesRuleList. + strList := &filterlist.StringRuleList{ + ID: f.urlFilterID, + RulesText: text, + IgnoreCosmetic: true, + } + + s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList}) + if err != nil { + return fmt.Errorf("creating rule storage: %w", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + + f.cache.Clear() + f.engine = urlfilter.NewDNSEngine(s) + + log.Info("%s: reset %d rules", f.id, f.engine.RulesCount) + + return nil +} + +// RulesCount returns the number of rules in the filter's engine. +func (f *Refreshable) RulesCount() (n int) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.filter.RulesCount() +} diff --git a/internal/filter/internal/rulelist/refreshable_test.go b/internal/filter/internal/rulelist/refreshable_test.go new file mode 100644 index 0000000..01086ee --- /dev/null +++ b/internal/filter/internal/rulelist/refreshable_test.go @@ -0,0 +1,107 @@ +package rulelist_test + +import ( + "context" + "net/http" + "net/netip" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// testReqHost is the request host for tests. +const testReqHost = "blocked.example" + +// testRemoteIP is the client IP for tests +var testRemoteIP = netip.MustParseAddr("1.2.3.4") + +// testFltListID is the common filter list IDs for tests. +const testFltListID agd.FilterListID = "fl1" + +// testBlockRule is the common blocking rule for tests. +const testBlockRule = "||" + testReqHost + "\n" + +func TestRefreshable_RulesCount(t *testing.T) { + rl, err := rulelist.NewFromString(testBlockRule, testFltListID, "", 0, false) + require.NoError(t, err) + + assert.Equal(t, 1, rl.RulesCount()) +} + +func TestRefreshable_DNSResult_cache(t *testing.T) { + rl, err := rulelist.NewFromString(testBlockRule, testFltListID, "", 100, true) + require.NoError(t, err) + + const qt = dns.TypeA + + t.Run("blocked", func(t *testing.T) { + dr := rl.DNSResult(testRemoteIP, "", testReqHost, qt, false) + require.NotNil(t, dr) + + assert.Len(t, dr.NetworkRules, 1) + + cachedDR := rl.DNSResult(testRemoteIP, "", testReqHost, qt, false) + require.NotNil(t, cachedDR) + + assert.Same(t, dr, cachedDR) + }) + + t.Run("none", func(t *testing.T) { + const otherHost = "other.example" + + dr := rl.DNSResult(testRemoteIP, "", otherHost, qt, false) + assert.Nil(t, dr) + + cachedDR := rl.DNSResult(testRemoteIP, "", otherHost, dns.TypeA, false) + assert.Nil(t, cachedDR) + }) +} + +func TestRefreshable_ID(t *testing.T) { + const svcID = agd.BlockedServiceID("test_service") + rl, err := rulelist.NewFromString(testBlockRule, testFltListID, svcID, 0, false) + require.NoError(t, err) + + gotID, gotSvcID := rl.ID() + assert.Equal(t, testFltListID, gotID) + assert.Equal(t, svcID, gotSvcID) +} + +func TestRefreshable_Refresh(t *testing.T) { + cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, testBlockRule, http.StatusOK) + rl := rulelist.NewRefreshable( + &agd.FilterList{ + URL: srvURL, + ID: testFltListID, + RefreshIvl: 1 * time.Hour, + }, + filepath.Dir(cachePath), + 100, + true, + ) + + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + err := rl.Refresh(ctx, false) + require.NoError(t, err) + + assert.Equal(t, 1, rl.RulesCount()) + + dr := rl.DNSResult(testRemoteIP, "", testReqHost, dns.TypeA, false) + require.NotNil(t, dr) + + assert.Len(t, dr.NetworkRules, 1) +} diff --git a/internal/filter/internal/rulelist/rulelist.go b/internal/filter/internal/rulelist/rulelist.go new file mode 100644 index 0000000..114e50e --- /dev/null +++ b/internal/filter/internal/rulelist/rulelist.go @@ -0,0 +1,154 @@ +// Package rulelist contains the implementation of the standard rule-list +// filter that wraps an urlfilter filtering-engine. +package rulelist + +import ( + "fmt" + "math/rand" + "net/netip" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" + "github.com/miekg/dns" +) + +// newURLFilterID returns a new random ID for the urlfilter DNS engine to use. +func newURLFilterID() (id int) { + // #nosec G404 -- Do not use cryptographically random ID generation, since + // these are only used in one place, internal/filter.compFilter.filterMsg, + // and are not used in any security-sensitive context. + // + // Despite the fact that the type of integer filter list IDs in module + // urlfilter is int, the module actually assumes that the ID is a + // non-negative integer, or at least not a largely negative one. Otherwise, + // some of its low-level optimizations seem to break. + return int(rand.Int31()) +} + +// filter is the basic rule-list filter that doesn't refresh or change in any +// other way. +type filter struct { + // engine is the DNS filtering engine. + // + // NOTE: We do not save the [filterlist.RuleList] used to create the engine + // to close it, because we exclusively use [filterlist.StringRuleList], + // which doesn't require closing. + engine *urlfilter.DNSEngine + + // cache contains cached results of filtering. + // + // TODO(ameshkov): Add metrics for these caches. + cache *resultcache.Cache[*urlfilter.DNSResult] + + // id is the filter list ID, if any. + id agd.FilterListID + + // svcID is the additional identifier for blocked service lists. If id is + svcID agd.BlockedServiceID + + // urlFilterID is the synthetic integer identifier for the urlfilter engine. + // + // TODO(a.garipov): Change the type to a string in module urlfilter and + // remove this crutch. + urlFilterID int +} + +// newFilter returns a new basic DNS request and response filter using the +// provided rule text and ID. +func newFilter( + text string, + id agd.FilterListID, + svcID agd.BlockedServiceID, + memCacheSize int, + useMemCache bool, +) (f *filter, err error) { + f = &filter{ + id: id, + svcID: svcID, + urlFilterID: newURLFilterID(), + } + + if useMemCache { + f.cache = resultcache.New[*urlfilter.DNSResult](memCacheSize) + } + + // TODO(a.garipov): Add filterlist.BytesRuleList. + strList := &filterlist.StringRuleList{ + ID: f.urlFilterID, + RulesText: text, + IgnoreCosmetic: true, + } + + s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList}) + if err != nil { + return nil, fmt.Errorf("creating rule storage: %w", err) + } + + f.engine = urlfilter.NewDNSEngine(s) + + return f, nil +} + +// DNSResult returns the result of applying the urlfilter DNS filtering engine. +// If the request is not filtered, DNSResult returns nil. +func (f *filter) DNSResult( + clientIP netip.Addr, + clientName string, + host string, + rrType dnsmsg.RRType, + isAns bool, +) (res *urlfilter.DNSResult) { + var ok bool + var cacheKey resultcache.Key + + // Don't waste resources on computing the cache key if the cache is not + // enabled. + useCache := f.cache != nil + if useCache { + // TODO(a.garipov): Add real class here. + cacheKey = resultcache.DefaultKey(host, rrType, dns.ClassINET, isAns) + res, ok = f.cache.Get(cacheKey) + if ok { + return res + } + } + + dnsReq := &urlfilter.DNSRequest{ + Hostname: host, + // TODO(a.garipov): Make this a net.IP in module urlfilter. + ClientIP: clientIP.String(), + ClientName: clientName, + DNSType: rrType, + Answer: isAns, + } + + res, ok = f.engine.MatchRequest(dnsReq) + if !ok && len(res.NetworkRules) == 0 { + res = nil + } + + if useCache { + f.cache.Set(cacheKey, res) + } + + return res +} + +// ID returns the filter list ID of this rule list filter, as well as the ID of +// the blocked service, if any. +func (f *filter) ID() (id agd.FilterListID, svcID agd.BlockedServiceID) { + return f.id, f.svcID +} + +// RulesCount returns the number of rules in the filter's engine. +func (f *filter) RulesCount() (n int) { + return f.engine.RulesCount +} + +// URLFilterID returns the synthetic ID used for the urlfilter module. +func (f *filter) URLFilterID() (n int) { + return f.urlFilterID +} diff --git a/internal/filter/internal/safesearch/safesearch.go b/internal/filter/internal/safesearch/safesearch.go new file mode 100644 index 0000000..1a50da1 --- /dev/null +++ b/internal/filter/internal/safesearch/safesearch.go @@ -0,0 +1,185 @@ +// Package safesearch contains the implementation of the safe-search filter +// that uses lists of DNS rewrite rules to enforce safe search. +package safesearch + +import ( + "context" + "fmt" + "net" + "net/netip" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/AdGuardDNS/internal/optlog" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" +) + +// Filter modifies the results of queries to search-engine addresses and +// rewrites them to the IP addresses of their safe versions. +type Filter struct { + resCache *resultcache.Cache[*internal.ResultModified] + flt *rulelist.Refreshable + resolver agdnet.Resolver + errColl agd.ErrorCollector + id agd.FilterListID +} + +// Config contains configuration for the safe-search filter. +type Config struct { + // List is the filtering-rule list used to filter requests. + List *agd.FilterList + + // Resolver is used to resolve the IP addresses of replacement hosts. + Resolver agdnet.Resolver + + // ErrColl is used to report errors of replacement-host resolving. + ErrColl agd.ErrorCollector + + // CacheDir is the path to the directory where the cached filter files are + // put. The directory must exist. + CacheDir string + + // CacheTTL is the time to live of the result cache-items. + // + //lint:ignore U1000 TODO(a.garipov): Currently unused. See AGDNS-398. + CacheTTL time.Duration + + // CacheSize is the number of items in the result cache. + CacheSize int +} + +// New returns a new safe-search filter. c must not be nil. The initial +// refresh should be called explicitly if necessary. +func New(c *Config) (f *Filter) { + id := c.List.ID + + return &Filter{ + resCache: resultcache.New[*internal.ResultModified](c.CacheSize), + // Don't use the rule list cache, since safeSearch already has its own. + flt: rulelist.NewRefreshable(c.List, c.CacheDir, 0, false), + resolver: c.Resolver, + errColl: c.ErrColl, + id: id, + } +} + +// type check +var _ internal.RequestFilter = (*Filter)(nil) + +// FilterRequest implements the [internal.RequestFilter] interface for *Filter. +// It modifies the response if host matches f. +func (f *Filter) FilterRequest( + ctx context.Context, + req *dns.Msg, + ri *agd.RequestInfo, +) (r internal.Result, err error) { + qt := ri.QType + fam := netutil.AddrFamilyFromRRType(qt) + if fam == netutil.AddrFamilyNone { + return nil, nil + } + + host := ri.Host + cacheKey := resultcache.DefaultKey(host, qt, ri.QClass, false) + rm, ok := f.resCache.Get(cacheKey) + if ok { + if rm == nil { + // Return nil explicitly instead of modifying CloneForReq to return + // nil if the result is nil to avoid a “non-nil nil” value. + return nil, nil + } + + return rm.CloneForReq(req), nil + } + + repHost, ok := f.safeSearchHost(host, qt) + if !ok { + optlog.Debug2("filter %s: host %q is not on the list", f.id, host) + + f.resCache.Set(cacheKey, nil) + + return nil, nil + } + + optlog.Debug2("filter %s: found host %q", f.id, repHost) + + ctx, cancel := context.WithTimeout(ctx, internal.DefaultResolveTimeout) + defer cancel() + + var result *dns.Msg + ips, err := f.resolver.LookupIP(ctx, fam, repHost) + if err != nil { + agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err) + + result = ri.Messages.NewMsgSERVFAIL(req) + } else { + result, err = ri.Messages.NewIPRespMsg(req, ips...) + if err != nil { + return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err) + } + } + + rm = &internal.ResultModified{ + Msg: result, + List: f.id, + Rule: agd.FilterRuleText(host), + } + + // Copy the result to make sure that modifications to the result message + // down the pipeline don't interfere with the cached value. + // + // See AGDNS-359. + f.resCache.Set(cacheKey, rm.Clone()) + + return rm, nil +} + +// safeSearchHost returns the replacement host for the given host and question +// type, if any. qt should be either [dns.TypeA] or [dns.TypeAAAA]. +func (f *Filter) safeSearchHost(host string, qt dnsmsg.RRType) (ssHost string, ok bool) { + dr := f.flt.DNSResult(netip.Addr{}, "", host, qt, false) + if dr == nil { + return "", false + } + + for _, nr := range dr.DNSRewrites() { + drw := nr.DNSRewrite + if drw.RCode != dns.RcodeSuccess { + continue + } + + if nc := drw.NewCNAME; nc != "" { + return nc, true + } + + // All the rules in safe search rule lists are expected to have either + // A/AAAA or CNAME type. + switch drw.RRType { + case dns.TypeA, dns.TypeAAAA: + return drw.Value.(net.IP).String(), true + default: + continue + } + } + + return "", false +} + +// Refresh reloads the rule list data. If acceptStale is true, and the cache +// file exists, the data is read from there regardless of its staleness. +func (f *Filter) Refresh(ctx context.Context, acceptStale bool) (err error) { + err = f.flt.Refresh(ctx, acceptStale) + if err != nil { + return err + } + + f.resCache.Clear() + + return nil +} diff --git a/internal/filter/internal/safesearch/safesearch_test.go b/internal/filter/internal/safesearch/safesearch_test.go new file mode 100644 index 0000000..1bb4015 --- /dev/null +++ b/internal/filter/internal/safesearch/safesearch_test.go @@ -0,0 +1,198 @@ +package safesearch_test + +import ( + "context" + "net" + "net/http" + "net/netip" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// testSafeIPStr is the string representation of the IP address of the safe +// version of [testEngineWithIP]. +const testSafeIPStr = "1.2.3.4" + +// testIPOfEngineWithIP is the IP address of the safe version of +// search-engine-ip.example. +var testIPOfEngineWithIP net.IP = netip.MustParseAddr(testSafeIPStr).AsSlice() + +// testIPOfEngineWithDomain is the IP address of the safe version of +// search-engine-domain.example. +var testIPOfEngineWithDomain = net.IP{1, 2, 3, 5} + +// Common domain names for tests. +const ( + testOther = "other.example" + testEngineWithIP = "search-engine-ip.example" + testEngineWithDomain = "search-engine-domain.example" + testSafeDomain = "safe-search-engine-domain.example" +) + +// testFilterRules is are common filtering rules for tests. +const testFilterRules = `|` + testEngineWithIP + `^$dnsrewrite=NOERROR;A;` + testSafeIPStr + "\n" + + `|` + testEngineWithDomain + `^$dnsrewrite=NOERROR;CNAME;` + testSafeDomain + +func TestFilter(t *testing.T) { + reqCh := make(chan struct{}, 1) + cachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, testFilterRules, http.StatusOK) + + id, err := agd.NewFilterListID(filepath.Base(cachePath)) + require.NoError(t, err) + + f := safesearch.New(&safesearch.Config{ + List: &agd.FilterList{ + ID: id, + URL: srvURL, + }, + Resolver: &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + host string, + ) (ips []net.IP, err error) { + switch host { + case testSafeIPStr: + return []net.IP{testIPOfEngineWithIP}, nil + case testSafeDomain: + return []net.IP{testIPOfEngineWithDomain}, nil + default: + return nil, errors.Error("test resolver error") + } + }, + }, + ErrColl: &agdtest.ErrorCollector{ + OnCollect: func(ctx context.Context, err error) { + panic("not implemented") + }, + }, + CacheDir: filepath.Dir(cachePath), + CacheTTL: 1 * time.Minute, + CacheSize: 100, + }) + + refrCtx, refrCancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(refrCancel) + + err = f.Refresh(refrCtx, true) + require.NoError(t, err) + + testutil.RequireReceive(t, reqCh, filtertest.Timeout) + + t.Run("no_match", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + req, ri := newReq(t, testOther, dns.TypeA) + res, fltErr := f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + assert.Nil(t, res) + + t.Run("cached", func(t *testing.T) { + res, fltErr = f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + // TODO(a.garipov): Find a way to make caches more inspectable. + assert.Nil(t, res) + }) + }) + + t.Run("txt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + req, ri := newReq(t, testEngineWithIP, dns.TypeTXT) + res, fltErr := f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + assert.Nil(t, res) + }) + + t.Run("ip", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + req, ri := newReq(t, testEngineWithIP, dns.TypeA) + res, fltErr := f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + rm := testutil.RequireTypeAssert[*internal.ResultModified](t, res) + require.Len(t, rm.Msg.Answer, 1) + + assert.Equal(t, rm.Rule, agd.FilterRuleText(testEngineWithIP)) + + a := testutil.RequireTypeAssert[*dns.A](t, rm.Msg.Answer[0]) + assert.Equal(t, testIPOfEngineWithIP, a.A) + + t.Run("cached", func(t *testing.T) { + newReq, newRI := newReq(t, testEngineWithIP, dns.TypeA) + + var cachedRes internal.Result + cachedRes, fltErr = f.FilterRequest(ctx, newReq, newRI) + require.NoError(t, fltErr) + + // Do not assert that the results are the same, since a modified + // result of a safe search is always cloned. But assert that the + // non-clonable fields are equal and that the message has reply + // fields set properly. + cachedRM := testutil.RequireTypeAssert[*internal.ResultModified](t, cachedRes) + assert.NotSame(t, cachedRM, rm) + assert.Equal(t, cachedRM.Msg.Id, newReq.Id) + assert.Equal(t, cachedRM.List, rm.List) + assert.Equal(t, cachedRM.Rule, rm.Rule) + }) + }) + + t.Run("domain", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout) + t.Cleanup(cancel) + + req, ri := newReq(t, testEngineWithDomain, dns.TypeA) + res, fltErr := f.FilterRequest(ctx, req, ri) + require.NoError(t, fltErr) + + rm := testutil.RequireTypeAssert[*internal.ResultModified](t, res) + require.Len(t, rm.Msg.Answer, 1) + + assert.Equal(t, rm.Rule, agd.FilterRuleText(testEngineWithDomain)) + + a := testutil.RequireTypeAssert[*dns.A](t, rm.Msg.Answer[0]) + assert.Equal(t, testIPOfEngineWithDomain, a.A) + }) +} + +// newReq is a test helper that returns the DNS request and its accompanying +// request info with the given data. +func newReq(tb testing.TB, host string, qt dnsmsg.RRType) (req *dns.Msg, ri *agd.RequestInfo) { + tb.Helper() + + req = dnsservertest.NewReq(host, qt, dns.ClassINET) + ri = &agd.RequestInfo{ + Messages: agdtest.NewConstructor(), + Host: host, + QType: qt, + QClass: dns.ClassINET, + } + + return req, ri +} diff --git a/internal/filter/internal/serviceblock/index.go b/internal/filter/internal/serviceblock/index.go new file mode 100644 index 0000000..3524ea4 --- /dev/null +++ b/internal/filter/internal/serviceblock/index.go @@ -0,0 +1,96 @@ +package serviceblock + +import ( + "context" + "fmt" + "strings" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" +) + +// indexResp is the struct for the JSON response from a blocked service index +// API. +type indexResp struct { + BlockedServices []*indexRespService `json:"blocked_services"` +} + +// toInternal converts the services from the index to serviceRuleLists. +func (r *indexResp) toInternal( + ctx context.Context, + errColl agd.ErrorCollector, + cacheSize int, + useCache bool, +) (services serviceRuleLists, err error) { + l := len(r.BlockedServices) + if l == 0 { + return nil, nil + } + + services = make(serviceRuleLists, l) + errs := make([]error, len(r.BlockedServices)) + for i, svc := range r.BlockedServices { + var ( + svcID agd.BlockedServiceID + rl *rulelist.Immutable + ) + svcID, rl, err = svc.toInternal(ctx, errColl, cacheSize, useCache) + if err != nil { + errs[i] = fmt.Errorf("service at index %d: %w", i, err) + + continue + } + + services[svcID] = rl + } + + err = errors.Join(errs...) + if err != nil { + return nil, fmt.Errorf("converting blocked services: %w", err) + } + + return services, nil +} + +// indexRespService is the struct for a filter from the JSON response from a +// blocked service index API. +type indexRespService struct { + ID string `json:"id"` + Rules []string `json:"rules"` +} + +// toInternal converts the service from the index to a rule-list filter. +func (svc *indexRespService) toInternal( + ctx context.Context, + errColl agd.ErrorCollector, + cacheSize int, + useCache bool, +) (svcID agd.BlockedServiceID, rl *rulelist.Immutable, err error) { + svcID, err = agd.NewBlockedServiceID(svc.ID) + if err != nil { + return "", nil, fmt.Errorf("validating id: %w", err) + } + + if len(svc.Rules) == 0 { + reportErr := fmt.Errorf("service filter: no rules for service with id %s", svcID) + errColl.Collect(ctx, reportErr) + log.Info("warning: %s", reportErr) + } + + rl, err = rulelist.NewImmutable( + strings.Join(svc.Rules, "\n"), + agd.FilterListIDBlockedService, + svcID, + cacheSize, + useCache, + ) + if err != nil { + return "", nil, fmt.Errorf("compiling %s: %w", svc.ID, err) + } + + log.Info("%s/%s: got %d rules", agd.FilterListIDBlockedService, svcID, rl.RulesCount()) + + return svcID, rl, nil +} diff --git a/internal/filter/internal/serviceblock/serviceblock.go b/internal/filter/internal/serviceblock/serviceblock.go new file mode 100644 index 0000000..de378fd --- /dev/null +++ b/internal/filter/internal/serviceblock/serviceblock.go @@ -0,0 +1,149 @@ +// Package serviceblock contains an implementation of a filter that blocks +// services using rule lists. The blocking is based on the parental-control +// settings in the profile. +package serviceblock + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" +) + +// Filter is a service-blocking filter that uses rule lists that it gets from an +// index. +type Filter struct { + // url is the URL from which the services are fetched. + url *url.URL + + // http is the HTTP client used to refresh the filter. + http *agdhttp.Client + + // mu protects services. + mu *sync.RWMutex + + // services is an ID to filter mapping. + services serviceRuleLists + + // errColl used to collect non-critical and rare errors. + errColl agd.ErrorCollector +} + +// serviceRuleLists is convenient alias for a ID to filter mapping. +type serviceRuleLists = map[agd.BlockedServiceID]*rulelist.Immutable + +// New returns a fully initialized service blocker. +func New(indexURL *url.URL, errColl agd.ErrorCollector) (f *Filter) { + return &Filter{ + url: indexURL, + http: agdhttp.NewClient(&agdhttp.ClientConfig{ + Timeout: internal.DefaultFilterRefreshTimeout, + }), + mu: &sync.RWMutex{}, + errColl: errColl, + } +} + +// RuleLists returns the rule-list filters for the given blocked service IDs. +// The order of the elements in rls is undefined. +func (f *Filter) RuleLists( + ctx context.Context, + ids []agd.BlockedServiceID, +) (rls []*rulelist.Immutable) { + if len(ids) == 0 { + return nil + } + + f.mu.RLock() + defer f.mu.RUnlock() + + for _, id := range ids { + rl := f.services[id] + if rl != nil { + rls = append(rls, rl) + + continue + } + + reportErr := fmt.Errorf("service filter: no service with id %s", id) + f.errColl.Collect(ctx, reportErr) + log.Info("warning: %s", reportErr) + } + + return rls +} + +// Refresh loads new service data from the index URL. +func (f *Filter) Refresh(ctx context.Context, cacheSize int, useCache bool) (err error) { + fltIDStr := string(agd.FilterListIDBlockedService) + defer func() { + if err != nil { + metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(0) + } + }() + + resp, err := f.loadIndex(ctx) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + services, err := resp.toInternal(ctx, f.errColl, cacheSize, useCache) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + count := 0 + for _, s := range services { + count += s.RulesCount() + } + + metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(count)) + metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime() + metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(1) + + f.mu.Lock() + defer f.mu.Unlock() + + f.services = services + + return nil +} + +// loadIndex fetches, decodes, and returns the blocked service index data. +func (f *Filter) loadIndex(ctx context.Context) (resp *indexResp, err error) { + defer func() { err = errors.Annotate(err, "loading blocked service index from %q: %w", f.url) }() + + httpResp, err := f.http.Get(ctx, f.url) + if err != nil { + return nil, fmt.Errorf("requesting: %w", err) + } + defer func() { err = errors.WithDeferred(err, httpResp.Body.Close()) }() + + err = agdhttp.CheckStatus(httpResp, http.StatusOK) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + + resp = &indexResp{} + err = json.NewDecoder(httpResp.Body).Decode(resp) + if err != nil { + return nil, agdhttp.WrapServerError(fmt.Errorf("decoding: %w", err), httpResp) + } + + log.Debug("service filter: loaded index with %d blocked services", len(resp.BlockedServices)) + + return resp, nil +} diff --git a/internal/filter/internal/serviceblock/serviceblock_test.go b/internal/filter/internal/serviceblock/serviceblock_test.go new file mode 100644 index 0000000..e93d24b --- /dev/null +++ b/internal/filter/internal/serviceblock/serviceblock_test.go @@ -0,0 +1,74 @@ +package serviceblock_test + +import ( + "context" + "net/http" + "testing" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/serviceblock" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Common blocked service IDs for tests. +const ( + testSvcID1 agd.BlockedServiceID = "svc_1" + testSvcID2 agd.BlockedServiceID = "svc_2" +) + +// testData is a sample of a service index response. +// +// See https://github.com/atropnikov/HostlistsRegistry/blob/main/assets/services.json. +const testData string = `{ + "blocked_services": [ + { + "id": "` + string(testSvcID1) + `", + "name": "Service 1", + "rules": [ + "||service-1.example^" + ] + }, + { + "id": "` + string(testSvcID2) + `", + "name": "Service 2", + "rules": [ + "||service-2.example^" + ] + } + ] +}` + +func TestFilter(t *testing.T) { + reqCh := make(chan struct{}, 1) + _, srvURL := filtertest.PrepareRefreshable(t, reqCh, testData, http.StatusOK) + + errColl := &agdtest.ErrorCollector{ + OnCollect: func(ctx context.Context, err error) { + panic("not implemented") + }, + } + + f := serviceblock.New(srvURL, errColl) + + ctx := context.Background() + err := f.Refresh(ctx, 0, false) + require.NoError(t, err) + + testutil.RequireReceive(t, reqCh, filtertest.Timeout) + + svcIDs := []agd.BlockedServiceID{testSvcID1, testSvcID2} + rls := f.RuleLists(ctx, svcIDs) + require.Len(t, rls, 2) + + gotFltIDs := make([]agd.FilterListID, 2) + gotSvcIDs := make([]agd.BlockedServiceID, 2) + gotFltIDs[0], gotSvcIDs[0] = rls[0].ID() + gotFltIDs[1], gotSvcIDs[1] = rls[1].ID() + assert.Equal(t, agd.FilterListIDBlockedService, gotFltIDs[0]) + assert.Equal(t, agd.FilterListIDBlockedService, gotFltIDs[1]) + assert.ElementsMatch(t, svcIDs, gotSvcIDs) +} diff --git a/internal/filter/refrfilter_internal_test.go b/internal/filter/refrfilter_internal_test.go deleted file mode 100644 index 8f33eb2..0000000 --- a/internal/filter/refrfilter_internal_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package filter - -import ( - "context" - "io" - "net/http" - "net/http/httptest" - "os" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRefreshableFilter_RefreshFromFile(t *testing.T) { - dir := t.TempDir() - f, err := os.CreateTemp(dir, t.Name()) - require.NoError(t, err) - - const defaultText = "||example.com\n" - _, err = io.WriteString(f, defaultText) - require.NoError(t, err) - - cachePath := f.Name() - - testCases := []struct { - name string - cachePath string - wantText string - staleness time.Duration - acceptStale bool - }{{ - name: "no_file", - cachePath: "does_not_exist", - wantText: "", - staleness: 0, - acceptStale: true, - }, { - name: "file", - cachePath: cachePath, - wantText: defaultText, - staleness: 0, - acceptStale: true, - }, { - name: "file_stale", - cachePath: cachePath, - wantText: "", - staleness: -1 * time.Second, - acceptStale: false, - }, { - name: "file_stale_accept", - cachePath: cachePath, - wantText: defaultText, - staleness: -1 * time.Second, - acceptStale: true, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - f := &refreshableFilter{ - http: nil, - url: nil, - id: "test_filter", - cachePath: tc.cachePath, - typ: "test filter", - staleness: tc.staleness, - } - - var text string - text, err = f.refreshFromFile(tc.acceptStale) - require.NoError(t, err) - - assert.Equal(t, tc.wantText, text) - }) - } -} - -func TestRefreshableFilter_RefreshFromURL(t *testing.T) { - const defaultText = "||example.com\n" - - codeCh := make(chan int, 1) - textCh := make(chan string, 1) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - pt := testutil.PanicT{} - - w.WriteHeader(<-codeCh) - - _, err := io.WriteString(w, <-textCh) - require.NoError(pt, err) - })) - t.Cleanup(srv.Close) - - u, err := agdhttp.ParseHTTPURL(srv.URL) - require.NoError(t, err) - - httpCli := agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: testTimeout, - }) - - dir := t.TempDir() - f, err := os.CreateTemp(dir, t.Name()) - require.NoError(t, err) - - _, err = io.WriteString(f, defaultText) - require.NoError(t, err) - - cachePath := f.Name() - - testCases := []struct { - name string - cachePath string - text string - wantText string - wantErrMsg string - timeout time.Duration - code int - expectReq bool - }{{ - name: "success", - cachePath: cachePath, - text: defaultText, - wantText: defaultText, - wantErrMsg: "", - timeout: testTimeout, - code: http.StatusOK, - expectReq: true, - }, { - name: "not_found", - cachePath: cachePath, - text: defaultText, - wantText: "", - wantErrMsg: `server "": status code error: expected 200, got 404`, - timeout: testTimeout, - code: http.StatusNotFound, - expectReq: true, - }, { - name: "timeout", - cachePath: cachePath, - text: defaultText, - wantText: "", - wantErrMsg: `requesting: Get "` + u.String() + `": context deadline exceeded`, - timeout: 0, - code: http.StatusOK, - // Context deadline errors are returned before any actual HTTP - // requesting happens. - expectReq: false, - }, { - name: "empty", - cachePath: cachePath, - text: "", - wantText: "", - wantErrMsg: `server "": empty text, not resetting`, - timeout: testTimeout, - code: http.StatusOK, - expectReq: true, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - f := &refreshableFilter{ - http: httpCli, - url: u, - id: "test_filter", - cachePath: tc.cachePath, - typ: "test filter", - staleness: testTimeout, - } - - if tc.expectReq { - codeCh <- tc.code - textCh <- tc.text - } - - ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) - defer cancel() - - var text string - text, err = f.refreshFromURL(ctx) - assert.Equal(t, tc.wantText, text) - testutil.AssertErrorMsg(t, tc.wantErrMsg, err) - }) - } -} diff --git a/internal/filter/rulelist.go b/internal/filter/rulelist.go deleted file mode 100644 index ba67be2..0000000 --- a/internal/filter/rulelist.go +++ /dev/null @@ -1,239 +0,0 @@ -package filter - -import ( - "context" - "fmt" - "math/rand" - "net/netip" - "path/filepath" - "sync" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/urlfilter" - "github.com/AdguardTeam/urlfilter/filterlist" -) - -// Rule-list filter - -// ruleListFilter is a DNS request and response filter based on filter rule -// lists. -// -// TODO(a.garipov): Consider adding a separate version that uses a single engine -// for multiple rule lists and using it to optimize the filtering using default -// filtering groups. -type ruleListFilter struct { - // mu protects engine and ruleList. - mu *sync.RWMutex - - // refr contains data for refreshing the filter. - refr *refreshableFilter - - // engine is the DNS filtering engine. - engine *urlfilter.DNSEngine - - // dnsResCache contains cached results of filtering. - // - // TODO(ameshkov): Add metrics for these caches. - dnsResCache *resultcache.Cache[*urlfilter.DNSResult] - - // ruleList is the filtering rule ruleList used by the engine. - // - // TODO(a.garipov): Consider making engines in module urlfilter closeable - // and remove this crutch. - ruleList filterlist.RuleList - - // subID is the additional identifier for custom rule lists and blocked - // service lists. If id is [agd.FilterListIDBlockedService], it contains an - // [agd.BlockedServiceID], and if id is [agd.FilterListIDCustom], it - // contains an [agd.ProfileID]. - subID string - - // urlFilterID is the synthetic integer identifier for the urlfilter engine. - // - // TODO(a.garipov): Change the type to a string in module urlfilter and - // remove this crutch. - urlFilterID int -} - -// newRuleListFilter returns a new DNS request and response filter based on the -// provided rule list. l must be non-nil. The initial refresh should be called -// explicitly if necessary. -func newRuleListFilter( - l *agd.FilterList, - fileCacheDir string, - cacheSize int, - useCache bool, -) (flt *ruleListFilter) { - flt = &ruleListFilter{ - mu: &sync.RWMutex{}, - refr: &refreshableFilter{ - http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultFilterRefreshTimeout, - }), - url: l.URL, - id: l.ID, - cachePath: filepath.Join(fileCacheDir, string(l.ID)), - typ: "rule list", - staleness: l.RefreshIvl, - }, - urlFilterID: newURLFilterID(), - } - - if useCache { - flt.dnsResCache = resultcache.New[*urlfilter.DNSResult](cacheSize) - } - - // Do not set this in the literal above, since flt is nil there. - flt.refr.resetRules = flt.resetRules - - return flt -} - -// newRuleListFltFromStr returns a new DNS request and response filter using the -// provided rule text and ID. -func newRuleListFltFromStr( - text string, - id agd.FilterListID, - subID string, - cacheSize int, - useCache bool, -) (flt *ruleListFilter, err error) { - flt = &ruleListFilter{ - mu: &sync.RWMutex{}, - refr: &refreshableFilter{ - id: id, - typ: "rule list", - }, - subID: subID, - urlFilterID: newURLFilterID(), - } - - if useCache { - flt.dnsResCache = resultcache.New[*urlfilter.DNSResult](cacheSize) - } - - err = flt.resetRules(text) - if err != nil { - return nil, err - } - - return flt, nil -} - -// newURLFilterID returns a new random ID for the urlfilter DNS engine to use. -func newURLFilterID() (id int) { - // #nosec G404 -- Do not use cryptographically random ID generation, since - // these are only used in one place, compFilter.filterMsg, and are not used - // in any security-sensitive context. - // - // Despite the fact that the type of integer filter list IDs in module - // urlfilter is int, the module actually assumes that the ID is - // a non-negative integer, or at least not a largely negative one. - // Otherwise, some of its low-level optimizations seem to break. - return int(rand.Int31()) -} - -// dnsResult returns the result of applying the urlfilter DNS filtering engine. -// If the request is not filtered, dnsResult returns nil. -func (f *ruleListFilter) dnsResult( - cliIP netip.Addr, - cliName string, - host string, - rrType dnsmsg.RRType, - isAns bool, -) (dr *urlfilter.DNSResult) { - var ok bool - var cacheKey resultcache.Key - - // Don't waste resources on computing the cache key if the cache is not - // enabled. - useCache := f.dnsResCache != nil - if useCache { - cacheKey = resultcache.DefaultKey(host, rrType, isAns) - dr, ok = f.dnsResCache.Get(cacheKey) - if ok { - return dr - } - } - - dnsReq := &urlfilter.DNSRequest{ - Hostname: host, - // TODO(a.garipov): Make this a net.IP in module urlfilter. - ClientIP: cliIP.String(), - ClientName: cliName, - DNSType: rrType, - Answer: isAns, - } - - f.mu.RLock() - defer f.mu.RUnlock() - - dr, ok = f.engine.MatchRequest(dnsReq) - if !ok && len(dr.NetworkRules) == 0 { - dr = nil - } - - if useCache { - f.dnsResCache.Set(cacheKey, dr) - } - - return dr -} - -// Close implements the [io.Closer] interface for *ruleListFilter. -func (f *ruleListFilter) Close() (err error) { - f.mu.Lock() - defer f.mu.Unlock() - - if err = f.ruleList.Close(); err != nil { - return fmt.Errorf("closing rule list %q: %w", f.id(), err) - } - - return nil -} - -// id returns the ID of the rule list. -func (f *ruleListFilter) id() (fltID agd.FilterListID) { - return f.refr.id -} - -// refresh reloads the rule list data. If acceptStale is true, do not try to -// load the list from its URL when there is already a file in the cache -// directory, regardless of its staleness. -func (f *ruleListFilter) refresh(ctx context.Context, acceptStale bool) (err error) { - return f.refr.refresh(ctx, acceptStale) -} - -// resetRules resets the filtering rules. -func (f *ruleListFilter) resetRules(text string) (err error) { - // TODO(a.garipov): Add filterlist.BytesRuleList. - strList := &filterlist.StringRuleList{ - ID: f.urlFilterID, - RulesText: text, - IgnoreCosmetic: true, - } - - s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList}) - if err != nil { - return fmt.Errorf("creating list storage: %w", err) - } - - f.mu.Lock() - defer f.mu.Unlock() - - f.dnsResCache.Clear() - f.ruleList = strList - f.engine = urlfilter.NewDNSEngine(s) - - if f.subID != "" { - log.Info("filter %s/%s: reset %d rules", f.id(), f.subID, f.engine.RulesCount) - } else { - log.Info("filter %s: reset %d rules", f.id(), f.engine.RulesCount) - } - - return nil -} diff --git a/internal/filter/rulelist_internal_test.go b/internal/filter/rulelist_internal_test.go deleted file mode 100644 index 9347a93..0000000 --- a/internal/filter/rulelist_internal_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package filter - -import ( - "testing" - - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRuleListFilter_dnsResult_cache(t *testing.T) { - rl, err := newRuleListFltFromStr("||example.com^", "fl1", "", 100, true) - require.NoError(t, err) - - t.Run("blocked", func(t *testing.T) { - dr := rl.dnsResult(testRemoteIP, "", testReqHost, dns.TypeA, false) - require.NotNil(t, dr) - - assert.Len(t, dr.NetworkRules, 1) - - cachedDR := rl.dnsResult(testRemoteIP, "", testReqHost, dns.TypeA, false) - require.NotNil(t, cachedDR) - - assert.Same(t, dr, cachedDR) - }) - - t.Run("none", func(t *testing.T) { - const otherHost = "other.example" - - dr := rl.dnsResult(testRemoteIP, "", otherHost, dns.TypeA, false) - assert.Nil(t, dr) - - cachedDR := rl.dnsResult(testRemoteIP, "", otherHost, dns.TypeA, false) - assert.Nil(t, cachedDR) - }) -} diff --git a/internal/filter/rulelist_test.go b/internal/filter/rulelist_test.go deleted file mode 100644 index a5217e1..0000000 --- a/internal/filter/rulelist_test.go +++ /dev/null @@ -1,404 +0,0 @@ -package filter_test - -import ( - "context" - "testing" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" - "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/golibs/testutil" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TODO(a.garipov): Try to turn these into table-driven tests. - -func TestStorage_FilterFromContext_ruleList_request(t *testing.T) { - c := prepareConf(t) - - c.ErrColl = &agdtest.ErrorCollector{ - OnCollect: func(_ context.Context, err error) { panic("not implemented") }, - } - - s, err := filter.NewDefaultStorage(c) - require.NoError(t, err) - - g := &agd.FilteringGroup{ - ID: "default", - RuleListIDs: []agd.FilterListID{testFilterID}, - RuleListsEnabled: true, - } - - p := &agd.Profile{ - RuleListIDs: []agd.FilterListID{testFilterID}, - FilteringEnabled: true, - RuleListsEnabled: true, - } - - t.Run("blocked", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: blockedFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, blockedHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedHost) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: allowedFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, allowedHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedHost) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("blocked_client", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: blockedClientFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, blockedClientHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedClientHost) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed_client", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: allowedClientFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, allowedClientHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedClientHost) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("blocked_device", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: blockedDeviceFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, p, blockedDeviceHost, deviceIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedDeviceHost) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed_device", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: allowedDeviceFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, p, allowedDeviceHost, deviceIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedDeviceHost) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("none", func(t *testing.T) { - req := &dns.Msg{ - Question: []dns.Question{{ - Name: otherNetFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }}, - } - - ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - assert.Nil(t, r) - }) -} - -func TestStorage_FilterFromContext_ruleList_response(t *testing.T) { - c := prepareConf(t) - - c.ErrColl = &agdtest.ErrorCollector{ - OnCollect: func(_ context.Context, err error) { panic("not implemented") }, - } - - s, err := filter.NewDefaultStorage(c) - require.NoError(t, err) - - g := &agd.FilteringGroup{ - ID: "default", - RuleListIDs: []agd.FilterListID{testFilterID}, - RuleListsEnabled: true, - } - - ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - question := []dns.Question{{ - Name: otherNetFQDN, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }} - - t.Run("blocked_a", func(t *testing.T) { - resp := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.A{ - A: blockedIP4, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, resp, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedIP4.String()) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed_a", func(t *testing.T) { - resp := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.A{ - A: allowedIP4, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, resp, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedIP4.String()) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("blocked_cname", func(t *testing.T) { - resp := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: blockedFQDN, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, resp, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedHost) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed_cname", func(t *testing.T) { - resp := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: allowedFQDN, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, resp, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedHost) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("blocked_client", func(t *testing.T) { - resp := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: blockedClientFQDN, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, resp, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, blockedClientHost) - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("allowed_client", func(t *testing.T) { - req := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: allowedClientFQDN, - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, req, ri) - require.NoError(t, err) - - ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) - - assert.Contains(t, ra.Rule, allowedClientHost) - assert.Equal(t, ra.List, testFilterID) - }) - - t.Run("exception_cname", func(t *testing.T) { - req := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: "cname.exception.", - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, req, ri) - require.NoError(t, err) - - assert.Nil(t, r) - }) - - t.Run("exception_cname_blocked", func(t *testing.T) { - req := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: "cname.blocked.", - }}, - } - - var r filter.Result - r, err = f.FilterResponse(ctx, req, ri) - require.NoError(t, err) - - rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) - - assert.Contains(t, rb.Rule, "cname.blocked") - assert.Equal(t, rb.List, testFilterID) - }) - - t.Run("none", func(t *testing.T) { - req := &dns.Msg{ - Question: question, - Answer: []dns.RR{&dns.CNAME{ - Target: otherOrgFQDN, - }}, - } - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - assert.Nil(t, r) - }) -} diff --git a/internal/filter/safebrowsing.go b/internal/filter/safebrowsing.go deleted file mode 100644 index 2087e0a..0000000 --- a/internal/filter/safebrowsing.go +++ /dev/null @@ -1,114 +0,0 @@ -package filter - -import ( - "context" - "encoding/hex" - "fmt" - "strings" - - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" -) - -// Safe Browsing TXT Record Server - -// SafeBrowsingServer is a safe browsing server that responds to TXT DNS queries -// to known domains. -// -// TODO(a.garipov): Consider making an interface to simplify testing. -type SafeBrowsingServer struct { - generalHashes *hashstorage.Storage - adultBlockingHashes *hashstorage.Storage -} - -// NewSafeBrowsingServer returns a new safe browsing DNS server. -func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) { - return &SafeBrowsingServer{ - generalHashes: general, - adultBlockingHashes: adultBlocking, - } -} - -// Default safe browsing host suffixes. -// -// TODO(ameshkov): Consider making these configurable. -const ( - GeneralTXTSuffix = ".sb.dns.adguard.com" - AdultBlockingTXTSuffix = ".pc.dns.adguard.com" -) - -// Hashes returns the matched hashes if the host matched one of the domain names -// in srv. -// -// TODO(a.garipov): Use the context for logging etc. -func (srv *SafeBrowsingServer) Hashes( - _ context.Context, - host string, -) (hashes []string, matched bool, err error) { - // TODO(a.garipov): Remove this if SafeBrowsingServer becomes an interface. - if srv == nil { - return nil, false, nil - } - - var prefixesStr string - var strg *hashstorage.Storage - if strings.HasSuffix(host, GeneralTXTSuffix) { - prefixesStr = host[:len(host)-len(GeneralTXTSuffix)] - strg = srv.generalHashes - } else if strings.HasSuffix(host, AdultBlockingTXTSuffix) { - prefixesStr = host[:len(host)-len(AdultBlockingTXTSuffix)] - strg = srv.adultBlockingHashes - } else { - return nil, false, nil - } - - log.Debug("safe browsing txt srv: got prefixes string %q", prefixesStr) - - hashPrefixes, err := hashPrefixesFromStr(prefixesStr) - if err != nil { - return nil, false, err - } - - return strg.Hashes(hashPrefixes), true, nil -} - -// legacyPrefixEncLen is the encoded length of a legacy hash. -const legacyPrefixEncLen = 8 - -// hashPrefixesFromStr returns hash prefixes from a dot-separated string. -func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) { - if prefixesStr == "" { - return nil, nil - } - - prefixSet := stringutil.NewSet() - prefixStrs := strings.Split(prefixesStr, ".") - for _, s := range prefixStrs { - if len(s) != hashstorage.PrefixEncLen { - // Some legacy clients send eight-character hashes instead of - // four-character ones. For now, remove the final four characters. - // - // TODO(a.garipov): Either remove this crutch or support such - // prefixes better. - if len(s) == legacyPrefixEncLen { - s = s[:hashstorage.PrefixEncLen] - } else { - return nil, fmt.Errorf("bad hash len for %q", s) - } - } - - prefixSet.Add(s) - } - - hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len()) - prefixStrs = prefixSet.Values() - for i, s := range prefixStrs { - _, err = hex.Decode(hashPrefixes[i][:], []byte(s)) - if err != nil { - return nil, fmt.Errorf("bad hash encoding for %q", s) - } - } - - return hashPrefixes, nil -} diff --git a/internal/filter/safesearch.go b/internal/filter/safesearch.go deleted file mode 100644 index f691736..0000000 --- a/internal/filter/safesearch.go +++ /dev/null @@ -1,191 +0,0 @@ -package filter - -import ( - "context" - "fmt" - "net" - "time" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache" - "github.com/AdguardTeam/AdGuardDNS/internal/optlog" - "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/urlfilter" - "github.com/miekg/dns" -) - -// Safe search - -// safeSearch is a filter that enforces safe search. -type safeSearch struct { - // resCache contains cached results. - resCache *resultcache.Cache[*ResultModified] - - // flt is used to filter requests. - flt *ruleListFilter - - // resolver resolves IP addresses. - resolver agdnet.Resolver - - // errColl is used to report rare errors. - errColl agd.ErrorCollector -} - -// safeSearchConfig contains configuration for the safe search filter. -type safeSearchConfig struct { - list *agd.FilterList - resolver agdnet.Resolver - errColl agd.ErrorCollector - cacheDir string - //lint:ignore U1000 TODO(a.garipov): Currently unused. See AGDNS-398. - ttl time.Duration - cacheSize int -} - -// newSafeSearch returns a new safe search filter. c must not be nil. The -// initial refresh should be called explicitly if necessary. -func newSafeSearch(c *safeSearchConfig) (f *safeSearch) { - return &safeSearch{ - resCache: resultcache.New[*ResultModified](c.cacheSize), - // Don't use the rule list cache, since safeSearch already has its own. - flt: newRuleListFilter(c.list, c.cacheDir, 0, false), - resolver: c.resolver, - errColl: c.errColl, - } -} - -// type check -var _ qtHostFilter = (*safeSearch)(nil) - -// filterReq implements the qtHostFilter interface for *safeSearch. It modifies -// the response if host matches f. -func (f *safeSearch) filterReq( - ctx context.Context, - ri *agd.RequestInfo, - req *dns.Msg, -) (r Result, err error) { - qt := ri.QType - fam := netutil.AddrFamilyFromRRType(qt) - if fam == netutil.AddrFamilyNone { - return nil, nil - } - - host := ri.Host - cacheKey := resultcache.DefaultKey(host, qt, false) - repHost, ok := f.safeSearchHost(host, qt) - if !ok { - optlog.Debug2("filter %s: host %q is not on the list", f.flt.id(), host) - - f.resCache.Set(cacheKey, nil) - - return nil, nil - } - - optlog.Debug2("filter %s: found host %q", f.flt.id(), repHost) - - rm, ok := f.resCache.Get(cacheKey) - if ok { - if rm == nil { - // Return nil explicitly instead of modifying CloneForReq to return - // nil if the result is nil to avoid a “non-nil nil” value. - return nil, nil - } - - return rm.CloneForReq(req), nil - } - - ctx, cancel := context.WithTimeout(ctx, defaultResolveTimeout) - defer cancel() - - var result *dns.Msg - ips, err := f.resolver.LookupIP(ctx, fam, repHost) - if err != nil { - agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.flt.id(), err) - - result = ri.Messages.NewMsgSERVFAIL(req) - } else { - result, err = ri.Messages.NewIPRespMsg(req, ips...) - if err != nil { - return nil, fmt.Errorf("filter %s: creating modified result: %w", f.flt.id(), err) - } - } - - rm = &ResultModified{ - Msg: result, - List: f.flt.id(), - Rule: agd.FilterRuleText(host), - } - - // Copy the result to make sure that modifications to the result message - // down the pipeline don't interfere with the cached value. - // - // See AGDNS-359. - f.resCache.Set(cacheKey, rm.Clone()) - - return rm, nil -} - -// safeSearchHost returns the replacement host for the given host and question -// type, if any. qt should be either [dns.TypeA] or [dns.TypeAAAA]. -func (f *safeSearch) safeSearchHost(host string, qt dnsmsg.RRType) (ssHost string, ok bool) { - dnsReq := &urlfilter.DNSRequest{ - Hostname: host, - DNSType: qt, - Answer: false, - } - - f.flt.mu.RLock() - defer f.flt.mu.RUnlock() - - // Omit matching the result since it's always false for rewrite rules. - dr, _ := f.flt.engine.MatchRequest(dnsReq) - if dr == nil { - return "", false - } - - for _, nr := range dr.DNSRewrites() { - drw := nr.DNSRewrite - if drw.RCode != dns.RcodeSuccess { - continue - } - - if nc := drw.NewCNAME; nc != "" { - return nc, true - } - - // All the rules in safe search rule lists are expected to have either - // A/AAAA or CNAME type. - switch drw.RRType { - case dns.TypeA, dns.TypeAAAA: - return drw.Value.(net.IP).String(), true - default: - continue - } - } - - return "", false -} - -// name implements the qtHostFilter interface for *safeSearch. -func (f *safeSearch) name() (n string) { - if f == nil || f.flt == nil { - return "" - } - - return string(f.flt.id()) -} - -// refresh reloads the rule list data. If acceptStale is true, and the cache -// file exists, the data is read from there regardless of its staleness. -func (f *safeSearch) refresh(ctx context.Context, acceptStale bool) (err error) { - err = f.flt.refresh(ctx, acceptStale) - if err != nil { - return err - } - - f.resCache.Clear() - - return nil -} diff --git a/internal/filter/safesearch_test.go b/internal/filter/safesearch_test.go deleted file mode 100644 index a295e81..0000000 --- a/internal/filter/safesearch_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package filter_test - -import ( - "context" - "net" - "testing" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" - "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" - "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStorage_FilterFromContext_safeSearch(t *testing.T) { - numLookupIP := 0 - resolver := &agdtest.Resolver{ - OnLookupIP: func( - _ context.Context, - fam netutil.AddrFamily, - _ string, - ) (ips []net.IP, err error) { - numLookupIP++ - - if fam == netutil.AddrFamilyIPv4 { - return []net.IP{safeSearchIPRespIP4}, nil - } - - return []net.IP{safeSearchIPRespIP6}, nil - }, - } - - c := prepareConf(t) - - c.ErrColl = &agdtest.ErrorCollector{ - OnCollect: func(_ context.Context, err error) { panic("not implemented") }, - } - - c.Resolver = resolver - - s, err := filter.NewDefaultStorage(c) - require.NoError(t, err) - - g := &agd.FilteringGroup{ - ID: "default", - ParentalEnabled: true, - GeneralSafeSearch: true, - } - - testCases := []struct { - name string - host string - wantIP net.IP - rrtype uint16 - wantLookups int - }{{ - name: "ip4", - host: safeSearchIPHost, - wantIP: safeSearchIPRespIP4, - rrtype: dns.TypeA, - wantLookups: 1, - }, { - name: "ip6", - host: safeSearchIPHost, - wantIP: safeSearchIPRespIP6, - rrtype: dns.TypeAAAA, - wantLookups: 1, - }, { - name: "host_ip4", - host: safeSearchHost, - wantIP: safeSearchIPRespIP4, - rrtype: dns.TypeA, - wantLookups: 1, - }, { - name: "host_ip6", - host: safeSearchHost, - wantIP: safeSearchIPRespIP6, - rrtype: dns.TypeAAAA, - wantLookups: 1, - }} - - for _, tc := range testCases { - numLookupIP = 0 - req := dnsservertest.CreateMessage(tc.host, tc.rrtype) - - t.Run(tc.name, func(t *testing.T) { - ri := newReqInfo(g, nil, tc.host, clientIP, tc.rrtype) - ctx := agd.ContextWithRequestInfo(context.Background(), ri) - - f := s.FilterFromContext(ctx, ri) - require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) - - var r filter.Result - r, err = f.FilterRequest(ctx, req, ri) - require.NoError(t, err) - - assert.Equal(t, tc.wantLookups, numLookupIP) - - rm, ok := r.(*filter.ResultModified) - require.True(t, ok) - - assert.Contains(t, rm.Rule, tc.host) - assert.Equal(t, rm.List, agd.FilterListIDGeneralSafeSearch) - - res := rm.Msg - require.NotNil(t, res) - - if tc.wantIP == nil { - assert.Nil(t, res.Answer) - - return - } - - require.Len(t, res.Answer, 1) - - switch tc.rrtype { - case dns.TypeA: - a, aok := res.Answer[0].(*dns.A) - require.True(t, aok) - - assert.Equal(t, tc.wantIP, a.A) - case dns.TypeAAAA: - aaaa, aaaaok := res.Answer[0].(*dns.AAAA) - require.True(t, aaaaok) - - assert.Equal(t, tc.wantIP, aaaa.AAAA) - default: - t.Fatalf("unexpected question type %d", tc.rrtype) - } - }) - } -} diff --git a/internal/filter/serviceblocker.go b/internal/filter/serviceblocker.go deleted file mode 100644 index ccb1f60..0000000 --- a/internal/filter/serviceblocker.go +++ /dev/null @@ -1,223 +0,0 @@ -package filter - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "sync" - - "github.com/AdguardTeam/AdGuardDNS/internal/agd" - "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" - "github.com/AdguardTeam/AdGuardDNS/internal/metrics" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" - "github.com/prometheus/client_golang/prometheus" -) - -// Service Blocking Filter - -// serviceBlocker is a filter that blocks services based on the settings in -// profile. -// -// TODO(a.garipov): Add tests. -type serviceBlocker struct { - // url is the URL from which the services are fetched. - url *url.URL - - // http is the HTTP client used to refresh the filter. - http *agdhttp.Client - - // mu protects services. - mu *sync.RWMutex - - // services is an ID to filter mapping. - services serviceRuleLists - - // errColl used to collect non-critical and rare errors. - errColl agd.ErrorCollector -} - -// serviceRuleLists is convenient alias for a ID to filter mapping. -type serviceRuleLists = map[agd.BlockedServiceID]*ruleListFilter - -// newServiceBlocker returns a fully initialized service blocker. -func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *serviceBlocker) { - return &serviceBlocker{ - url: indexURL, - http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultFilterRefreshTimeout, - }), - mu: &sync.RWMutex{}, - errColl: errColl, - } -} - -// ruleLists returns the rule list filters for the given blocked service IDs. -// The order of the elements in rls is undefined. -func (b *serviceBlocker) ruleLists(ids []agd.BlockedServiceID) (rls []*ruleListFilter) { - if len(ids) == 0 { - return nil - } - - b.mu.RLock() - defer b.mu.RUnlock() - - for _, id := range ids { - rl := b.services[id] - if rl == nil { - log.Info("service filter: no service with id %q", id) - - continue - } - - rls = append(rls, rl) - } - - return rls -} - -// refresh loads new service data from the index URL. -func (b *serviceBlocker) refresh( - ctx context.Context, - cacheSize int, - useCache bool, -) (err error) { - // Report the services update to prometheus. - promLabels := prometheus.Labels{ - "filter": string(agd.FilterListIDBlockedService), - } - - defer func() { - if err != nil { - agd.Collectf(ctx, b.errColl, "refreshing blocked services: %w", err) - metrics.FilterUpdatedStatus.With(promLabels).Set(0) - } - }() - - resp, err := b.loadIndex(ctx) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - services, err := resp.toInternal(cacheSize, useCache) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - b.mu.Lock() - defer b.mu.Unlock() - - b.services = services - - count := 0 - for _, s := range services { - count += s.engine.RulesCount - } - metrics.FilterRulesTotal.With(promLabels).Set(float64(count)) - metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime() - metrics.FilterUpdatedStatus.With(promLabels).Set(1) - - return nil -} - -// loadIndex fetches, decodes, and returns the blocked service index data. -func (b *serviceBlocker) loadIndex(ctx context.Context) (resp *svcIndexResp, err error) { - defer func() { err = errors.Annotate(err, "loading blocked service index from %q: %w", b.url) }() - - httpResp, err := b.http.Get(ctx, b.url) - if err != nil { - return nil, fmt.Errorf("requesting: %w", err) - } - defer func() { err = errors.WithDeferred(err, httpResp.Body.Close()) }() - - err = agdhttp.CheckStatus(httpResp, http.StatusOK) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return nil, err - } - - resp = &svcIndexResp{} - err = json.NewDecoder(httpResp.Body).Decode(resp) - if err != nil { - return nil, agdhttp.WrapServerError( - fmt.Errorf("decoding: %w", err), - httpResp, - ) - } - - log.Debug("service filter: loaded index with %d blocked services", len(resp.BlockedServices)) - - return resp, nil -} - -// svcIndexResp is the struct for the JSON response from a blocked service index -// API. -type svcIndexResp struct { - BlockedServices []*svcIndexRespService `json:"blocked_services"` -} - -// toInternal converts the services from the index to serviceRuleLists. -func (r *svcIndexResp) toInternal( - cacheSize int, - useCache bool, -) (services serviceRuleLists, err error) { - l := len(r.BlockedServices) - if l == 0 { - return nil, nil - } - - services = make(serviceRuleLists, l) - errs := make([]error, len(r.BlockedServices)) - for i, svc := range r.BlockedServices { - var id agd.BlockedServiceID - id, err = agd.NewBlockedServiceID(svc.ID) - if err != nil { - errs[i] = fmt.Errorf("service at index %d: validating id: %w", i, err) - - continue - } - - if len(svc.Rules) == 0 { - log.Info("service filter: no rules for service with id %s", id) - - continue - } - - text := strings.Join(svc.Rules, "\n") - - var rl *ruleListFilter - rl, err = newRuleListFltFromStr( - text, - agd.FilterListIDBlockedService, - svc.ID, - cacheSize, - useCache, - ) - if err != nil { - errs[i] = fmt.Errorf("compiling %s: %w", svc.ID, err) - - continue - } - - services[id] = rl - } - - err = errors.Join(errs...) - if err != nil { - return nil, fmt.Errorf("converting blocked services: %w", err) - } - - return services, nil -} - -// svcIndexRespService is the struct for a filter from the JSON response from -// a blocked service index API. -type svcIndexRespService struct { - ID string `json:"id"` - Rules []string `json:"rules"` -} diff --git a/internal/filter/storage.go b/internal/filter/storage.go index 3a3b114..d4f9fe3 100644 --- a/internal/filter/storage.go +++ b/internal/filter/storage.go @@ -12,11 +12,17 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/composite" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/serviceblock" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/bluele/gcache" - "github.com/prometheus/client_golang/prometheus" ) // Filter storage @@ -48,22 +54,22 @@ type DefaultStorage struct { http *agdhttp.Client // ruleLists are the filter list ID to a rule list filter map. - ruleLists map[agd.FilterListID]*ruleListFilter + ruleLists map[agd.FilterListID]*rulelist.Refreshable // services is the service blocking filter. - services *serviceBlocker + services *serviceblock.Filter // safeBrowsing is the general safe browsing filter. - safeBrowsing *HashPrefix + safeBrowsing *hashprefix.Filter // adultBlocking is the adult content blocking safe browsing filter. - adultBlocking *HashPrefix + adultBlocking *hashprefix.Filter // genSafeSearch is the general safe search filter. - genSafeSearch *safeSearch + genSafeSearch *safesearch.Filter // ytSafeSearch is the YouTube safe search filter. - ytSafeSearch *safeSearch + ytSafeSearch *safesearch.Filter // now returns the current time. now func() (t time.Time) @@ -73,7 +79,7 @@ type DefaultStorage struct { errColl agd.ErrorCollector // customFilters is the storage of custom filters for profiles. - customFilters *customFilters + customFilters *custom.Filters // cacheDir is the path to the directory where the cached filter files are // put. The directory must exist. @@ -84,7 +90,7 @@ type DefaultStorage struct { refreshIvl time.Duration // RuleListCacheSize defines the size of the LRU cache of rule-list - // filteirng results. + // filtering results. ruleListCacheSize int // useRuleListCache, if true, enables rule list cache. @@ -110,11 +116,11 @@ type DefaultStorageConfig struct { // SafeBrowsing is the configuration for the default safe browsing filter. // It must not be nil. - SafeBrowsing *HashPrefix + SafeBrowsing *hashprefix.Filter // AdultBlocking is the configuration for the adult content blocking safe // browsing filter. It must not be nil. - AdultBlocking *HashPrefix + AdultBlocking *hashprefix.Filter // Now is a function that returns current time. Now func() (now time.Time) @@ -142,7 +148,7 @@ type DefaultStorageConfig struct { SafeSearchCacheTTL time.Duration // RuleListCacheSize defines the size of the LRU cache of rule-list - // filteirng results. + // filtering results. RuleListCacheSize int // RefreshIvl is the refresh interval for this storage. It defines how @@ -155,49 +161,49 @@ type DefaultStorageConfig struct { // NewDefaultStorage returns a new filter storage. c must not be nil. func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) { - genSafeSearch := newSafeSearch(&safeSearchConfig{ - resolver: c.Resolver, - errColl: c.ErrColl, - list: &agd.FilterList{ + genSafeSearch := safesearch.New(&safesearch.Config{ + List: &agd.FilterList{ URL: c.GeneralSafeSearchRulesURL, ID: agd.FilterListIDGeneralSafeSearch, RefreshIvl: c.RefreshIvl, }, - cacheDir: c.CacheDir, - ttl: c.SafeSearchCacheTTL, - cacheSize: c.SafeSearchCacheSize, + Resolver: c.Resolver, + ErrColl: c.ErrColl, + CacheDir: c.CacheDir, + CacheTTL: c.SafeSearchCacheTTL, + CacheSize: c.SafeSearchCacheSize, }) - ytSafeSearch := newSafeSearch(&safeSearchConfig{ - resolver: c.Resolver, - errColl: c.ErrColl, - list: &agd.FilterList{ + ytSafeSearch := safesearch.New(&safesearch.Config{ + List: &agd.FilterList{ URL: c.YoutubeSafeSearchRulesURL, ID: agd.FilterListIDYoutubeSafeSearch, RefreshIvl: c.RefreshIvl, }, - cacheDir: c.CacheDir, - ttl: c.SafeSearchCacheTTL, - cacheSize: c.SafeSearchCacheSize, + Resolver: c.Resolver, + ErrColl: c.ErrColl, + CacheDir: c.CacheDir, + CacheTTL: c.SafeSearchCacheTTL, + CacheSize: c.SafeSearchCacheSize, }) s = &DefaultStorage{ mu: &sync.RWMutex{}, url: c.FilterIndexURL, http: agdhttp.NewClient(&agdhttp.ClientConfig{ - Timeout: defaultFilterRefreshTimeout, + Timeout: internal.DefaultFilterRefreshTimeout, }), - services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl), + services: serviceblock.New(c.BlockedServiceIndexURL, c.ErrColl), safeBrowsing: c.SafeBrowsing, adultBlocking: c.AdultBlocking, genSafeSearch: genSafeSearch, ytSafeSearch: ytSafeSearch, now: c.Now, errColl: c.ErrColl, - customFilters: &customFilters{ - cache: gcache.New(c.CustomFilterCacheSize).LRU().Build(), - errColl: c.ErrColl, - }, + customFilters: custom.New( + gcache.New(c.CustomFilterCacheSize).LRU().Build(), + c.ErrColl, + ), cacheDir: c.CacheDir, refreshIvl: c.RefreshIvl, ruleListCacheSize: c.RuleListCacheSize, @@ -221,71 +227,63 @@ func (s *DefaultStorage) FilterFromContext(ctx context.Context, ri *agd.RequestI return s.filterForProfile(ctx, ri) } - flt := &compFilter{} + c := &composite.Config{} g := ri.FilteringGroup if g.RuleListsEnabled { - flt.ruleLists = append(flt.ruleLists, s.filters(g.RuleListIDs)...) + c.RuleLists = s.filters(g.RuleListIDs) } - flt.safeBrowsing, flt.adultBlocking = s.safeBrowsingForGroup(g) - flt.genSafeSearch, flt.ytSafeSearch = s.safeSearchForGroup(g) + c.SafeBrowsing, c.AdultBlocking = s.safeBrowsingForGroup(g) + c.GeneralSafeSearch, c.YouTubeSafeSearch = s.safeSearchForGroup(g) - return flt + return composite.New(c) } // filterForProfile returns a composite filter for profile. func (s *DefaultStorage) filterForProfile(ctx context.Context, ri *agd.RequestInfo) (f Interface) { - flt := &compFilter{} - p := ri.Profile if !p.FilteringEnabled { // According to the current requirements, this means that the profile // should receive no filtering at all. - return flt + return composite.New(nil) } d := ri.Device if d != nil && !d.FilteringEnabled { // According to the current requirements, this means that the device // should receive no filtering at all. - return flt + return composite.New(nil) } // Assume that if we have a profile then we also have a device. - flt.ruleLists = s.filters(p.RuleListIDs) - flt.ruleLists = s.customFilters.appendRuleLists(ctx, flt.ruleLists, p) + c := &composite.Config{} + c.RuleLists = s.filters(p.RuleListIDs) + c.Custom = s.customFilters.Get(ctx, p) pp := p.Parental parentalEnabled := pp != nil && pp.Enabled && s.pcBySchedule(pp.Schedule) - flt.ruleLists = append(flt.ruleLists, s.serviceFilters(p, parentalEnabled)...) + c.ServiceLists = s.serviceFilters(ctx, p, parentalEnabled) - flt.safeBrowsing, flt.adultBlocking = s.safeBrowsingForProfile(p, parentalEnabled) - flt.genSafeSearch, flt.ytSafeSearch = s.safeSearchForProfile(p, parentalEnabled) + c.SafeBrowsing, c.AdultBlocking = s.safeBrowsingForProfile(p, parentalEnabled) + c.GeneralSafeSearch, c.YouTubeSafeSearch = s.safeSearchForProfile(p, parentalEnabled) - return flt + return composite.New(c) } // serviceFilters returns the blocked service rule lists for the profile. -// -// TODO(a.garipov): Consider not using ruleListFilter for service filters. Due -// to performance reasons, it would be better to simply go through all enabled -// rules sequentially instead. Alternatively, rework the urlfilter.DNSEngine -// and make it use the sequential scan if the number of rules is less than some -// constant value. -// -// See AGDNS-342. func (s *DefaultStorage) serviceFilters( + ctx context.Context, p *agd.Profile, parentalEnabled bool, -) (rls []*ruleListFilter) { +) (rls []*rulelist.Immutable) { if !parentalEnabled || len(p.Parental.BlockedServices) == 0 { return nil } - return s.services.ruleLists(p.Parental.BlockedServices) + return s.services.RuleLists(ctx, p.Parental.BlockedServices) } // pcBySchedule returns true if the profile's schedule allows parental control @@ -304,7 +302,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b func (s *DefaultStorage) safeBrowsingForProfile( p *agd.Profile, parentalEnabled bool, -) (safeBrowsing, adultBlocking *HashPrefix) { +) (safeBrowsing, adultBlocking *hashprefix.Filter) { if p.SafeBrowsingEnabled { safeBrowsing = s.safeBrowsing } @@ -321,7 +319,7 @@ func (s *DefaultStorage) safeBrowsingForProfile( func (s *DefaultStorage) safeSearchForProfile( p *agd.Profile, parentalEnabled bool, -) (gen, yt *safeSearch) { +) (gen, yt *safesearch.Filter) { if !parentalEnabled { return nil, nil } @@ -341,7 +339,7 @@ func (s *DefaultStorage) safeSearchForProfile( // in the filtering group. g must not be nil. func (s *DefaultStorage) safeBrowsingForGroup( g *agd.FilteringGroup, -) (safeBrowsing, adultBlocking *HashPrefix) { +) (safeBrowsing, adultBlocking *hashprefix.Filter) { if g.SafeBrowsingEnabled { safeBrowsing = s.safeBrowsing } @@ -355,7 +353,7 @@ func (s *DefaultStorage) safeBrowsingForGroup( // safeSearchForGroup returns safe search filters based on the information in // the filtering group. g must not be nil. -func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *safeSearch) { +func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *safesearch.Filter) { if !g.ParentalEnabled { return nil, nil } @@ -372,7 +370,7 @@ func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *saf } // filters returns all rule list filters with the given filtering rule list IDs. -func (s *DefaultStorage) filters(ids []agd.FilterListID) (rls []*ruleListFilter) { +func (s *DefaultStorage) filters(ids []agd.FilterListID) (rls []*rulelist.Refreshable) { if len(ids) == 0 { return nil } @@ -430,7 +428,7 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err log.Info("%s: got %d filter lists from index after validations", strgLogPrefix, len(fls)) - ruleLists := make(map[agd.FilterListID]*ruleListFilter, len(resp.Filters)) + ruleLists := make(map[agd.FilterListID]*rulelist.Refreshable, len(resp.Filters)) for _, fl := range fls { if _, ok := ruleLists[fl.ID]; ok { agd.Collectf(ctx, s.errColl, "%s: duplicated id %q", strgLogPrefix, fl.ID) @@ -438,23 +436,26 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err continue } - // TODO(a.garipov): Cache these. - promLabels := prometheus.Labels{"filter": string(fl.ID)} - - rl := newRuleListFilter(fl, s.cacheDir, s.ruleListCacheSize, s.useRuleListCache) - err = rl.refresh(ctx, acceptStale) + fltIDStr := string(fl.ID) + rl := rulelist.NewRefreshable( + fl, + s.cacheDir, + s.ruleListCacheSize, + s.useRuleListCache, + ) + err = rl.Refresh(ctx, acceptStale) if err == nil { ruleLists[fl.ID] = rl - metrics.FilterUpdatedStatus.With(promLabels).Set(1) - metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime() - metrics.FilterRulesTotal.With(promLabels).Set(float64(rl.engine.RulesCount)) + metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(1) + metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime() + metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(rl.RulesCount())) continue } agd.Collectf(ctx, s.errColl, "%s: refreshing %q: %w", strgLogPrefix, fl.ID, err) - metrics.FilterUpdatedStatus.With(promLabels).Set(0) + metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(0) // If we can't get the new filter, and there is an old version of the // same rule list, use it. @@ -466,17 +467,20 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err log.Info("%s: got %d filter lists from index after compilation", strgLogPrefix, len(ruleLists)) - err = s.services.refresh(ctx, s.ruleListCacheSize, s.useRuleListCache) + err = s.services.Refresh(ctx, s.ruleListCacheSize, s.useRuleListCache) if err != nil { - return fmt.Errorf("refreshing service blocker: %w", err) + const errFmt = "refreshing blocked services: %w" + agd.Collectf(ctx, s.errColl, errFmt, err) + + return fmt.Errorf(errFmt, err) } - err = s.genSafeSearch.refresh(ctx, acceptStale) + err = s.genSafeSearch.Refresh(ctx, acceptStale) if err != nil { return fmt.Errorf("refreshing safe search: %w", err) } - err = s.ytSafeSearch.refresh(ctx, acceptStale) + err = s.ytSafeSearch.Refresh(ctx, acceptStale) if err != nil { return fmt.Errorf("refreshing safe search: %w", err) } @@ -515,17 +519,10 @@ func (s *DefaultStorage) loadIndex(ctx context.Context) (resp *filterIndexResp, } // setRuleLists replaces the storage's rule lists. -func (s *DefaultStorage) setRuleLists(ruleLists map[agd.FilterListID]*ruleListFilter) { +func (s *DefaultStorage) setRuleLists(ruleLists map[agd.FilterListID]*rulelist.Refreshable) { s.mu.Lock() defer s.mu.Unlock() - for id, rl := range s.ruleLists { - err := rl.Close() - if err != nil { - log.Error("%s: closing rule list %q: %s", strgLogPrefix, id, err) - } - } - s.ruleLists = ruleLists } diff --git a/internal/filter/storage_test.go b/internal/filter/storage_test.go index 9839abc..dcb4199 100644 --- a/internal/filter/storage_test.go +++ b/internal/filter/storage_test.go @@ -5,13 +5,16 @@ import ( "io" "net" "os" + "path/filepath" "testing" "time" "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" "github.com/AdguardTeam/AdGuardDNS/internal/filter" - "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage" + "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -62,7 +65,6 @@ func TestStorage_FilterFromContext(t *testing.T) { f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) var r filter.Result r, err = f.FilterRequest(ctx, req, ri) @@ -88,7 +90,6 @@ func TestStorage_FilterFromContext(t *testing.T) { f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) var r filter.Result r, err = f.FilterRequest(ctx, req, ri) @@ -114,7 +115,6 @@ func TestStorage_FilterFromContext(t *testing.T) { f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) var r filter.Result r, err = f.FilterRequest(ctx, req, ri) @@ -147,12 +147,12 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) { _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") require.NoError(t, err) - hashes, err := hashstorage.New(safeBrowsingHost) + hashes, err := hashprefix.NewStorage(safeBrowsingHost) require.NoError(t, err) c := prepareConf(t) - c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ + c.SafeBrowsing, err = hashprefix.NewFilter(&hashprefix.FilterConfig{ Hashes: hashes, ErrColl: errColl, Resolver: resolver, @@ -202,7 +202,6 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) { f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) r, err := f.FilterRequest(ctx, req, ri) require.NoError(t, err) @@ -240,13 +239,13 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { _, err = io.WriteString(tmpFile, safeBrowsingHost+"\n") require.NoError(t, err) - hashes, err := hashstorage.New(safeBrowsingHost) + hashes, err := hashprefix.NewStorage(safeBrowsingHost) require.NoError(t, err) c := prepareConf(t) // Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule. - c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{ + c.AdultBlocking, err = hashprefix.NewFilter(&hashprefix.FilterConfig{ Hashes: hashes, ErrColl: errColl, Resolver: resolver, @@ -272,7 +271,7 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { // Set up our profile with the schedule that disables filtering at the // current moment. sch := &agd.ParentalProtectionSchedule{ - TimeZone: time.UTC, + TimeZone: agdtime.UTC(), Week: &agd.WeeklySchedule{ time.Sunday: agd.ZeroLengthDayRange(), time.Monday: agd.ZeroLengthDayRange(), @@ -320,7 +319,6 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { // schedule. f := s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) r, err := f.FilterRequest(ctx, req, ri) require.NoError(t, err) @@ -332,7 +330,6 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { f = s.FilterFromContext(ctx, ri) require.NotNil(t, f) - testutil.CleanupAndRequireSuccess(t, f.Close) r, err = f.FilterRequest(ctx, req, ri) require.NoError(t, err) @@ -343,6 +340,578 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) { assert.Equal(t, rm.List, agd.FilterListIDAdultBlocking) } +func TestStorage_FilterFromContext_ruleList_request(t *testing.T) { + c := prepareConf(t) + + c.ErrColl = &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, err error) { panic("not implemented") }, + } + + s, err := filter.NewDefaultStorage(c) + require.NoError(t, err) + + g := &agd.FilteringGroup{ + ID: "default", + RuleListIDs: []agd.FilterListID{testFilterID}, + RuleListsEnabled: true, + } + + p := &agd.Profile{ + RuleListIDs: []agd.FilterListID{testFilterID}, + FilteringEnabled: true, + RuleListsEnabled: true, + } + + t.Run("blocked", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: blockedFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, blockedHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedHost) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: allowedFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, allowedHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedHost) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("blocked_client", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: blockedClientFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, blockedClientHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedClientHost) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed_client", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: allowedClientFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, allowedClientHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedClientHost) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("blocked_device", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: blockedDeviceFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, p, blockedDeviceHost, deviceIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedDeviceHost) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed_device", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: allowedDeviceFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, p, allowedDeviceHost, deviceIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedDeviceHost) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("none", func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: otherNetFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + assert.Nil(t, r) + }) +} + +func TestStorage_FilterFromContext_ruleList_response(t *testing.T) { + c := prepareConf(t) + + c.ErrColl = &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, err error) { panic("not implemented") }, + } + + s, err := filter.NewDefaultStorage(c) + require.NoError(t, err) + + g := &agd.FilteringGroup{ + ID: "default", + RuleListIDs: []agd.FilterListID{testFilterID}, + RuleListsEnabled: true, + } + + ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + question := []dns.Question{{ + Name: otherNetFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }} + + t.Run("blocked_a", func(t *testing.T) { + resp := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.A{ + A: blockedIP4, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedIP4.String()) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed_a", func(t *testing.T) { + resp := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.A{ + A: allowedIP4, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedIP4.String()) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("blocked_cname", func(t *testing.T) { + resp := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: blockedFQDN, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedHost) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed_cname", func(t *testing.T) { + resp := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: allowedFQDN, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedHost) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("blocked_client", func(t *testing.T) { + resp := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: blockedClientFQDN, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, resp, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, blockedClientHost) + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("allowed_client", func(t *testing.T) { + req := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: allowedClientFQDN, + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, req, ri) + require.NoError(t, err) + + ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r) + + assert.Contains(t, ra.Rule, allowedClientHost) + assert.Equal(t, ra.List, testFilterID) + }) + + t.Run("exception_cname", func(t *testing.T) { + req := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: "cname.exception.", + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, req, ri) + require.NoError(t, err) + + assert.Nil(t, r) + }) + + t.Run("exception_cname_blocked", func(t *testing.T) { + req := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: "cname.blocked.", + }}, + } + + var r filter.Result + r, err = f.FilterResponse(ctx, req, ri) + require.NoError(t, err) + + rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r) + + assert.Contains(t, rb.Rule, "cname.blocked") + assert.Equal(t, rb.List, testFilterID) + }) + + t.Run("none", func(t *testing.T) { + req := &dns.Msg{ + Question: question, + Answer: []dns.RR{&dns.CNAME{ + Target: otherOrgFQDN, + }}, + } + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + assert.Nil(t, r) + }) +} + +func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) { + cacheDir := t.TempDir() + cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing)) + err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644) + require.NoError(t, err) + + hashes, err := hashprefix.NewStorage("") + require.NoError(t, err) + + errColl := &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, err error) { + panic("not implemented") + }, + } + + resolver := &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + _ netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + return []net.IP{safeBrowsingSafeIP4}, nil + }, + } + + c := prepareConf(t) + + c.SafeBrowsing, err = hashprefix.NewFilter(&hashprefix.FilterConfig{ + Hashes: hashes, + ErrColl: errColl, + Resolver: resolver, + ID: agd.FilterListIDSafeBrowsing, + CachePath: cachePath, + ReplacementHost: safeBrowsingSafeHost, + Staleness: 1 * time.Hour, + CacheTTL: 10 * time.Second, + CacheSize: 100, + }) + require.NoError(t, err) + + c.ErrColl = errColl + c.Resolver = resolver + + s, err := filter.NewDefaultStorage(c) + require.NoError(t, err) + + g := &agd.FilteringGroup{ + ID: "default", + RuleListIDs: []agd.FilterListID{}, + ParentalEnabled: true, + SafeBrowsingEnabled: true, + } + + // Test + + req := &dns.Msg{ + Question: []dns.Question{{ + Name: safeBrowsingSubFQDN, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + + ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r) + + assert.Equal(t, rm.Rule, agd.FilterRuleText(safeBrowsingHost)) + assert.Equal(t, rm.List, agd.FilterListIDSafeBrowsing) +} + +func TestStorage_FilterFromContext_safeSearch(t *testing.T) { + numLookupIP := 0 + resolver := &agdtest.Resolver{ + OnLookupIP: func( + _ context.Context, + fam netutil.AddrFamily, + _ string, + ) (ips []net.IP, err error) { + numLookupIP++ + + if fam == netutil.AddrFamilyIPv4 { + return []net.IP{safeSearchIPRespIP4}, nil + } + + return []net.IP{safeSearchIPRespIP6}, nil + }, + } + + c := prepareConf(t) + + c.ErrColl = &agdtest.ErrorCollector{ + OnCollect: func(_ context.Context, err error) { panic("not implemented") }, + } + + c.Resolver = resolver + + s, err := filter.NewDefaultStorage(c) + require.NoError(t, err) + + g := &agd.FilteringGroup{ + ID: "default", + ParentalEnabled: true, + GeneralSafeSearch: true, + } + + testCases := []struct { + name string + host string + wantIP net.IP + rrtype uint16 + wantLookups int + }{{ + name: "ip4", + host: safeSearchIPHost, + wantIP: safeSearchIPRespIP4, + rrtype: dns.TypeA, + wantLookups: 1, + }, { + name: "ip6", + host: safeSearchIPHost, + wantIP: safeSearchIPRespIP6, + rrtype: dns.TypeAAAA, + wantLookups: 1, + }, { + name: "host_ip4", + host: safeSearchHost, + wantIP: safeSearchIPRespIP4, + rrtype: dns.TypeA, + wantLookups: 1, + }, { + name: "host_ip6", + host: safeSearchHost, + wantIP: safeSearchIPRespIP6, + rrtype: dns.TypeAAAA, + wantLookups: 1, + }} + + for _, tc := range testCases { + numLookupIP = 0 + req := dnsservertest.CreateMessage(tc.host, tc.rrtype) + + t.Run(tc.name, func(t *testing.T) { + ri := newReqInfo(g, nil, tc.host, clientIP, tc.rrtype) + ctx := agd.ContextWithRequestInfo(context.Background(), ri) + + f := s.FilterFromContext(ctx, ri) + require.NotNil(t, f) + + var r filter.Result + r, err = f.FilterRequest(ctx, req, ri) + require.NoError(t, err) + + assert.Equal(t, tc.wantLookups, numLookupIP) + + rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r) + assert.Contains(t, rm.Rule, tc.host) + assert.Equal(t, rm.List, agd.FilterListIDGeneralSafeSearch) + + res := rm.Msg + require.NotNil(t, res) + + if tc.wantIP == nil { + assert.Nil(t, res.Answer) + + return + } + + require.Len(t, res.Answer, 1) + + switch ans := res.Answer[0]; ans := ans.(type) { + case *dns.A: + assert.Equal(t, tc.rrtype, ans.Hdr.Rrtype) + assert.Equal(t, tc.wantIP, ans.A) + case *dns.AAAA: + assert.Equal(t, tc.rrtype, ans.Hdr.Rrtype) + assert.Equal(t, tc.wantIP, ans.AAAA) + default: + t.Fatalf("unexpected answer type %T(%[1]v)", ans) + } + }) + } +} + var ( defaultStorageSink *filter.DefaultStorage errSink error diff --git a/internal/geoip/asntops.go b/internal/geoip/asntops.go index c5d9fcf..d69f7bf 100644 --- a/internal/geoip/asntops.go +++ b/internal/geoip/asntops.go @@ -11,16 +11,13 @@ var allTopASNs = map[agd.ASN]struct{}{ 812: {}, 852: {}, 1136: {}, - 1213: {}, 1221: {}, 1241: {}, 1257: {}, 1267: {}, 1547: {}, 1680: {}, - 1955: {}, - 2107: {}, - 2108: {}, + 1901: {}, 2110: {}, 2116: {}, 2119: {}, @@ -35,7 +32,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 3209: {}, 3212: {}, 3215: {}, - 3216: {}, 3238: {}, 3243: {}, 3249: {}, @@ -54,7 +50,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 3855: {}, 4007: {}, 4134: {}, - 4230: {}, 4609: {}, 4638: {}, 4648: {}, @@ -96,6 +91,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 6306: {}, 6327: {}, 6400: {}, + 6535: {}, 6568: {}, 6639: {}, 6661: {}, @@ -135,16 +131,17 @@ var allTopASNs = map[agd.ASN]struct{}{ 8359: {}, 8374: {}, 8376: {}, + 8386: {}, 8400: {}, 8402: {}, 8412: {}, 8447: {}, 8452: {}, 8473: {}, - 8544: {}, 8551: {}, 8560: {}, 8585: {}, + 8632: {}, 8661: {}, 8680: {}, 8681: {}, @@ -208,7 +205,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 11315: {}, 11427: {}, 11556: {}, - 11562: {}, 11594: {}, 11664: {}, 11816: {}, @@ -230,6 +226,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 12709: {}, 12716: {}, 12735: {}, + 12741: {}, 12764: {}, 12793: {}, 12810: {}, @@ -244,8 +241,8 @@ var allTopASNs = map[agd.ASN]struct{}{ 12997: {}, 13036: {}, 13046: {}, - 13092: {}, 13122: {}, + 13124: {}, 13194: {}, 13280: {}, 13285: {}, @@ -255,6 +252,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 13999: {}, 14061: {}, 14080: {}, + 14117: {}, 14434: {}, 14522: {}, 14638: {}, @@ -272,6 +270,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 15480: {}, 15502: {}, 15557: {}, + 15659: {}, 15704: {}, 15706: {}, 15735: {}, @@ -283,6 +282,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 15958: {}, 15962: {}, 15964: {}, + 15994: {}, 16010: {}, 16019: {}, 16028: {}, @@ -291,9 +291,9 @@ var allTopASNs = map[agd.ASN]struct{}{ 16135: {}, 16232: {}, 16276: {}, + 16322: {}, 16345: {}, 16437: {}, - 16509: {}, 16637: {}, 16705: {}, 17072: {}, @@ -330,12 +330,10 @@ var allTopASNs = map[agd.ASN]struct{}{ 20294: {}, 20473: {}, 20634: {}, - 20661: {}, 20776: {}, 20845: {}, 20875: {}, 20910: {}, - 20963: {}, 20978: {}, 21003: {}, 21183: {}, @@ -350,6 +348,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 21450: {}, 21497: {}, 21575: {}, + 21744: {}, 21826: {}, 21928: {}, 21996: {}, @@ -357,7 +356,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 22069: {}, 22085: {}, 22351: {}, - 22581: {}, 22724: {}, 22773: {}, 22927: {}, @@ -371,11 +369,11 @@ var allTopASNs = map[agd.ASN]struct{}{ 23693: {}, 23752: {}, 23889: {}, + 23917: {}, 23955: {}, 23969: {}, 24158: {}, 24203: {}, - 24337: {}, 24378: {}, 24389: {}, 24432: {}, @@ -386,6 +384,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 24722: {}, 24757: {}, 24835: {}, + 24852: {}, 24921: {}, 24940: {}, 25019: {}, @@ -419,6 +418,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 27781: {}, 27800: {}, 27831: {}, + 27839: {}, 27882: {}, 27884: {}, 27895: {}, @@ -449,6 +449,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 29247: {}, 29256: {}, 29310: {}, + 29314: {}, 29355: {}, 29357: {}, 29447: {}, @@ -463,16 +464,15 @@ var allTopASNs = map[agd.ASN]struct{}{ 29975: {}, 30689: {}, 30722: {}, - 30844: {}, 30873: {}, 30969: {}, 30985: {}, 30986: {}, + 30987: {}, 30990: {}, 30992: {}, 30999: {}, 31012: {}, - 31027: {}, 31037: {}, 31042: {}, 31122: {}, @@ -483,12 +483,11 @@ var allTopASNs = map[agd.ASN]struct{}{ 31213: {}, 31224: {}, 31252: {}, - 31404: {}, 31452: {}, + 31549: {}, 31615: {}, 31721: {}, 32020: {}, - 32189: {}, 33363: {}, 33392: {}, 33567: {}, @@ -513,14 +512,12 @@ var allTopASNs = map[agd.ASN]struct{}{ 35228: {}, 35432: {}, 35444: {}, - 35725: {}, 35805: {}, + 35807: {}, 35819: {}, - 35892: {}, 35900: {}, 36290: {}, 36549: {}, - 36866: {}, 36873: {}, 36884: {}, 36890: {}, @@ -541,6 +538,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 36974: {}, 36988: {}, 36992: {}, + 36994: {}, 36996: {}, 36998: {}, 36999: {}, @@ -574,16 +572,16 @@ var allTopASNs = map[agd.ASN]struct{}{ 37284: {}, 37287: {}, 37294: {}, + 37303: {}, 37309: {}, 37323: {}, + 37336: {}, 37337: {}, 37342: {}, 37343: {}, - 37349: {}, 37371: {}, 37376: {}, 37385: {}, - 37406: {}, 37410: {}, 37424: {}, 37440: {}, @@ -593,6 +591,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 37457: {}, 37460: {}, 37461: {}, + 37463: {}, 37473: {}, 37492: {}, 37508: {}, @@ -601,8 +600,8 @@ var allTopASNs = map[agd.ASN]struct{}{ 37526: {}, 37529: {}, 37531: {}, - 37541: {}, 37550: {}, + 37552: {}, 37559: {}, 37563: {}, 37575: {}, @@ -610,13 +609,13 @@ var allTopASNs = map[agd.ASN]struct{}{ 37594: {}, 37611: {}, 37612: {}, - 37614: {}, 37616: {}, 37645: {}, 37649: {}, 37671: {}, 37693: {}, 37705: {}, + 38008: {}, 38009: {}, 38077: {}, 38195: {}, @@ -629,12 +628,15 @@ var allTopASNs = map[agd.ASN]struct{}{ 38742: {}, 38800: {}, 38819: {}, + 38875: {}, 38901: {}, 38999: {}, 39010: {}, 39232: {}, 39603: {}, + 39611: {}, 39642: {}, + 39737: {}, 39891: {}, 40945: {}, 41164: {}, @@ -642,9 +644,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 41329: {}, 41330: {}, 41557: {}, - 41564: {}, - 41653: {}, - 41697: {}, 41738: {}, 41750: {}, 41897: {}, @@ -654,15 +653,19 @@ var allTopASNs = map[agd.ASN]struct{}{ 42298: {}, 42313: {}, 42437: {}, - 42532: {}, 42560: {}, + 42610: {}, 42772: {}, 42779: {}, + 42837: {}, + 42841: {}, 42863: {}, + 42960: {}, 42961: {}, + 43019: {}, 43197: {}, - 43242: {}, 43447: {}, + 43513: {}, 43557: {}, 43571: {}, 43612: {}, @@ -675,7 +678,10 @@ var allTopASNs = map[agd.ASN]struct{}{ 44087: {}, 44143: {}, 44244: {}, + 44395: {}, + 44489: {}, 44558: {}, + 44575: {}, 44869: {}, 45143: {}, 45168: {}, @@ -702,7 +708,7 @@ var allTopASNs = map[agd.ASN]struct{}{ 47331: {}, 47377: {}, 47394: {}, - 47588: {}, + 47524: {}, 47589: {}, 47883: {}, 47956: {}, @@ -715,7 +721,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 48728: {}, 48832: {}, 48887: {}, - 48953: {}, 49273: {}, 49800: {}, 49902: {}, @@ -725,14 +730,14 @@ var allTopASNs = map[agd.ASN]struct{}{ 50251: {}, 50266: {}, 50616: {}, + 50810: {}, 50973: {}, - 50979: {}, 51207: {}, 51375: {}, 51407: {}, 51495: {}, - 51684: {}, 51765: {}, + 51852: {}, 51896: {}, 52228: {}, 52233: {}, @@ -744,8 +749,8 @@ var allTopASNs = map[agd.ASN]struct{}{ 52341: {}, 52362: {}, 52398: {}, - 52433: {}, 52468: {}, + 53667: {}, 55330: {}, 55430: {}, 55805: {}, @@ -753,10 +758,12 @@ var allTopASNs = map[agd.ASN]struct{}{ 55850: {}, 55943: {}, 55944: {}, + 56017: {}, 56055: {}, 56089: {}, 56167: {}, 56300: {}, + 56369: {}, 56653: {}, 56665: {}, 56696: {}, @@ -766,18 +773,16 @@ var allTopASNs = map[agd.ASN]struct{}{ 57388: {}, 57513: {}, 57704: {}, - 58065: {}, 58224: {}, 58460: {}, 58731: {}, + 58952: {}, 59257: {}, - 59588: {}, 59989: {}, 60068: {}, 60258: {}, 61143: {}, 61461: {}, - 62563: {}, 63949: {}, 64466: {}, 131178: {}, @@ -791,29 +796,32 @@ var allTopASNs = map[agd.ASN]struct{}{ 132167: {}, 132199: {}, 132471: {}, + 132486: {}, 132618: {}, - 132831: {}, 133385: {}, 133481: {}, 133579: {}, + 133606: {}, 133612: {}, - 133897: {}, 134783: {}, 135409: {}, 136255: {}, + 136950: {}, 137412: {}, + 137824: {}, 138179: {}, 139759: {}, 139831: {}, 139898: {}, 139922: {}, + 140504: {}, 196838: {}, 197207: {}, 197830: {}, 198279: {}, 198605: {}, 199140: {}, - 199155: {}, + 199276: {}, 199731: {}, 200134: {}, 201167: {}, @@ -822,52 +830,58 @@ var allTopASNs = map[agd.ASN]struct{}{ 202087: {}, 202254: {}, 202422: {}, + 202448: {}, 203214: {}, 203953: {}, 203995: {}, 204170: {}, - 204317: {}, - 204342: {}, - 204649: {}, + 204279: {}, + 205110: {}, 205368: {}, 205714: {}, 206026: {}, 206067: {}, 206206: {}, 206262: {}, + 207369: {}, 207569: {}, 207651: {}, 207810: {}, + 209424: {}, + 210003: {}, 210315: {}, 210542: {}, 211144: {}, + 212238: {}, + 212370: {}, 213155: {}, - 213373: {}, + 213371: {}, 262145: {}, 262186: {}, 262197: {}, 262202: {}, 262210: {}, 262239: {}, + 262589: {}, 263238: {}, 263725: {}, 263783: {}, 263824: {}, 264628: {}, 264645: {}, + 264663: {}, 264668: {}, 264731: {}, + 266673: {}, 269729: {}, 271773: {}, 327697: {}, - 327707: {}, 327712: {}, 327725: {}, 327738: {}, 327756: {}, 327765: {}, 327769: {}, - 327776: {}, 327799: {}, 327802: {}, 327885: {}, @@ -876,7 +890,6 @@ var allTopASNs = map[agd.ASN]struct{}{ 328061: {}, 328079: {}, 328088: {}, - 328136: {}, 328140: {}, 328169: {}, 328191: {}, @@ -886,16 +899,13 @@ var allTopASNs = map[agd.ASN]struct{}{ 328286: {}, 328297: {}, 328309: {}, - 328411: {}, 328453: {}, 328469: {}, 328488: {}, - 328586: {}, - 328605: {}, - 328708: {}, + 328539: {}, 328755: {}, 328943: {}, - 329020: {}, + 393275: {}, 394311: {}, 395561: {}, 396304: {}, @@ -908,9 +918,9 @@ var allTopASNs = map[agd.ASN]struct{}{ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryAD: 6752, agd.CountryAE: 5384, - agd.CountryAF: 55330, + agd.CountryAF: 132471, agd.CountryAG: 11594, - agd.CountryAI: 11139, + agd.CountryAI: 2740, agd.CountryAL: 21183, agd.CountryAM: 12297, agd.CountryAO: 37119, @@ -931,55 +941,54 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryBI: 327799, agd.CountryBJ: 37424, agd.CountryBL: 3215, - agd.CountryBM: 3855, + agd.CountryBM: 32020, agd.CountryBN: 10094, agd.CountryBO: 6568, agd.CountryBQ: 27745, - agd.CountryBR: 26599, + agd.CountryBR: 28573, agd.CountryBS: 15146, agd.CountryBT: 18024, agd.CountryBW: 14988, - agd.CountryBY: 25106, - agd.CountryBZ: 10269, + agd.CountryBY: 6697, + agd.CountryBZ: 212370, agd.CountryCA: 812, - agd.CountryCC: 198605, agd.CountryCD: 37020, agd.CountryCF: 37460, - agd.CountryCG: 37451, + agd.CountryCG: 36924, agd.CountryCH: 3303, agd.CountryCI: 29571, agd.CountryCK: 10131, - agd.CountryCL: 27651, + agd.CountryCL: 7418, agd.CountryCM: 30992, agd.CountryCN: 4134, - agd.CountryCO: 26611, + agd.CountryCO: 10620, agd.CountryCR: 11830, agd.CountryCU: 27725, agd.CountryCV: 37517, agd.CountryCW: 52233, - agd.CountryCX: 198605, - agd.CountryCY: 6866, + agd.CountryCY: 202448, agd.CountryCZ: 5610, agd.CountryDE: 3320, agd.CountryDJ: 30990, - agd.CountryDK: 3292, + agd.CountryDK: 9009, agd.CountryDM: 40945, agd.CountryDO: 6400, agd.CountryDZ: 36947, agd.CountryEC: 27947, agd.CountryEE: 3249, agd.CountryEG: 8452, + agd.CountryEH: 6713, agd.CountryER: 24757, agd.CountryES: 12479, agd.CountryET: 24757, agd.CountryFI: 51765, agd.CountryFJ: 38442, - agd.CountryFK: 204649, + agd.CountryFK: 198605, agd.CountryFM: 139759, agd.CountryFO: 15389, agd.CountryFR: 3215, - agd.CountryGA: 16058, - agd.CountryGB: 2856, + agd.CountryGA: 36924, + agd.CountryGB: 60068, agd.CountryGD: 46650, agd.CountryGE: 16010, agd.CountryGF: 3215, @@ -987,7 +996,7 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryGH: 30986, agd.CountryGI: 8301, agd.CountryGL: 8818, - agd.CountryGM: 37309, + agd.CountryGM: 25250, agd.CountryGN: 37461, agd.CountryGP: 3215, agd.CountryGQ: 37173, @@ -1002,13 +1011,13 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryHT: 27653, agd.CountryHU: 5483, agd.CountryID: 7713, - agd.CountryIE: 15502, + agd.CountryIE: 6830, agd.CountryIL: 1680, agd.CountryIM: 13122, agd.CountryIN: 55836, agd.CountryIO: 17458, agd.CountryIQ: 203214, - agd.CountryIR: 44244, + agd.CountryIR: 58224, agd.CountryIS: 43571, agd.CountryIT: 1267, agd.CountryJE: 8680, @@ -1019,21 +1028,21 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryKG: 50223, agd.CountryKH: 38623, agd.CountryKI: 134783, - agd.CountryKM: 328061, + agd.CountryKM: 36939, agd.CountryKN: 11139, agd.CountryKR: 4766, agd.CountryKW: 29357, agd.CountryKY: 6639, agd.CountryKZ: 206026, agd.CountryLA: 9873, - agd.CountryLB: 38999, + agd.CountryLB: 42003, agd.CountryLC: 15344, agd.CountryLI: 20634, agd.CountryLK: 18001, - agd.CountryLR: 37094, + agd.CountryLR: 37410, agd.CountryLS: 33567, agd.CountryLT: 8764, - agd.CountryLU: 9009, + agd.CountryLU: 6661, agd.CountryLV: 24921, agd.CountryLY: 21003, agd.CountryMA: 36903, @@ -1068,37 +1077,35 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryNL: 1136, agd.CountryNO: 2119, agd.CountryNP: 17501, - agd.CountryNR: 45355, - agd.CountryNU: 198605, - agd.CountryNZ: 4771, + agd.CountryNR: 140504, + agd.CountryNZ: 9790, agd.CountryOM: 28885, - agd.CountryPA: 11556, + agd.CountryPA: 18809, agd.CountryPE: 12252, agd.CountryPF: 9471, agd.CountryPG: 139898, agd.CountryPH: 9299, agd.CountryPK: 45669, - agd.CountryPL: 43447, + agd.CountryPL: 5617, agd.CountryPM: 3695, - agd.CountryPR: 21928, + agd.CountryPR: 14638, agd.CountryPS: 12975, - agd.CountryPT: 12353, + agd.CountryPT: 3243, agd.CountryPW: 17893, agd.CountryPY: 23201, agd.CountryQA: 42298, - agd.CountryRE: 37002, + agd.CountryRE: 3215, agd.CountryRO: 8708, agd.CountryRS: 8400, agd.CountryRU: 8359, agd.CountryRW: 36924, agd.CountrySA: 39891, agd.CountrySB: 45891, - agd.CountrySC: 131267, + agd.CountrySC: 36958, agd.CountrySD: 15706, - agd.CountrySE: 1257, + agd.CountrySE: 60068, agd.CountrySG: 4773, agd.CountrySI: 3212, - agd.CountrySJ: 198605, agd.CountrySK: 6855, agd.CountrySL: 37164, agd.CountrySM: 15433, @@ -1113,7 +1120,7 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountrySZ: 328169, agd.CountryTC: 394311, agd.CountryTD: 327802, - agd.CountryTG: 24691, + agd.CountryTG: 36924, agd.CountryTH: 131445, agd.CountryTJ: 43197, agd.CountryTK: 4648, @@ -1121,8 +1128,9 @@ var countryTopASNs = map[agd.Country]agd.ASN{ agd.CountryTM: 51495, agd.CountryTN: 37705, agd.CountryTO: 38201, - agd.CountryTR: 16135, + agd.CountryTR: 47331, agd.CountryTT: 27800, + agd.CountryTV: 23917, agd.CountryTW: 3462, agd.CountryTZ: 36908, agd.CountryUA: 15895, diff --git a/internal/geoip/asntops_generate.go b/internal/geoip/asntops_generate.go index 60eee58..9a32879 100644 --- a/internal/geoip/asntops_generate.go +++ b/internal/geoip/asntops_generate.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" ) @@ -22,7 +23,7 @@ func main() { req, err := http.NewRequest(http.MethodGet, countriesASNURL, nil) check(err) - req.Header.Add("User-Agent", agdhttp.UserAgent()) + req.Header.Add(httphdr.UserAgent, agdhttp.UserAgent()) resp, err := c.Do(req) check(err) diff --git a/internal/geoip/file.go b/internal/geoip/file.go index 432e3ca..5d29b81 100644 --- a/internal/geoip/file.go +++ b/internal/geoip/file.go @@ -23,7 +23,8 @@ type FileConfig struct { // ASNPath is the path to the GeoIP database of ASNs. ASNPath string - // CountryPath is the path to the GeoIP database of countries. + // CountryPath is the path to the GeoIP database of countries. The + // databases containing subdivisions and cities info are also supported. CountryPath string // HostCacheSize is how many lookups are cached by hostname. Zero means no @@ -207,18 +208,16 @@ func (f *File) Data(host string, ip netip.Addr) (l *agd.Location, err error) { return nil, fmt.Errorf("looking up asn: %w", err) } - ctry, cont, err := f.lookupCtry(ip) + l = &agd.Location{ + ASN: asn, + } + + err = f.setCtry(l, ip) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err } - l = &agd.Location{ - Country: ctry, - Continent: cont, - ASN: asn, - } - f.setCaches(host, cacheKey, l) return l, nil @@ -271,29 +270,36 @@ type countryResult struct { Country struct { ISOCode string `maxminddb:"iso_code"` } `maxminddb:"country"` + Subdivisions []struct { + ISOCode string `maxminddb:"iso_code"` + } `maxminddb:"subdivisions"` } -// lookupCtry looks up and returns the country and continent parts of the GeoIP -// data for ip. -func (f *File) lookupCtry(ip netip.Addr) (ctry agd.Country, cont agd.Continent, err error) { +// setCtry looks up and sets the country, continent and the subdivision parts +// of the GeoIP data for ip into loc. loc must not be nil. +func (f *File) setCtry(loc *agd.Location, ip netip.Addr) (err error) { // TODO(a.garipov): Remove AsSlice if oschwald/maxminddb-golang#88 is done. var res countryResult err = f.country.Lookup(ip.AsSlice(), &res) if err != nil { - return ctry, cont, fmt.Errorf("looking up country: %w", err) + return fmt.Errorf("looking up country: %w", err) } - ctry, err = agd.NewCountry(res.Country.ISOCode) + loc.Country, err = agd.NewCountry(res.Country.ISOCode) if err != nil { - return ctry, cont, fmt.Errorf("converting country: %w", err) + return fmt.Errorf("converting country: %w", err) } - cont, err = agd.NewContinent(res.Continent.Code) + loc.Continent, err = agd.NewContinent(res.Continent.Code) if err != nil { - return ctry, cont, fmt.Errorf("converting continent: %w", err) + return fmt.Errorf("converting continent: %w", err) } - return ctry, cont, nil + if len(res.Subdivisions) > 0 { + loc.TopSubdivision = res.Subdivisions[0].ISOCode + } + + return nil } // setCaches sets the GeoIP data into the caches. diff --git a/internal/geoip/file_test.go b/internal/geoip/file_test.go index 58c629c..1715546 100644 --- a/internal/geoip/file_test.go +++ b/internal/geoip/file_test.go @@ -11,7 +11,31 @@ import ( "github.com/stretchr/testify/require" ) -func TestFile_Data(t *testing.T) { +func TestFile_Data_cityDB(t *testing.T) { + conf := &geoip.FileConfig{ + ASNPath: asnPath, + CountryPath: cityPath, + HostCacheSize: 0, + IPCacheSize: 1, + } + + g, err := geoip.NewFile(conf) + require.NoError(t, err) + + d, err := g.Data(testHost, testIPWithASN) + require.NoError(t, err) + + assert.Equal(t, testASN, d.ASN) + + d, err = g.Data(testHost, testIPWithSubdiv) + require.NoError(t, err) + + assert.Equal(t, testCtry, d.Country) + assert.Equal(t, testCont, d.Continent) + assert.Equal(t, testSubdiv, d.TopSubdivision) +} + +func TestFile_Data_countryDB(t *testing.T) { conf := &geoip.FileConfig{ ASNPath: asnPath, CountryPath: countryPath, @@ -27,17 +51,18 @@ func TestFile_Data(t *testing.T) { assert.Equal(t, testASN, d.ASN) - d, err = g.Data(testHost, testIPWithCountry) + d, err = g.Data(testHost, testIPWithSubdiv) require.NoError(t, err) assert.Equal(t, testCtry, d.Country) assert.Equal(t, testCont, d.Continent) + assert.Empty(t, d.TopSubdivision) } func TestFile_Data_hostCache(t *testing.T) { conf := &geoip.FileConfig{ ASNPath: asnPath, - CountryPath: countryPath, + CountryPath: cityPath, HostCacheSize: 1, IPCacheSize: 1, } @@ -64,7 +89,7 @@ func TestFile_Data_hostCache(t *testing.T) { func TestFile_SubnetByLocation(t *testing.T) { conf := &geoip.FileConfig{ ASNPath: asnPath, - CountryPath: countryPath, + CountryPath: cityPath, HostCacheSize: 0, IPCacheSize: 1, } @@ -91,7 +116,7 @@ var errSink error func BenchmarkFile_Data(b *testing.B) { conf := &geoip.FileConfig{ ASNPath: asnPath, - CountryPath: countryPath, + CountryPath: cityPath, HostCacheSize: 0, IPCacheSize: 1, } @@ -103,7 +128,7 @@ func BenchmarkFile_Data(b *testing.B) { // Change the eighth byte in testIPWithCountry to create a different address // in the same network. - ipSlice := testIPWithCountry.AsSlice() + ipSlice := ipCountry1.AsSlice() ipSlice[7] = 1 ipCountry2, ok := netip.AddrFromSlice(ipSlice) require.True(b, ok) @@ -142,7 +167,7 @@ var fileSink *geoip.File func BenchmarkNewFile(b *testing.B) { conf := &geoip.FileConfig{ ASNPath: asnPath, - CountryPath: countryPath, + CountryPath: cityPath, HostCacheSize: 0, IPCacheSize: 1, } diff --git a/internal/geoip/geoip_test.go b/internal/geoip/geoip_test.go index b6d2d50..2768867 100644 --- a/internal/geoip/geoip_test.go +++ b/internal/geoip/geoip_test.go @@ -15,6 +15,7 @@ func TestMain(m *testing.M) { // Paths to test data. const ( asnPath = "./testdata/GeoLite2-ASN-Test.mmdb" + cityPath = "./testdata/GeoIP2-City-Test.mmdb" countryPath = "./testdata/GeoIP2-Country-Test.mmdb" ) @@ -24,12 +25,16 @@ const ( testOtherHost = "other.example.com" ) -// Test data. See https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoLite2-ASN-Test.json -// and https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-Country-Test.json. +// Test data. See [ASN], [city], and [country] testing datum. +// +// [ASN]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoLite2-ASN-Test.json +// [city]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-City-Test.json +// [country]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-Country-Test.json const ( - testASN agd.ASN = 1221 - testCtry agd.Country = agd.CountryJP - testCont agd.Continent = agd.ContinentAS + testASN agd.ASN = 1221 + testCtry agd.Country = agd.CountryUS + testCont agd.Continent = agd.ContinentNA + testSubdiv string = "WA" testIPv4SubnetCtry = agd.CountryUS testIPv6SubnetCtry = agd.CountryJP @@ -38,7 +43,13 @@ const ( // testIPWithASN has ASN set to 1221 in the test database. var testIPWithASN = netip.MustParseAddr("1.128.0.0") -// testIPWithCountry has country set to Japan in the test database. +// testIPWithSubdiv has country set to USA and the subdivision set to Washington +// in the city-aware test database. It has no subdivision in the country-aware +// test database but resolves into USA as well. +var testIPWithSubdiv = netip.MustParseAddr("216.160.83.56") + +// testIPWithCountry has country set to Japan in the country-aware test +// database. var testIPWithCountry = netip.MustParseAddr("2001:218::") // Subnets for CountrySubnet tests. diff --git a/internal/geoip/testdata/GeoIP2-City-Test.mmdb b/internal/geoip/testdata/GeoIP2-City-Test.mmdb new file mode 100644 index 0000000..43ab5ed Binary files /dev/null and b/internal/geoip/testdata/GeoIP2-City-Test.mmdb differ diff --git a/internal/metrics/datastorage.go b/internal/metrics/backend.go similarity index 74% rename from internal/metrics/datastorage.go rename to internal/metrics/backend.go index 8874d52..87761f4 100644 --- a/internal/metrics/datastorage.go +++ b/internal/metrics/backend.go @@ -5,6 +5,24 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) +// DevicesCountGauge is a gauge with the total number of user devices loaded +// from the backend. +var DevicesCountGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "devices_total", + Subsystem: subsystemBackend, + Namespace: namespace, + Help: "The total number of user devices loaded from the backend.", +}) + +// DevicesNewCountGauge is a gauge with the number of user devices downloaded +// during the last sync. +var DevicesNewCountGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "devices_newly_synced_total", + Subsystem: subsystemBackend, + Namespace: namespace, + Help: "The number of user devices that were changed or added since the previous sync.", +}) + // ProfilesCountGauge is a gauge with the total number of user profiles loaded // from the backend. var ProfilesCountGauge = promauto.NewGauge(prometheus.GaugeOpts{ diff --git a/internal/metrics/connlimiter.go b/internal/metrics/connlimiter.go new file mode 100644 index 0000000..154d8d1 --- /dev/null +++ b/internal/metrics/connlimiter.go @@ -0,0 +1,45 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// ConnLimiterLimits is the gauge vector for showing the configured limits of +// the number of active stream-connections. +var ConnLimiterLimits = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "limits", + Namespace: namespace, + Subsystem: subsystemConnLimiter, + Help: `The current limits of the number of active stream-connections: ` + + `kind="stop" for the stopping limit and kind="resume" for the resuming one.`, +}, []string{"kind"}) + +// ConnLimiterActiveStreamConns is the gauge vector for the number of active +// stream-connections. +var ConnLimiterActiveStreamConns = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "active_stream_conns", + Namespace: namespace, + Subsystem: subsystemConnLimiter, + Help: `The number of currently active stream-connections.`, +}, []string{"name", "proto", "addr"}) + +// StreamConnWaitDuration is a histogram with the duration of waiting times for +// accepting stream connections. +var StreamConnWaitDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "stream_conn_wait_duration_seconds", + Subsystem: subsystemConnLimiter, + Namespace: namespace, + Help: "How long a stream connection waits for an accept, in seconds.", + Buckets: []float64{0.00001, 0.01, 0.1, 1, 10, 30, 60}, +}, []string{"name", "proto", "addr"}) + +// StreamConnLifeDuration is a histogram with the duration of lives of stream +// connections. +var StreamConnLifeDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "stream_conn_life_duration_seconds", + Subsystem: subsystemConnLimiter, + Namespace: namespace, + Help: "How long a stream connection lives, in seconds.", + Buckets: []float64{0.1, 1, 5, 10, 30, 60}, +}, []string{"name", "proto", "addr"}) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 5371034..012d721 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,5 +1,6 @@ // Package metrics contains definitions of most of the prometheus metrics // that we use in AdGuard DNS. +// // TODO(ameshkov): consider not using promauto. package metrics @@ -16,6 +17,7 @@ const ( subsystemApplication = "app" subsystemBackend = "backend" subsystemBillStat = "billstat" + subsystemConnLimiter = "connlimiter" subsystemConsul = "consul" subsystemDNSCheck = "dnscheck" subsystemDNSDB = "dnsdb" @@ -30,8 +32,8 @@ const ( subsystemWebSvc = "websvc" ) -// SetUpGauge signals that the server has been started. -// We're using a function here to avoid circular dependencies. +// SetUpGauge signals that the server has been started. Use a function here to +// avoid circular dependencies. func SetUpGauge(version, buildtime, branch, revision, goversion string) { upGauge := promauto.NewGauge( prometheus.GaugeOpts{ diff --git a/internal/metrics/research.go b/internal/metrics/research.go index 50646a6..5c4b34e 100644 --- a/internal/metrics/research.go +++ b/internal/metrics/research.go @@ -1,8 +1,10 @@ package metrics import ( + "github.com/AdguardTeam/AdGuardDNS/internal/agd" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/common/model" ) // ResearchRequestsPerCountryTotal counts the total number of queries per @@ -23,39 +25,80 @@ var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.C Help: "The number of blocked DNS queries per country from anonymous users.", }, []string{"filter", "country"}) +// ResearchRequestsPerSubdivTotal counts the total number of queries per country +// from anonymous users. +var ResearchRequestsPerSubdivTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "requests_per_subdivision_total", + Namespace: namespace, + Subsystem: subsystemResearch, + Help: `The total number of DNS queries per countries with top ` + + `subdivision from anonymous users.`, +}, []string{"country", "subdivision"}) + +// ResearchBlockedRequestsPerSubdivTotal counts the number of blocked queries +// per country from anonymous users. +var ResearchBlockedRequestsPerSubdivTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "blocked_per_subdivision_total", + Namespace: namespace, + Subsystem: subsystemResearch, + Help: `The number of blocked DNS queries per countries with top ` + + `subdivision from anonymous users.`, +}, []string{"filter", "country", "subdivision"}) + // ReportResearchMetrics reports metrics to prometheus that we may need to // conduct researches. -// -// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved. func ReportResearchMetrics( - anonymous bool, - filteringEnabled bool, - asn string, - ctry string, - filterID string, + ri *agd.RequestInfo, + filterID agd.FilterListID, blocked bool, ) { - // The current research metrics only count queries that come to public - // DNS servers where filtering is enabled. - if !filteringEnabled || !anonymous { + filteringEnabled := ri.FilteringGroup != nil && + ri.FilteringGroup.RuleListsEnabled && + len(ri.FilteringGroup.RuleListIDs) > 0 + + // The current research metrics only count queries that come to public DNS + // servers where filtering is enabled. + if !filteringEnabled || ri.Profile != nil { return } - // Ignore AdGuard ASN specifically in order to avoid counting queries that - // come from the monitoring. This part is ugly, but since these metrics - // are a one-time deal, this is acceptable. - // - // TODO(ameshkov): think of a better way later if we need to do that again. - if asn == "212772" { - return + var ctry, subdiv string + if l := ri.Location; l != nil { + // Ignore AdGuard ASN specifically in order to avoid counting queries + // that come from the monitoring. This part is ugly, but since these + // metrics are a one-time deal, this is acceptable. + // + // TODO(ameshkov): Think of a better way later if we need to do that + // again. + if l.ASN == 212772 { + return + } + + ctry = string(l.Country) + if model.LabelValue(l.TopSubdivision).IsValid() { + subdiv = l.TopSubdivision + } } if blocked { - ResearchBlockedRequestsPerCountryTotal.WithLabelValues( - filterID, - ctry, - ).Inc() + reportResearchBlocked(string(filterID), ctry, subdiv) } - ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc() + reportResearchRequest(ctry, subdiv) +} + +// reportResearchBlocked reports on a blocked request to the research metrics. +func reportResearchBlocked(fltID, ctry, subdiv string) { + ResearchBlockedRequestsPerCountryTotal.WithLabelValues(fltID, ctry).Inc() + if subdiv != "" { + ResearchBlockedRequestsPerSubdivTotal.WithLabelValues(fltID, ctry, subdiv).Inc() + } +} + +// reportResearchBlocked reports on a request to the research metrics. +func reportResearchRequest(ctry, subdiv string) { + ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc() + if subdiv != "" { + ResearchRequestsPerSubdivTotal.WithLabelValues(ctry, subdiv).Inc() + } } diff --git a/internal/metrics/tls.go b/internal/metrics/tls.go index b619b9c..0f6baa7 100644 --- a/internal/metrics/tls.go +++ b/internal/metrics/tls.go @@ -93,6 +93,8 @@ func TLSMetricsAfterHandshake( tlsVersionToString(state.Version), BoolString(state.DidResume), tls.CipherSuiteName(state.CipherSuite), + // Don't validate the negotiated protocol since it's expected to + // contain only ASCII after negotiation itself. state.NegotiatedProtocol, sLabel, ).Inc() @@ -112,12 +114,17 @@ func TLSMetricsBeforeHandshake(proto string) (f func(*tls.ClientHelloInfo) (*tls } } + supProtos := make([]string, len(info.SupportedProtos)) + for i := range info.SupportedProtos { + supProtos[i] = strings.ToValidUTF8(info.SupportedProtos[i], "") + } + // Stick to using WithLabelValues instead of With in order to avoid // extra allocations on prometheus.Labels. The labels order is VERY // important here. TLSHandshakeAttemptsTotal.WithLabelValues( proto, - strings.Join(info.SupportedProtos, ","), + strings.Join(supProtos, ","), tlsVersionToString(maxVersion), ).Inc() @@ -127,7 +134,6 @@ func TLSMetricsBeforeHandshake(proto string) (f func(*tls.ClientHelloInfo) (*tls // tlsVersionToString converts TLS version to string. func tlsVersionToString(ver uint16) (tlsVersion string) { - tlsVersion = "unknown" switch ver { case tls.VersionTLS13: tlsVersion = "tls1.3" @@ -137,7 +143,10 @@ func tlsVersionToString(ver uint16) (tlsVersion string) { tlsVersion = "tls1.1" case tls.VersionTLS10: tlsVersion = "tls1.0" + default: + tlsVersion = "unknown" } + return tlsVersion } @@ -151,8 +160,8 @@ func serverNameToLabel( srvCerts []tls.Certificate, ) (label string) { if sni == "" { - // SNI is not provided, so the request is probably made on the - // IP address. + // SNI is not provided, so the request is probably made on the IP + // address. return fmt.Sprintf("%s: other", srvName) } diff --git a/internal/metrics/tls_test.go b/internal/metrics/tls_test.go index e3ccda7..9fd28f7 100644 --- a/internal/metrics/tls_test.go +++ b/internal/metrics/tls_test.go @@ -7,7 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_model/go" + io_prometheus_client "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -135,3 +135,18 @@ outerLoop: return assert.Truef(t, ok, "%s not found in server name labels", wantLabel) } + +func TestTLSMetricsBeforeHandshake(t *testing.T) { + f := metrics.TLSMetricsBeforeHandshake("srv-name") + + var conf *tls.Config + var err error + require.NotPanics(t, func() { + conf, err = f(&tls.ClientHelloInfo{ + SupportedProtos: []string{"\xC0\xC1\xF5\xF6\xF7\xF8\xF9\xFA\xFB\xFC\xFD\xFE\xFF"}, + }) + }) + require.NoError(t, err) + + assert.Nil(t, conf) +} diff --git a/internal/profiledb/error.go b/internal/profiledb/error.go new file mode 100644 index 0000000..40814ec --- /dev/null +++ b/internal/profiledb/error.go @@ -0,0 +1,7 @@ +package profiledb + +import "github.com/AdguardTeam/golibs/errors" + +// ErrDeviceNotFound is an error returned by lookup methods when a device +// couldn't be found. +const ErrDeviceNotFound errors.Error = "device not found" diff --git a/internal/profiledb/internal/filecachejson/filecachejson.go b/internal/profiledb/internal/filecachejson/filecachejson.go new file mode 100644 index 0000000..3c3110e --- /dev/null +++ b/internal/profiledb/internal/filecachejson/filecachejson.go @@ -0,0 +1,117 @@ +// Package filecachejson contains an implementation of the file-cache storage +// that encodes data using JSON. +package filecachejson + +import ( + "encoding/json" + "fmt" + "os" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/google/renameio" +) + +// Storage is the file-cache storage that encodes data using JSON. +type Storage struct { + path string +} + +// New returns a new JSON-encoded file-cache storage. +func New(cachePath string) (s *Storage) { + return &Storage{ + path: cachePath, + } +} + +// fileCache is the structure for the JSON filesystem cache of a profile +// database. +// +// NOTE: Do not change fields of this structure without incrementing +// [internal.FileCacheVersion]. +type fileCache struct { + SyncTime time.Time `json:"sync_time"` + Profiles []*agd.Profile `json:"profiles"` + Devices []*agd.Device `json:"devices"` + Version int32 `json:"version"` +} + +// logPrefix is the logging prefix for the JSON-encoded file-cache. +const logPrefix = "profiledb json cache" + +var _ internal.FileCacheStorage = (*Storage)(nil) + +// Load implements the [internal.FileCacheStorage] interface for *Storage. +func (s *Storage) Load() (c *internal.FileCache, err error) { + log.Info("%s: loading", logPrefix) + + data, err := s.loadFromFile() + if err != nil { + return nil, fmt.Errorf("loading from file: %w", err) + } + + if data == nil { + log.Info("%s: file not present", logPrefix) + + return nil, nil + } + + return &internal.FileCache{ + SyncTime: data.SyncTime, + Profiles: data.Profiles, + Devices: data.Devices, + Version: data.Version, + }, nil +} + +// loadFromFile loads the profile data from cache file. +func (s *Storage) loadFromFile() (data *fileCache, err error) { + file, err := os.Open(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + // File could be deleted or not yet created, go on. + return nil, nil + } + + return nil, err + } + defer func() { err = errors.WithDeferred(err, file.Close()) }() + + data = &fileCache{} + err = json.NewDecoder(file).Decode(data) + if err != nil { + return nil, fmt.Errorf("decoding json: %w", err) + } + + return data, nil +} + +// Store implements the [internal.FileCacheStorage] interface for *Storage. +func (s *Storage) Store(c *internal.FileCache) (err error) { + profNum := len(c.Profiles) + log.Info("%s: saving %d profiles to %q", logPrefix, profNum, s.path) + defer log.Info("%s: saved %d profiles to %q", logPrefix, profNum, s.path) + + data := &fileCache{ + SyncTime: c.SyncTime, + Profiles: c.Profiles, + Devices: c.Devices, + Version: c.Version, + } + + cache, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("encoding json: %w", err) + } + + err = renameio.WriteFile(s.path, cache, 0o600) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + return nil +} diff --git a/internal/profiledb/internal/filecachejson/filecachejson_test.go b/internal/profiledb/internal/filecachejson/filecachejson_test.go new file mode 100644 index 0000000..7c017f1 --- /dev/null +++ b/internal/profiledb/internal/filecachejson/filecachejson_test.go @@ -0,0 +1,53 @@ +package filecachejson_test + +import ( + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachejson" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +func TestStorage(t *testing.T) { + prof, dev := profiledbtest.NewProfile(t) + cachePath := filepath.Join(t.TempDir(), "profiles.json") + s := filecachejson.New(cachePath) + require.NotNil(t, s) + + fc := &internal.FileCache{ + SyncTime: time.Now().Round(0).UTC(), + Profiles: []*agd.Profile{prof}, + Devices: []*agd.Device{dev}, + Version: internal.FileCacheVersion, + } + + err := s.Store(fc) + require.NoError(t, err) + + gotFC, err := s.Load() + require.NoError(t, err) + require.NotNil(t, gotFC) + require.NotEmpty(t, *gotFC) + + assert.Equal(t, fc, gotFC) +} + +func TestStorage_Load_noFile(t *testing.T) { + cachePath := filepath.Join(t.TempDir(), "profiles.json") + s := filecachejson.New(cachePath) + require.NotNil(t, s) + + fc, err := s.Load() + assert.NoError(t, err) + assert.Nil(t, fc) +} diff --git a/internal/profiledb/internal/filecachepb/filecache.pb.go b/internal/profiledb/internal/filecachepb/filecache.pb.go new file mode 100644 index 0000000..2c72dfc --- /dev/null +++ b/internal/profiledb/internal/filecachepb/filecache.pb.go @@ -0,0 +1,1165 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v4.22.3 +// source: filecache.proto + +package filecachepb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + durationpb "google.golang.org/protobuf/types/known/durationpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type FileCache struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SyncTime *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=sync_time,json=syncTime,proto3" json:"sync_time,omitempty"` + Profiles []*Profile `protobuf:"bytes,2,rep,name=profiles,proto3" json:"profiles,omitempty"` + Devices []*Device `protobuf:"bytes,3,rep,name=devices,proto3" json:"devices,omitempty"` + Version int32 `protobuf:"varint,4,opt,name=version,proto3" json:"version,omitempty"` +} + +func (x *FileCache) Reset() { + *x = FileCache{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FileCache) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FileCache) ProtoMessage() {} + +func (x *FileCache) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FileCache.ProtoReflect.Descriptor instead. +func (*FileCache) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{0} +} + +func (x *FileCache) GetSyncTime() *timestamppb.Timestamp { + if x != nil { + return x.SyncTime + } + return nil +} + +func (x *FileCache) GetProfiles() []*Profile { + if x != nil { + return x.Profiles + } + return nil +} + +func (x *FileCache) GetDevices() []*Device { + if x != nil { + return x.Devices + } + return nil +} + +func (x *FileCache) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +type Profile struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Parental *ParentalProtectionSettings `protobuf:"bytes,1,opt,name=parental,proto3" json:"parental,omitempty"` + // Types that are assignable to BlockingMode: + // + // *Profile_BlockingModeCustomIp + // *Profile_BlockingModeNxdomain + // *Profile_BlockingModeNullIp + // *Profile_BlockingModeRefused + BlockingMode isProfile_BlockingMode `protobuf_oneof:"blocking_mode"` + ProfileId string `protobuf:"bytes,6,opt,name=profile_id,json=profileId,proto3" json:"profile_id,omitempty"` + UpdateTime *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"` + DeviceIds []string `protobuf:"bytes,8,rep,name=device_ids,json=deviceIds,proto3" json:"device_ids,omitempty"` + RuleListIds []string `protobuf:"bytes,9,rep,name=rule_list_ids,json=ruleListIds,proto3" json:"rule_list_ids,omitempty"` + CustomRules []string `protobuf:"bytes,10,rep,name=custom_rules,json=customRules,proto3" json:"custom_rules,omitempty"` + FilteredResponseTtl *durationpb.Duration `protobuf:"bytes,11,opt,name=filtered_response_ttl,json=filteredResponseTtl,proto3" json:"filtered_response_ttl,omitempty"` + FilteringEnabled bool `protobuf:"varint,12,opt,name=filtering_enabled,json=filteringEnabled,proto3" json:"filtering_enabled,omitempty"` + SafeBrowsingEnabled bool `protobuf:"varint,13,opt,name=safe_browsing_enabled,json=safeBrowsingEnabled,proto3" json:"safe_browsing_enabled,omitempty"` + RuleListsEnabled bool `protobuf:"varint,14,opt,name=rule_lists_enabled,json=ruleListsEnabled,proto3" json:"rule_lists_enabled,omitempty"` + QueryLogEnabled bool `protobuf:"varint,15,opt,name=query_log_enabled,json=queryLogEnabled,proto3" json:"query_log_enabled,omitempty"` + Deleted bool `protobuf:"varint,16,opt,name=deleted,proto3" json:"deleted,omitempty"` + BlockPrivateRelay bool `protobuf:"varint,17,opt,name=block_private_relay,json=blockPrivateRelay,proto3" json:"block_private_relay,omitempty"` + BlockFirefoxCanary bool `protobuf:"varint,18,opt,name=block_firefox_canary,json=blockFirefoxCanary,proto3" json:"block_firefox_canary,omitempty"` +} + +func (x *Profile) Reset() { + *x = Profile{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Profile) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Profile) ProtoMessage() {} + +func (x *Profile) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Profile.ProtoReflect.Descriptor instead. +func (*Profile) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{1} +} + +func (x *Profile) GetParental() *ParentalProtectionSettings { + if x != nil { + return x.Parental + } + return nil +} + +func (m *Profile) GetBlockingMode() isProfile_BlockingMode { + if m != nil { + return m.BlockingMode + } + return nil +} + +func (x *Profile) GetBlockingModeCustomIp() *BlockingModeCustomIP { + if x, ok := x.GetBlockingMode().(*Profile_BlockingModeCustomIp); ok { + return x.BlockingModeCustomIp + } + return nil +} + +func (x *Profile) GetBlockingModeNxdomain() *BlockingModeNXDOMAIN { + if x, ok := x.GetBlockingMode().(*Profile_BlockingModeNxdomain); ok { + return x.BlockingModeNxdomain + } + return nil +} + +func (x *Profile) GetBlockingModeNullIp() *BlockingModeNullIP { + if x, ok := x.GetBlockingMode().(*Profile_BlockingModeNullIp); ok { + return x.BlockingModeNullIp + } + return nil +} + +func (x *Profile) GetBlockingModeRefused() *BlockingModeREFUSED { + if x, ok := x.GetBlockingMode().(*Profile_BlockingModeRefused); ok { + return x.BlockingModeRefused + } + return nil +} + +func (x *Profile) GetProfileId() string { + if x != nil { + return x.ProfileId + } + return "" +} + +func (x *Profile) GetUpdateTime() *timestamppb.Timestamp { + if x != nil { + return x.UpdateTime + } + return nil +} + +func (x *Profile) GetDeviceIds() []string { + if x != nil { + return x.DeviceIds + } + return nil +} + +func (x *Profile) GetRuleListIds() []string { + if x != nil { + return x.RuleListIds + } + return nil +} + +func (x *Profile) GetCustomRules() []string { + if x != nil { + return x.CustomRules + } + return nil +} + +func (x *Profile) GetFilteredResponseTtl() *durationpb.Duration { + if x != nil { + return x.FilteredResponseTtl + } + return nil +} + +func (x *Profile) GetFilteringEnabled() bool { + if x != nil { + return x.FilteringEnabled + } + return false +} + +func (x *Profile) GetSafeBrowsingEnabled() bool { + if x != nil { + return x.SafeBrowsingEnabled + } + return false +} + +func (x *Profile) GetRuleListsEnabled() bool { + if x != nil { + return x.RuleListsEnabled + } + return false +} + +func (x *Profile) GetQueryLogEnabled() bool { + if x != nil { + return x.QueryLogEnabled + } + return false +} + +func (x *Profile) GetDeleted() bool { + if x != nil { + return x.Deleted + } + return false +} + +func (x *Profile) GetBlockPrivateRelay() bool { + if x != nil { + return x.BlockPrivateRelay + } + return false +} + +func (x *Profile) GetBlockFirefoxCanary() bool { + if x != nil { + return x.BlockFirefoxCanary + } + return false +} + +type isProfile_BlockingMode interface { + isProfile_BlockingMode() +} + +type Profile_BlockingModeCustomIp struct { + BlockingModeCustomIp *BlockingModeCustomIP `protobuf:"bytes,2,opt,name=blocking_mode_custom_ip,json=blockingModeCustomIp,proto3,oneof"` +} + +type Profile_BlockingModeNxdomain struct { + BlockingModeNxdomain *BlockingModeNXDOMAIN `protobuf:"bytes,3,opt,name=blocking_mode_nxdomain,json=blockingModeNxdomain,proto3,oneof"` +} + +type Profile_BlockingModeNullIp struct { + BlockingModeNullIp *BlockingModeNullIP `protobuf:"bytes,4,opt,name=blocking_mode_null_ip,json=blockingModeNullIp,proto3,oneof"` +} + +type Profile_BlockingModeRefused struct { + BlockingModeRefused *BlockingModeREFUSED `protobuf:"bytes,5,opt,name=blocking_mode_refused,json=blockingModeRefused,proto3,oneof"` +} + +func (*Profile_BlockingModeCustomIp) isProfile_BlockingMode() {} + +func (*Profile_BlockingModeNxdomain) isProfile_BlockingMode() {} + +func (*Profile_BlockingModeNullIp) isProfile_BlockingMode() {} + +func (*Profile_BlockingModeRefused) isProfile_BlockingMode() {} + +type ParentalProtectionSettings struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Schedule *ParentalProtectionSchedule `protobuf:"bytes,1,opt,name=schedule,proto3" json:"schedule,omitempty"` + BlockedServices []string `protobuf:"bytes,2,rep,name=blocked_services,json=blockedServices,proto3" json:"blocked_services,omitempty"` + Enabled bool `protobuf:"varint,3,opt,name=enabled,proto3" json:"enabled,omitempty"` + BlockAdult bool `protobuf:"varint,4,opt,name=block_adult,json=blockAdult,proto3" json:"block_adult,omitempty"` + GeneralSafeSearch bool `protobuf:"varint,5,opt,name=general_safe_search,json=generalSafeSearch,proto3" json:"general_safe_search,omitempty"` + YoutubeSafeSearch bool `protobuf:"varint,6,opt,name=youtube_safe_search,json=youtubeSafeSearch,proto3" json:"youtube_safe_search,omitempty"` +} + +func (x *ParentalProtectionSettings) Reset() { + *x = ParentalProtectionSettings{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ParentalProtectionSettings) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ParentalProtectionSettings) ProtoMessage() {} + +func (x *ParentalProtectionSettings) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ParentalProtectionSettings.ProtoReflect.Descriptor instead. +func (*ParentalProtectionSettings) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{2} +} + +func (x *ParentalProtectionSettings) GetSchedule() *ParentalProtectionSchedule { + if x != nil { + return x.Schedule + } + return nil +} + +func (x *ParentalProtectionSettings) GetBlockedServices() []string { + if x != nil { + return x.BlockedServices + } + return nil +} + +func (x *ParentalProtectionSettings) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *ParentalProtectionSettings) GetBlockAdult() bool { + if x != nil { + return x.BlockAdult + } + return false +} + +func (x *ParentalProtectionSettings) GetGeneralSafeSearch() bool { + if x != nil { + return x.GeneralSafeSearch + } + return false +} + +func (x *ParentalProtectionSettings) GetYoutubeSafeSearch() bool { + if x != nil { + return x.YoutubeSafeSearch + } + return false +} + +type ParentalProtectionSchedule struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TimeZone string `protobuf:"bytes,1,opt,name=time_zone,json=timeZone,proto3" json:"time_zone,omitempty"` + Mon *DayRange `protobuf:"bytes,2,opt,name=mon,proto3" json:"mon,omitempty"` + Tue *DayRange `protobuf:"bytes,3,opt,name=tue,proto3" json:"tue,omitempty"` + Wed *DayRange `protobuf:"bytes,4,opt,name=wed,proto3" json:"wed,omitempty"` + Thu *DayRange `protobuf:"bytes,5,opt,name=thu,proto3" json:"thu,omitempty"` + Fri *DayRange `protobuf:"bytes,6,opt,name=fri,proto3" json:"fri,omitempty"` + Sat *DayRange `protobuf:"bytes,7,opt,name=sat,proto3" json:"sat,omitempty"` + Sun *DayRange `protobuf:"bytes,8,opt,name=sun,proto3" json:"sun,omitempty"` +} + +func (x *ParentalProtectionSchedule) Reset() { + *x = ParentalProtectionSchedule{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ParentalProtectionSchedule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ParentalProtectionSchedule) ProtoMessage() {} + +func (x *ParentalProtectionSchedule) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ParentalProtectionSchedule.ProtoReflect.Descriptor instead. +func (*ParentalProtectionSchedule) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{3} +} + +func (x *ParentalProtectionSchedule) GetTimeZone() string { + if x != nil { + return x.TimeZone + } + return "" +} + +func (x *ParentalProtectionSchedule) GetMon() *DayRange { + if x != nil { + return x.Mon + } + return nil +} + +func (x *ParentalProtectionSchedule) GetTue() *DayRange { + if x != nil { + return x.Tue + } + return nil +} + +func (x *ParentalProtectionSchedule) GetWed() *DayRange { + if x != nil { + return x.Wed + } + return nil +} + +func (x *ParentalProtectionSchedule) GetThu() *DayRange { + if x != nil { + return x.Thu + } + return nil +} + +func (x *ParentalProtectionSchedule) GetFri() *DayRange { + if x != nil { + return x.Fri + } + return nil +} + +func (x *ParentalProtectionSchedule) GetSat() *DayRange { + if x != nil { + return x.Sat + } + return nil +} + +func (x *ParentalProtectionSchedule) GetSun() *DayRange { + if x != nil { + return x.Sun + } + return nil +} + +type DayRange struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` + End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"` +} + +func (x *DayRange) Reset() { + *x = DayRange{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DayRange) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DayRange) ProtoMessage() {} + +func (x *DayRange) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DayRange.ProtoReflect.Descriptor instead. +func (*DayRange) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{4} +} + +func (x *DayRange) GetStart() uint32 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *DayRange) GetEnd() uint32 { + if x != nil { + return x.End + } + return 0 +} + +type BlockingModeCustomIP struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Ipv4 []byte `protobuf:"bytes,1,opt,name=ipv4,proto3" json:"ipv4,omitempty"` + Ipv6 []byte `protobuf:"bytes,2,opt,name=ipv6,proto3" json:"ipv6,omitempty"` +} + +func (x *BlockingModeCustomIP) Reset() { + *x = BlockingModeCustomIP{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BlockingModeCustomIP) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BlockingModeCustomIP) ProtoMessage() {} + +func (x *BlockingModeCustomIP) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BlockingModeCustomIP.ProtoReflect.Descriptor instead. +func (*BlockingModeCustomIP) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{5} +} + +func (x *BlockingModeCustomIP) GetIpv4() []byte { + if x != nil { + return x.Ipv4 + } + return nil +} + +func (x *BlockingModeCustomIP) GetIpv6() []byte { + if x != nil { + return x.Ipv6 + } + return nil +} + +type BlockingModeNXDOMAIN struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *BlockingModeNXDOMAIN) Reset() { + *x = BlockingModeNXDOMAIN{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BlockingModeNXDOMAIN) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BlockingModeNXDOMAIN) ProtoMessage() {} + +func (x *BlockingModeNXDOMAIN) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BlockingModeNXDOMAIN.ProtoReflect.Descriptor instead. +func (*BlockingModeNXDOMAIN) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{6} +} + +type BlockingModeNullIP struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *BlockingModeNullIP) Reset() { + *x = BlockingModeNullIP{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BlockingModeNullIP) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BlockingModeNullIP) ProtoMessage() {} + +func (x *BlockingModeNullIP) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BlockingModeNullIP.ProtoReflect.Descriptor instead. +func (*BlockingModeNullIP) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{7} +} + +type BlockingModeREFUSED struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *BlockingModeREFUSED) Reset() { + *x = BlockingModeREFUSED{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BlockingModeREFUSED) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BlockingModeREFUSED) ProtoMessage() {} + +func (x *BlockingModeREFUSED) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BlockingModeREFUSED.ProtoReflect.Descriptor instead. +func (*BlockingModeREFUSED) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{8} +} + +type Device struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DeviceId string `protobuf:"bytes,1,opt,name=device_id,json=deviceId,proto3" json:"device_id,omitempty"` + LinkedIp []byte `protobuf:"bytes,2,opt,name=linked_ip,json=linkedIp,proto3" json:"linked_ip,omitempty"` + DeviceName string `protobuf:"bytes,3,opt,name=device_name,json=deviceName,proto3" json:"device_name,omitempty"` + DedicatedIps [][]byte `protobuf:"bytes,4,rep,name=dedicated_ips,json=dedicatedIps,proto3" json:"dedicated_ips,omitempty"` + FilteringEnabled bool `protobuf:"varint,5,opt,name=filtering_enabled,json=filteringEnabled,proto3" json:"filtering_enabled,omitempty"` +} + +func (x *Device) Reset() { + *x = Device{} + if protoimpl.UnsafeEnabled { + mi := &file_filecache_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Device) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Device) ProtoMessage() {} + +func (x *Device) ProtoReflect() protoreflect.Message { + mi := &file_filecache_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Device.ProtoReflect.Descriptor instead. +func (*Device) Descriptor() ([]byte, []int) { + return file_filecache_proto_rawDescGZIP(), []int{9} +} + +func (x *Device) GetDeviceId() string { + if x != nil { + return x.DeviceId + } + return "" +} + +func (x *Device) GetLinkedIp() []byte { + if x != nil { + return x.LinkedIp + } + return nil +} + +func (x *Device) GetDeviceName() string { + if x != nil { + return x.DeviceName + } + return "" +} + +func (x *Device) GetDedicatedIps() [][]byte { + if x != nil { + return x.DedicatedIps + } + return nil +} + +func (x *Device) GetFilteringEnabled() bool { + if x != nil { + return x.FilteringEnabled + } + return false +} + +var File_filecache_proto protoreflect.FileDescriptor + +var file_filecache_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x66, 0x69, 0x6c, 0x65, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x09, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x1a, 0x1e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xbb, 0x01, + 0x0a, 0x09, 0x46, 0x69, 0x6c, 0x65, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x37, 0x0a, 0x09, 0x73, + 0x79, 0x6e, 0x63, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x08, 0x73, 0x79, 0x6e, 0x63, + 0x54, 0x69, 0x6d, 0x65, 0x12, 0x2e, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, + 0x64, 0x62, 0x2e, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x66, + 0x69, 0x6c, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x07, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, + 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, + 0x62, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x07, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x82, 0x08, 0x0a, 0x07, + 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x41, 0x0a, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, + 0x74, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x70, 0x72, 0x6f, 0x66, + 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x50, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, + 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, + 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x12, 0x58, 0x0a, 0x17, 0x62, 0x6c, + 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x70, 0x72, + 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, + 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x49, 0x50, 0x48, 0x00, 0x52, 0x14, + 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x49, 0x70, 0x12, 0x57, 0x0a, 0x16, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, + 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x6e, 0x78, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, + 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x58, 0x44, + 0x4f, 0x4d, 0x41, 0x49, 0x4e, 0x48, 0x00, 0x52, 0x14, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, + 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x78, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x52, 0x0a, + 0x15, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x6e, + 0x75, 0x6c, 0x6c, 0x5f, 0x69, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x70, + 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, + 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49, 0x50, 0x48, 0x00, 0x52, 0x12, 0x62, + 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49, + 0x70, 0x12, 0x54, 0x0a, 0x15, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, + 0x64, 0x65, 0x5f, 0x72, 0x65, 0x66, 0x75, 0x73, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1e, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f, + 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, + 0x48, 0x00, 0x52, 0x13, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, + 0x52, 0x65, 0x66, 0x75, 0x73, 0x65, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x66, 0x69, + 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x6f, + 0x66, 0x69, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x3b, 0x0a, 0x0b, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, + 0x69, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, + 0x64, 0x73, 0x12, 0x22, 0x0a, 0x0d, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x5f, + 0x69, 0x64, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x75, 0x6c, 0x65, 0x4c, + 0x69, 0x73, 0x74, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5f, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x4d, 0x0a, 0x15, 0x66, 0x69, 0x6c, + 0x74, 0x65, 0x72, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74, + 0x74, 0x6c, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x13, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x65, 0x64, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x54, 0x74, 0x6c, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6c, 0x74, + 0x65, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0c, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x15, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x62, 0x72, + 0x6f, 0x77, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0d, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x73, 0x61, 0x66, 0x65, 0x42, 0x72, 0x6f, 0x77, 0x73, 0x69, + 0x6e, 0x67, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2c, 0x0a, 0x12, 0x72, 0x75, 0x6c, + 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x73, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x75, 0x6c, 0x65, 0x4c, 0x69, 0x73, 0x74, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x71, 0x75, 0x65, 0x72, 0x79, + 0x5f, 0x6c, 0x6f, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x4c, 0x6f, 0x67, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x10, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x12, 0x2e, 0x0a, + 0x13, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x5f, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x5f, 0x72, + 0x65, 0x6c, 0x61, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x62, 0x6c, 0x6f, 0x63, + 0x6b, 0x50, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x30, 0x0a, + 0x14, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x66, 0x6f, 0x78, 0x5f, 0x63, + 0x61, 0x6e, 0x61, 0x72, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x62, 0x6c, 0x6f, + 0x63, 0x6b, 0x46, 0x69, 0x72, 0x65, 0x66, 0x6f, 0x78, 0x43, 0x61, 0x6e, 0x61, 0x72, 0x79, 0x42, + 0x0f, 0x0a, 0x0d, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, + 0x22, 0xa5, 0x02, 0x0a, 0x1a, 0x50, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f, + 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, + 0x41, 0x0a, 0x08, 0x73, 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x25, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x50, 0x61, + 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x53, 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x52, 0x08, 0x73, 0x63, 0x68, 0x65, 0x64, 0x75, + 0x6c, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x62, 0x6c, + 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x18, 0x0a, + 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x62, 0x6c, 0x6f, 0x63, 0x6b, + 0x5f, 0x61, 0x64, 0x75, 0x6c, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x62, 0x6c, + 0x6f, 0x63, 0x6b, 0x41, 0x64, 0x75, 0x6c, 0x74, 0x12, 0x2e, 0x0a, 0x13, 0x67, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x6c, 0x5f, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x53, 0x61, + 0x66, 0x65, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x12, 0x2e, 0x0a, 0x13, 0x79, 0x6f, 0x75, 0x74, + 0x75, 0x62, 0x65, 0x5f, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x79, 0x6f, 0x75, 0x74, 0x75, 0x62, 0x65, 0x53, 0x61, + 0x66, 0x65, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x22, 0xca, 0x02, 0x0a, 0x1a, 0x50, 0x61, 0x72, + 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, + 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x5f, + 0x7a, 0x6f, 0x6e, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x69, 0x6d, 0x65, + 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x25, 0x0a, 0x03, 0x6d, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, + 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x6d, 0x6f, 0x6e, 0x12, 0x25, 0x0a, 0x03, 0x74, + 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, + 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x74, + 0x75, 0x65, 0x12, 0x25, 0x0a, 0x03, 0x77, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x77, 0x65, 0x64, 0x12, 0x25, 0x0a, 0x03, 0x74, 0x68, 0x75, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, + 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x74, 0x68, 0x75, + 0x12, 0x25, 0x0a, 0x03, 0x66, 0x72, 0x69, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, + 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x52, 0x03, 0x66, 0x72, 0x69, 0x12, 0x25, 0x0a, 0x03, 0x73, 0x61, 0x74, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, + 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x73, 0x61, 0x74, 0x12, 0x25, + 0x0a, 0x03, 0x73, 0x75, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, + 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x52, 0x03, 0x73, 0x75, 0x6e, 0x22, 0x32, 0x0a, 0x08, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x22, 0x3e, 0x0a, 0x14, 0x42, 0x6c, 0x6f, + 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x49, + 0x50, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x22, 0x16, 0x0a, 0x14, 0x42, 0x6c, 0x6f, + 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x58, 0x44, 0x4f, 0x4d, 0x41, 0x49, + 0x4e, 0x22, 0x14, 0x0a, 0x12, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, + 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49, 0x50, 0x22, 0x15, 0x0a, 0x13, 0x42, 0x6c, 0x6f, 0x63, 0x6b, + 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x22, 0xb5, + 0x01, 0x0a, 0x06, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6c, 0x69, 0x6e, 0x6b, 0x65, 0x64, + 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x6c, 0x69, 0x6e, 0x6b, 0x65, + 0x64, 0x49, 0x70, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65, + 0x64, 0x5f, 0x69, 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0c, 0x64, 0x65, 0x64, + 0x69, 0x63, 0x61, 0x74, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6c, + 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x0f, 0x5a, 0x0d, 0x2e, 0x2f, 0x66, 0x69, 0x6c, 0x65, + 0x63, 0x61, 0x63, 0x68, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_filecache_proto_rawDescOnce sync.Once + file_filecache_proto_rawDescData = file_filecache_proto_rawDesc +) + +func file_filecache_proto_rawDescGZIP() []byte { + file_filecache_proto_rawDescOnce.Do(func() { + file_filecache_proto_rawDescData = protoimpl.X.CompressGZIP(file_filecache_proto_rawDescData) + }) + return file_filecache_proto_rawDescData +} + +var file_filecache_proto_msgTypes = make([]protoimpl.MessageInfo, 10) +var file_filecache_proto_goTypes = []interface{}{ + (*FileCache)(nil), // 0: profiledb.FileCache + (*Profile)(nil), // 1: profiledb.Profile + (*ParentalProtectionSettings)(nil), // 2: profiledb.ParentalProtectionSettings + (*ParentalProtectionSchedule)(nil), // 3: profiledb.ParentalProtectionSchedule + (*DayRange)(nil), // 4: profiledb.DayRange + (*BlockingModeCustomIP)(nil), // 5: profiledb.BlockingModeCustomIP + (*BlockingModeNXDOMAIN)(nil), // 6: profiledb.BlockingModeNXDOMAIN + (*BlockingModeNullIP)(nil), // 7: profiledb.BlockingModeNullIP + (*BlockingModeREFUSED)(nil), // 8: profiledb.BlockingModeREFUSED + (*Device)(nil), // 9: profiledb.Device + (*timestamppb.Timestamp)(nil), // 10: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 11: google.protobuf.Duration +} +var file_filecache_proto_depIdxs = []int32{ + 10, // 0: profiledb.FileCache.sync_time:type_name -> google.protobuf.Timestamp + 1, // 1: profiledb.FileCache.profiles:type_name -> profiledb.Profile + 9, // 2: profiledb.FileCache.devices:type_name -> profiledb.Device + 2, // 3: profiledb.Profile.parental:type_name -> profiledb.ParentalProtectionSettings + 5, // 4: profiledb.Profile.blocking_mode_custom_ip:type_name -> profiledb.BlockingModeCustomIP + 6, // 5: profiledb.Profile.blocking_mode_nxdomain:type_name -> profiledb.BlockingModeNXDOMAIN + 7, // 6: profiledb.Profile.blocking_mode_null_ip:type_name -> profiledb.BlockingModeNullIP + 8, // 7: profiledb.Profile.blocking_mode_refused:type_name -> profiledb.BlockingModeREFUSED + 10, // 8: profiledb.Profile.update_time:type_name -> google.protobuf.Timestamp + 11, // 9: profiledb.Profile.filtered_response_ttl:type_name -> google.protobuf.Duration + 3, // 10: profiledb.ParentalProtectionSettings.schedule:type_name -> profiledb.ParentalProtectionSchedule + 4, // 11: profiledb.ParentalProtectionSchedule.mon:type_name -> profiledb.DayRange + 4, // 12: profiledb.ParentalProtectionSchedule.tue:type_name -> profiledb.DayRange + 4, // 13: profiledb.ParentalProtectionSchedule.wed:type_name -> profiledb.DayRange + 4, // 14: profiledb.ParentalProtectionSchedule.thu:type_name -> profiledb.DayRange + 4, // 15: profiledb.ParentalProtectionSchedule.fri:type_name -> profiledb.DayRange + 4, // 16: profiledb.ParentalProtectionSchedule.sat:type_name -> profiledb.DayRange + 4, // 17: profiledb.ParentalProtectionSchedule.sun:type_name -> profiledb.DayRange + 18, // [18:18] is the sub-list for method output_type + 18, // [18:18] is the sub-list for method input_type + 18, // [18:18] is the sub-list for extension type_name + 18, // [18:18] is the sub-list for extension extendee + 0, // [0:18] is the sub-list for field type_name +} + +func init() { file_filecache_proto_init() } +func file_filecache_proto_init() { + if File_filecache_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_filecache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*FileCache); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Profile); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ParentalProtectionSettings); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ParentalProtectionSchedule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DayRange); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BlockingModeCustomIP); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BlockingModeNXDOMAIN); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BlockingModeNullIP); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BlockingModeREFUSED); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_filecache_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Device); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_filecache_proto_msgTypes[1].OneofWrappers = []interface{}{ + (*Profile_BlockingModeCustomIp)(nil), + (*Profile_BlockingModeNxdomain)(nil), + (*Profile_BlockingModeNullIp)(nil), + (*Profile_BlockingModeRefused)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_filecache_proto_rawDesc, + NumEnums: 0, + NumMessages: 10, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_filecache_proto_goTypes, + DependencyIndexes: file_filecache_proto_depIdxs, + MessageInfos: file_filecache_proto_msgTypes, + }.Build() + File_filecache_proto = out.File + file_filecache_proto_rawDesc = nil + file_filecache_proto_goTypes = nil + file_filecache_proto_depIdxs = nil +} diff --git a/internal/profiledb/internal/filecachepb/filecache.proto b/internal/profiledb/internal/filecachepb/filecache.proto new file mode 100644 index 0000000..6fd8f68 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/filecache.proto @@ -0,0 +1,82 @@ +syntax = "proto3"; + +package profiledb; + +option go_package = "./filecachepb"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + +message FileCache { + google.protobuf.Timestamp sync_time = 1; + repeated Profile profiles = 2; + repeated Device devices = 3; + int32 version = 4; +} + +message Profile { + ParentalProtectionSettings parental = 1; + oneof blocking_mode { + BlockingModeCustomIP blocking_mode_custom_ip = 2; + BlockingModeNXDOMAIN blocking_mode_nxdomain = 3; + BlockingModeNullIP blocking_mode_null_ip = 4; + BlockingModeREFUSED blocking_mode_refused = 5; + } + string profile_id = 6; + google.protobuf.Timestamp update_time = 7; + repeated string device_ids = 8; + repeated string rule_list_ids = 9; + repeated string custom_rules = 10; + google.protobuf.Duration filtered_response_ttl = 11; + bool filtering_enabled = 12; + bool safe_browsing_enabled = 13; + bool rule_lists_enabled = 14; + bool query_log_enabled = 15; + bool deleted = 16; + bool block_private_relay = 17; + bool block_firefox_canary = 18; +} + +message ParentalProtectionSettings { + ParentalProtectionSchedule schedule = 1; + repeated string blocked_services = 2; + bool enabled = 3; + bool block_adult = 4; + bool general_safe_search = 5; + bool youtube_safe_search = 6; +} + +message ParentalProtectionSchedule { + string time_zone = 1; + DayRange mon = 2; + DayRange tue = 3; + DayRange wed = 4; + DayRange thu = 5; + DayRange fri = 6; + DayRange sat = 7; + DayRange sun = 8; +} + +message DayRange { + uint32 start = 1; + uint32 end = 2; +} + +message BlockingModeCustomIP { + bytes ipv4 = 1; + bytes ipv6 = 2; +} + +message BlockingModeNXDOMAIN {} + +message BlockingModeNullIP {} + +message BlockingModeREFUSED {} + +message Device { + string device_id = 1; + bytes linked_ip = 2; + string device_name = 3; + repeated bytes dedicated_ips = 4; + bool filtering_enabled = 5; +} diff --git a/internal/profiledb/internal/filecachepb/filecachepb.go b/internal/profiledb/internal/filecachepb/filecachepb.go new file mode 100644 index 0000000..71a8438 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/filecachepb.go @@ -0,0 +1,388 @@ +// Package filecachepb contains the protobuf structures for the profile cache. +package filecachepb + +import ( + "fmt" + "net/netip" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// toInternal converts the protobuf-encoded data into a cache structure. +func toInternal(fc *FileCache) (c *internal.FileCache, err error) { + profiles, err := profilesToInternal(fc.Profiles) + if err != nil { + return nil, fmt.Errorf("converting profiles: %w", err) + } + + devices, err := devicesFromProtobuf(fc.Devices) + if err != nil { + return nil, fmt.Errorf("converting devices: %w", err) + } + + return &internal.FileCache{ + SyncTime: fc.SyncTime.AsTime(), + Profiles: profiles, + Devices: devices, + Version: fc.Version, + }, nil +} + +// toProtobuf converts the cache structure into protobuf structure for encoding. +func toProtobuf(c *internal.FileCache) (pbFileCache *FileCache) { + return &FileCache{ + SyncTime: timestamppb.New(c.SyncTime), + Profiles: profilesToProtobuf(c.Profiles), + Devices: devicesToProtobuf(c.Devices), + Version: c.Version, + } +} + +// profilesToInternal converts protobuf profile structures into internal ones. +func profilesToInternal(pbProfiles []*Profile) (profiles []*agd.Profile, err error) { + profiles = make([]*agd.Profile, 0, len(pbProfiles)) + for i, pbProf := range pbProfiles { + var prof *agd.Profile + prof, err = pbProf.toInternal() + if err != nil { + return nil, fmt.Errorf("profile at index %d: %w", i, err) + } + + profiles = append(profiles, prof) + } + + return profiles, nil +} + +// toInternal converts a protobuf profile structure to an internal one. +func (x *Profile) toInternal() (prof *agd.Profile, err error) { + parental, err := x.Parental.toInternal() + if err != nil { + return nil, fmt.Errorf("parental: %w", err) + } + + m, err := blockingModeToInternal(x.BlockingMode) + if err != nil { + return nil, fmt.Errorf("blocking mode: %w", err) + } + + return &agd.Profile{ + Parental: parental, + BlockingMode: m, + ID: agd.ProfileID(x.ProfileId), + UpdateTime: x.UpdateTime.AsTime(), + // Consider device IDs to have been prevalidated. + DeviceIDs: unsafelyConvertStrSlice[string, agd.DeviceID](x.DeviceIds), + // Consider rule-list IDs to have been prevalidated. + RuleListIDs: unsafelyConvertStrSlice[string, agd.FilterListID](x.RuleListIds), + // Consider rule-list IDs to have been prevalidated. + CustomRules: unsafelyConvertStrSlice[string, agd.FilterRuleText](x.CustomRules), + FilteredResponseTTL: x.FilteredResponseTtl.AsDuration(), + FilteringEnabled: x.FilteringEnabled, + SafeBrowsingEnabled: x.SafeBrowsingEnabled, + RuleListsEnabled: x.RuleListsEnabled, + QueryLogEnabled: x.QueryLogEnabled, + Deleted: x.Deleted, + BlockPrivateRelay: x.BlockPrivateRelay, + BlockFirefoxCanary: x.BlockFirefoxCanary, + }, nil +} + +// toInternal converts a protobuf parental-settings structure to an internal +// one. +func (x *ParentalProtectionSettings) toInternal() (s *agd.ParentalProtectionSettings, err error) { + if x == nil { + return nil, nil + } + + schedule, err := x.Schedule.toInternal() + if err != nil { + return nil, fmt.Errorf("schedule: %w", err) + } + + return &agd.ParentalProtectionSettings{ + Schedule: schedule, + // Consider block service IDs to have been prevalidated. + BlockedServices: unsafelyConvertStrSlice[string, agd.BlockedServiceID]( + x.BlockedServices, + ), + Enabled: x.Enabled, + BlockAdult: x.BlockAdult, + GeneralSafeSearch: x.GeneralSafeSearch, + YoutubeSafeSearch: x.YoutubeSafeSearch, + }, nil +} + +// toInternal converts a protobuf protection-schedule structure to an internal +// one. +func (x *ParentalProtectionSchedule) toInternal() (s *agd.ParentalProtectionSchedule, err error) { + if x == nil { + return nil, nil + } + + loc, err := agdtime.LoadLocation(x.TimeZone) + if err != nil { + return nil, fmt.Errorf("time zone: %w", err) + } + + return &agd.ParentalProtectionSchedule{ + // Consider the lengths to be prevalidated. + Week: &agd.WeeklySchedule{ + time.Monday: {Start: uint16(x.Mon.Start), End: uint16(x.Mon.End)}, + time.Tuesday: {Start: uint16(x.Tue.Start), End: uint16(x.Tue.End)}, + time.Wednesday: {Start: uint16(x.Wed.Start), End: uint16(x.Wed.End)}, + time.Thursday: {Start: uint16(x.Thu.Start), End: uint16(x.Thu.End)}, + time.Friday: {Start: uint16(x.Fri.Start), End: uint16(x.Fri.End)}, + time.Saturday: {Start: uint16(x.Sat.Start), End: uint16(x.Sat.End)}, + time.Sunday: {Start: uint16(x.Sun.Start), End: uint16(x.Sun.End)}, + }, + TimeZone: loc, + }, nil +} + +// blockingModeToInternal converts a protobuf blocking-mode sum-type to an +// internal one. +func blockingModeToInternal( + pbBlockingMode isProfile_BlockingMode, +) (m dnsmsg.BlockingModeCodec, err error) { + switch pbm := pbBlockingMode.(type) { + case *Profile_BlockingModeCustomIp: + custom := &dnsmsg.BlockingModeCustomIP{} + err = custom.IPv4.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv4) + if err != nil { + return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv4: %w", err) + } + + err = custom.IPv6.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv6) + if err != nil { + return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv6: %w", err) + } + + m.Mode = custom + case *Profile_BlockingModeNxdomain: + m.Mode = &dnsmsg.BlockingModeNXDOMAIN{} + case *Profile_BlockingModeNullIp: + m.Mode = &dnsmsg.BlockingModeNullIP{} + case *Profile_BlockingModeRefused: + m.Mode = &dnsmsg.BlockingModeREFUSED{} + default: + // Consider unhandled type-switch cases programmer errors. + panic(fmt.Errorf("bad pb blocking mode %T(%[1]v)", m)) + } + + return m, nil +} + +// devicesToInternal converts protobuf device structures into internal ones. +func devicesFromProtobuf(pbDevices []*Device) (devices []*agd.Device, err error) { + devices = make([]*agd.Device, 0, len(pbDevices)) + for i, pbDev := range pbDevices { + var dev *agd.Device + dev, err = pbDev.toInternal() + if err != nil { + return nil, fmt.Errorf("device at index %d: %w", i, err) + } + + devices = append(devices, dev) + } + + return devices, nil +} + +// toInternal converts a protobuf device structure to an internal one. +func (x *Device) toInternal() (d *agd.Device, err error) { + var linkedIP netip.Addr + err = linkedIP.UnmarshalBinary(x.LinkedIp) + if err != nil { + return nil, fmt.Errorf("linked ip: %w", err) + } + + var dedicatedIPs []netip.Addr + dedicatedIPs, err = byteSlicesToIPs(x.DedicatedIps) + if err != nil { + return nil, fmt.Errorf("dedicated ips: %w", err) + } + + return &agd.Device{ + // Consider device IDs to have been prevalidated. + ID: agd.DeviceID(x.DeviceId), + LinkedIP: linkedIP, + // Consider device names to have been prevalidated. + Name: agd.DeviceName(x.DeviceName), + DedicatedIPs: dedicatedIPs, + FilteringEnabled: x.FilteringEnabled, + }, nil +} + +// byteSlicesToIPs converts a slice of byte slices into a slice of netip.Addrs. +func byteSlicesToIPs(data [][]byte) (ips []netip.Addr, err error) { + if data == nil { + return nil, nil + } + + ips = make([]netip.Addr, 0, len(data)) + for i, ipData := range data { + var ip netip.Addr + err = ip.UnmarshalBinary(ipData) + if err != nil { + return nil, fmt.Errorf("ip at index %d: %w", i, err) + } + + ips = append(ips, ip) + } + + return ips, nil +} + +// profilesToProtobuf converts a slice of profiles to protobuf structures. +func profilesToProtobuf(profiles []*agd.Profile) (pbProfiles []*Profile) { + pbProfiles = make([]*Profile, 0, len(profiles)) + for _, p := range profiles { + pbProfiles = append(pbProfiles, &Profile{ + Parental: parentalToProtobuf(p.Parental), + BlockingMode: blockingModeToProtobuf(p.BlockingMode), + ProfileId: string(p.ID), + UpdateTime: timestamppb.New(p.UpdateTime), + DeviceIds: unsafelyConvertStrSlice[agd.DeviceID, string](p.DeviceIDs), + RuleListIds: unsafelyConvertStrSlice[agd.FilterListID, string](p.RuleListIDs), + CustomRules: unsafelyConvertStrSlice[agd.FilterRuleText, string](p.CustomRules), + FilteredResponseTtl: durationpb.New(p.FilteredResponseTTL), + FilteringEnabled: p.FilteringEnabled, + SafeBrowsingEnabled: p.SafeBrowsingEnabled, + RuleListsEnabled: p.RuleListsEnabled, + QueryLogEnabled: p.QueryLogEnabled, + Deleted: p.Deleted, + BlockPrivateRelay: p.BlockPrivateRelay, + BlockFirefoxCanary: p.BlockFirefoxCanary, + }) + } + + return pbProfiles +} + +// parentalToProtobuf converts parental settings to protobuf structure. +func parentalToProtobuf(s *agd.ParentalProtectionSettings) (pbSetts *ParentalProtectionSettings) { + if s == nil { + return nil + } + + return &ParentalProtectionSettings{ + Schedule: scheduleToProtobuf(s.Schedule), + BlockedServices: unsafelyConvertStrSlice[agd.BlockedServiceID, string](s.BlockedServices), + Enabled: s.Enabled, + BlockAdult: s.BlockAdult, + GeneralSafeSearch: s.GeneralSafeSearch, + YoutubeSafeSearch: s.YoutubeSafeSearch, + } +} + +// parentalToProtobuf converts parental-settings schedule to protobuf structure. +func scheduleToProtobuf(s *agd.ParentalProtectionSchedule) (pbSched *ParentalProtectionSchedule) { + if s == nil { + return nil + } + + return &ParentalProtectionSchedule{ + TimeZone: s.TimeZone.String(), + Mon: &DayRange{ + Start: uint32(s.Week[time.Monday].Start), + End: uint32(s.Week[time.Monday].End), + }, + Tue: &DayRange{ + Start: uint32(s.Week[time.Tuesday].Start), + End: uint32(s.Week[time.Tuesday].End), + }, + Wed: &DayRange{ + Start: uint32(s.Week[time.Wednesday].Start), + End: uint32(s.Week[time.Wednesday].End), + }, + Thu: &DayRange{ + Start: uint32(s.Week[time.Thursday].Start), + End: uint32(s.Week[time.Thursday].End), + }, + Fri: &DayRange{ + Start: uint32(s.Week[time.Friday].Start), + End: uint32(s.Week[time.Friday].End), + }, + Sat: &DayRange{ + Start: uint32(s.Week[time.Saturday].Start), + End: uint32(s.Week[time.Saturday].End), + }, + Sun: &DayRange{ + Start: uint32(s.Week[time.Sunday].Start), + End: uint32(s.Week[time.Sunday].End), + }, + } +} + +// blockingModeToProtobuf converts a blocking-mode sum-type to a protobuf one. +func blockingModeToProtobuf(m dnsmsg.BlockingModeCodec) (pbBlockingMode isProfile_BlockingMode) { + switch m := m.Mode.(type) { + case *dnsmsg.BlockingModeCustomIP: + return &Profile_BlockingModeCustomIp{ + BlockingModeCustomIp: &BlockingModeCustomIP{ + Ipv4: ipToBytes(m.IPv4), + Ipv6: ipToBytes(m.IPv6), + }, + } + case *dnsmsg.BlockingModeNXDOMAIN: + return &Profile_BlockingModeNxdomain{ + BlockingModeNxdomain: &BlockingModeNXDOMAIN{}, + } + case *dnsmsg.BlockingModeNullIP: + return &Profile_BlockingModeNullIp{ + BlockingModeNullIp: &BlockingModeNullIP{}, + } + case *dnsmsg.BlockingModeREFUSED: + return &Profile_BlockingModeRefused{ + BlockingModeRefused: &BlockingModeREFUSED{}, + } + default: + panic(fmt.Errorf("bad blocking mode %T(%[1]v)", m)) + } +} + +// ipToBytes is a wrapper around netip.Addr.MarshalBinary that ignores the +// always-nil error. +func ipToBytes(ip netip.Addr) (b []byte) { + b, _ = ip.MarshalBinary() + + return b +} + +// devicesToProtobuf converts a slice of devices to protobuf structures. +func devicesToProtobuf(devices []*agd.Device) (pbDevices []*Device) { + pbDevices = make([]*Device, 0, len(devices)) + for _, d := range devices { + pbDevices = append(pbDevices, &Device{ + DeviceId: string(d.ID), + LinkedIp: ipToBytes(d.LinkedIP), + DeviceName: string(d.Name), + DedicatedIps: ipsToByteSlices(d.DedicatedIPs), + FilteringEnabled: d.FilteringEnabled, + }) + } + + return pbDevices +} + +// ipsToByteSlices is a wrapper around netip.Addr.MarshalBinary that ignores the +// always-nil errors. +func ipsToByteSlices(ips []netip.Addr) (data [][]byte) { + if ips == nil { + return nil + } + + data = make([][]byte, 0, len(ips)) + for _, ip := range ips { + data = append(data, ipToBytes(ip)) + } + + return data +} diff --git a/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go b/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go new file mode 100644 index 0000000..9658447 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go @@ -0,0 +1,121 @@ +package filecachepb + +import ( + "encoding/json" + "os" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// Sinks for benchmarks +var ( + bytesSink []byte + cacheSink = &internal.FileCache{} + errSink error + fileCacheSink = &FileCache{} +) + +// envVarName is the environment variable name the presence and value of which +// define whether to run the benchmarks with the data from the given file. +// +// The path should be an absolute path. +const envVarName = "ADGUARD_DNS_TEST_PROFILEDB_JSON" + +// setCacheSink is a helper that allows using a prepared JSON file for loading +// the data for benchmarks from the environment. +func setCacheSink(tb testing.TB) { + tb.Helper() + + filePath := os.Getenv(envVarName) + if filePath == "" { + prof, dev := profiledbtest.NewProfile(tb) + cacheSink = &internal.FileCache{ + SyncTime: time.Now().Round(0).UTC(), + Profiles: []*agd.Profile{prof}, + Devices: []*agd.Device{dev}, + Version: internal.FileCacheVersion, + } + + return + } + + tb.Logf("using %q as source for profiledb data", filePath) + + data, err := os.ReadFile(filePath) + require.NoError(tb, err) + + err = json.Unmarshal(data, cacheSink) + require.NoError(tb, err) +} + +func BenchmarkCache(b *testing.B) { + setCacheSink(b) + + b.Run("to_protobuf", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fileCacheSink = toProtobuf(cacheSink) + } + + require.NoError(b, errSink) + require.NotEmpty(b, fileCacheSink) + }) + + b.Run("from_protobuf", func(b *testing.B) { + var gotCache *internal.FileCache + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + gotCache, errSink = toInternal(fileCacheSink) + } + + require.NoError(b, errSink) + require.NotEmpty(b, gotCache) + }) + + b.Run("encode", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bytesSink, errSink = proto.Marshal(fileCacheSink) + } + + require.NoError(b, errSink) + require.NotEmpty(b, bytesSink) + }) + + b.Run("decode", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errSink = proto.Unmarshal(bytesSink, fileCacheSink) + } + + require.NoError(b, errSink) + require.NotEmpty(b, fileCacheSink) + }) + + // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkCache/to_protobuf-16 674397 1657 ns/op 1240 B/op 22 allocs/op + // BenchmarkCache/from_protobuf-16 83577 14285 ns/op 7400 B/op 29 allocs/op + // BenchmarkCache/encode-16 563797 1984 ns/op 208 B/op 1 allocs/op + // BenchmarkCache/decode-16 273951 5143 ns/op 1288 B/op 31 allocs/op +} diff --git a/internal/profiledb/internal/filecachepb/storage.go b/internal/profiledb/internal/filecachepb/storage.go new file mode 100644 index 0000000..4e2eb96 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/storage.go @@ -0,0 +1,74 @@ +package filecachepb + +import ( + "fmt" + "os" + + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/google/renameio" + "google.golang.org/protobuf/proto" +) + +// Storage is the file-cache storage that encodes data using protobuf. +type Storage struct { + path string +} + +// New returns a new protobuf-encoded file-cache storage. +func New(cachePath string) (s *Storage) { + return &Storage{ + path: cachePath, + } +} + +// logPrefix is the logging prefix for the protobuf-encoded file-cache. +const logPrefix = "profiledb protobuf cache" + +var _ internal.FileCacheStorage = (*Storage)(nil) + +// Load implements the [internal.FileCacheStorage] interface for *Storage. +func (s *Storage) Load() (c *internal.FileCache, err error) { + log.Info("%s: loading", logPrefix) + + b, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + log.Info("%s: file not present", logPrefix) + + return nil, nil + } + + return nil, err + } + + fc := &FileCache{} + err = proto.Unmarshal(b, fc) + if err != nil { + return nil, fmt.Errorf("decoding protobuf: %w", err) + } + + return toInternal(fc) +} + +// Store implements the [internal.FileCacheStorage] interface for *Storage. +func (s *Storage) Store(c *internal.FileCache) (err error) { + profNum := len(c.Profiles) + log.Info("%s: saving %d profiles to %q", logPrefix, profNum, s.path) + defer log.Info("%s: saved %d profiles to %q", logPrefix, profNum, s.path) + + fc := toProtobuf(c) + b, err := proto.Marshal(fc) + if err != nil { + return fmt.Errorf("encoding protobuf: %w", err) + } + + err = renameio.WriteFile(s.path, b, 0o600) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + return nil +} diff --git a/internal/profiledb/internal/filecachepb/storage_test.go b/internal/profiledb/internal/filecachepb/storage_test.go new file mode 100644 index 0000000..b9fc366 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/storage_test.go @@ -0,0 +1,48 @@ +package filecachepb_test + +import ( + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStorage(t *testing.T) { + prof, dev := profiledbtest.NewProfile(t) + cachePath := filepath.Join(t.TempDir(), "profiles.pb") + s := filecachepb.New(cachePath) + require.NotNil(t, s) + + fc := &internal.FileCache{ + SyncTime: time.Now().Round(0).UTC(), + Profiles: []*agd.Profile{prof}, + Devices: []*agd.Device{dev}, + Version: internal.FileCacheVersion, + } + + err := s.Store(fc) + require.NoError(t, err) + + gotFC, err := s.Load() + require.NoError(t, err) + require.NotNil(t, gotFC) + require.NotEmpty(t, *gotFC) + + assert.Equal(t, fc, gotFC) +} + +func TestStorage_Load_noFile(t *testing.T) { + cachePath := filepath.Join(t.TempDir(), "profiles.pb") + s := filecachepb.New(cachePath) + require.NotNil(t, s) + + fc, err := s.Load() + assert.NoError(t, err) + assert.Nil(t, fc) +} diff --git a/internal/profiledb/internal/filecachepb/unsafe.go b/internal/profiledb/internal/filecachepb/unsafe.go new file mode 100644 index 0000000..7578533 --- /dev/null +++ b/internal/profiledb/internal/filecachepb/unsafe.go @@ -0,0 +1,17 @@ +package filecachepb + +import "unsafe" + +// unsafelyConvertStrSlice checks if []T1 can be converted to []T2 at compile +// time and, if so, converts the slice using package unsafe. +// +// Slices resulting from this conversion must not be mutated. +func unsafelyConvertStrSlice[T1, T2 ~string](s []T1) (res []T2) { + if s == nil { + return nil + } + + // #nosec G103 -- Conversion between two slices with the same underlying + // element type is safe. + return *(*[]T2)(unsafe.Pointer(&s)) +} diff --git a/internal/profiledb/internal/internal.go b/internal/profiledb/internal/internal.go new file mode 100644 index 0000000..2d2705a --- /dev/null +++ b/internal/profiledb/internal/internal.go @@ -0,0 +1,48 @@ +// Package internal contains common constants and types that all implementations +// of the default profile-cache use. +package internal + +import ( + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" +) + +// FileCacheVersion is the version of cached data structure. It must be +// manually incremented on every change in [agd.Device], [agd.Profile], and any +// file-cache structures. +const FileCacheVersion = 6 + +// FileCache contains the data that is cached on the filesystem. +type FileCache struct { + SyncTime time.Time + Profiles []*agd.Profile + Devices []*agd.Device + Version int32 +} + +// FileCacheStorage is the interface for all file caches. +type FileCacheStorage interface { + // Load read the data from the cache file. If the file does not exist, Load + // must return a nil *FileCache. Load must return an informative error. + Load() (c *FileCache, err error) + + // Store writes the data to the cache file. c must not be nil. Store must + // return an informative error. + Store(c *FileCache) (err error) +} + +// EmptyFileCacheStorage is the empty file-cache storage that does nothing and +// returns nils. +type EmptyFileCacheStorage struct{} + +// type check +var _ FileCacheStorage = EmptyFileCacheStorage{} + +// Load implements the [FileCacheStorage] interface for EmptyFileCacheStorage. +// It does nothing and returns nils. +func (EmptyFileCacheStorage) Load() (_ *FileCache, _ error) { return nil, nil } + +// Store implements the [FileCacheStorage] interface for EmptyFileCacheStorage. +// It does nothing and returns nil. +func (EmptyFileCacheStorage) Store(_ *FileCache) (_ error) { return nil } diff --git a/internal/profiledb/internal/profiledbtest/profiledbtest.go b/internal/profiledb/internal/profiledbtest/profiledbtest.go new file mode 100644 index 0000000..9d0d28c --- /dev/null +++ b/internal/profiledb/internal/profiledbtest/profiledbtest.go @@ -0,0 +1,79 @@ +// Package profiledbtest contains common helpers for profile-database tests. +package profiledbtest + +import ( + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/stretchr/testify/require" +) + +// ProfileID is the profile ID for tests. +// +// Keep in sync with internal/profiledb/testdata/profiles.json. +const ProfileID agd.ProfileID = "prof1234" + +// DeviceID is the profile ID for tests. +// +// Keep in sync with internal/profiledb/testdata/profiles.json. +const DeviceID agd.DeviceID = "dev1234" + +// NewProfile returns the common profile and device for tests. +// +// Keep in sync with internal/profiledb/testdata/profiles.json. +func NewProfile(tb testing.TB) (p *agd.Profile, d *agd.Device) { + tb.Helper() + + loc, err := agdtime.LoadLocation("Europe/Brussels") + require.NoError(tb, err) + + dev := &agd.Device{ + ID: DeviceID, + LinkedIP: netip.MustParseAddr("1.2.3.4"), + Name: "dev1", + DedicatedIPs: []netip.Addr{ + netip.MustParseAddr("1.2.4.5"), + }, + FilteringEnabled: true, + } + + return &agd.Profile{ + Parental: &agd.ParentalProtectionSettings{ + Schedule: &agd.ParentalProtectionSchedule{ + Week: &agd.WeeklySchedule{ + {Start: 0, End: 700}, + {Start: 0, End: 700}, + {Start: 0, End: 700}, + {Start: 0, End: 700}, + {Start: 0, End: 700}, + {Start: 0, End: 700}, + {Start: 0, End: 700}, + }, + TimeZone: loc, + }, + Enabled: true, + }, + BlockingMode: dnsmsg.BlockingModeCodec{ + Mode: &dnsmsg.BlockingModeNullIP{}, + }, + ID: ProfileID, + DeviceIDs: []agd.DeviceID{dev.ID}, + RuleListIDs: []agd.FilterListID{ + "adguard_dns_filter", + }, + CustomRules: []agd.FilterRuleText{ + "|blocked-by-custom.example", + }, + FilteredResponseTTL: 10 * time.Second, + FilteringEnabled: true, + SafeBrowsingEnabled: true, + RuleListsEnabled: true, + QueryLogEnabled: true, + BlockPrivateRelay: true, + BlockFirefoxCanary: true, + }, dev +} diff --git a/internal/profiledb/profiledb.go b/internal/profiledb/profiledb.go new file mode 100644 index 0000000..ffa4a95 --- /dev/null +++ b/internal/profiledb/profiledb.go @@ -0,0 +1,506 @@ +// Package profiledb defines interfaces for databases of user profiles. +package profiledb + +import ( + "context" + "fmt" + "net/netip" + "path/filepath" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/metrics" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachejson" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +// Interface is the local database of user profiles and devices. +type Interface interface { + // ProfileByDeviceID returns the profile and the device identified by id. + ProfileByDeviceID( + ctx context.Context, + id agd.DeviceID, + ) (p *agd.Profile, d *agd.Device, err error) + + // ProfileByDedicatedIP returns the profile and the device identified by its + // dedicated DNS server IP address. + ProfileByDedicatedIP( + ctx context.Context, + ip netip.Addr, + ) (p *agd.Profile, d *agd.Device, err error) + + // ProfileByLinkedIP returns the profile and the device identified by its + // linked IP address. + ProfileByLinkedIP(ctx context.Context, ip netip.Addr) (p *agd.Profile, d *agd.Device, err error) +} + +// Default is the default in-memory implementation of the [Interface] interface +// that can refresh itself from the provided storage. +type Default struct { + // mapsMu protects the profiles, devices, deviceIDToProfileID, + // linkedIPToDeviceID, and dedicatedIPToDeviceID maps. + mapsMu *sync.RWMutex + + // refreshMu protects syncTime and lastFullSync. These are only used within + // Refresh, so this is also basically a refresh serializer. + refreshMu *sync.Mutex + + // cache is the filesystem-cache storage used by this profile database. + cache internal.FileCacheStorage + + // storage returns the data for this profile DB. + storage Storage + + // profiles maps profile IDs to profile records. + profiles map[agd.ProfileID]*agd.Profile + + // devices maps device IDs to device records. + devices map[agd.DeviceID]*agd.Device + + // deviceIDToProfileID maps device IDs to the ID of their profile. + deviceIDToProfileID map[agd.DeviceID]agd.ProfileID + + // linkedIPToDeviceID maps linked IP addresses to the IDs of their devices. + linkedIPToDeviceID map[netip.Addr]agd.DeviceID + + // dedicatedIPToDeviceID maps dedicated IP addresses to the IDs of their + // devices. + dedicatedIPToDeviceID map[netip.Addr]agd.DeviceID + + // syncTime is the time of the last synchronization point. It is received + // from the storage during a refresh and is then used in consecutive + // requests to the storage, unless it's a full synchronization. + syncTime time.Time + + // lastFullSync is the time of the last full synchronization. + lastFullSync time.Time + + // fullSyncIvl is the interval between two full synchronizations with the + // storage. + fullSyncIvl time.Duration +} + +// New returns a new default in-memory profile database with a filesystem cache. +// The initial refresh is performed immediately with the constant timeout of 1 +// minute, beyond which an empty profiledb is returned. If cacheFilePath is the +// string "none", filesystem cache is disabled. db is never nil. +func New( + s Storage, + fullSyncIvl time.Duration, + cacheFilePath string, +) (db *Default, err error) { + var cacheStorage internal.FileCacheStorage + if cacheFilePath == "none" { + cacheStorage = internal.EmptyFileCacheStorage{} + } else { + switch ext := filepath.Ext(cacheFilePath); ext { + case ".json": + cacheStorage = filecachejson.New(cacheFilePath) + case ".pb": + cacheStorage = filecachepb.New(cacheFilePath) + default: + return nil, fmt.Errorf("file %q is neither json nor protobuf", cacheFilePath) + } + } + + db = &Default{ + mapsMu: &sync.RWMutex{}, + refreshMu: &sync.Mutex{}, + cache: cacheStorage, + storage: s, + syncTime: time.Time{}, + lastFullSync: time.Time{}, + profiles: make(map[agd.ProfileID]*agd.Profile), + devices: make(map[agd.DeviceID]*agd.Device), + deviceIDToProfileID: make(map[agd.DeviceID]agd.ProfileID), + linkedIPToDeviceID: make(map[netip.Addr]agd.DeviceID), + dedicatedIPToDeviceID: make(map[netip.Addr]agd.DeviceID), + fullSyncIvl: fullSyncIvl, + } + + err = db.loadFileCache() + if err != nil { + log.Error("profiledb: fs cache: loading: %s", err) + } + + // initialTimeout defines the maximum duration of the first attempt to load + // the profiledb. + const initialTimeout = 1 * time.Minute + + ctx, cancel := context.WithTimeout(context.Background(), initialTimeout) + defer cancel() + + log.Info("profiledb: initial refresh") + + err = db.Refresh(ctx) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + log.Info("profiledb: warning: initial refresh timeout: %s", err) + + return db, nil + } + + return nil, fmt.Errorf("initial refresh: %w", err) + } + + log.Info("profiledb: initial refresh succeeded") + + return db, nil +} + +// type check +var _ agd.Refresher = (*Default)(nil) + +// Refresh implements the [Refresher] interface for *Default. It updates the +// internal maps and the synchronization time using the data it receives from +// the storage. +func (db *Default) Refresh(ctx context.Context) (err error) { + var totalProfiles, totalDevices int + startTime := time.Now() + defer func() { + metrics.ProfilesSyncTime.SetToCurrentTime() + metrics.ProfilesSyncDuration.Observe(time.Since(startTime).Seconds()) + metrics.ProfilesCountGauge.Set(float64(totalProfiles)) + metrics.DevicesCountGauge.Set(float64(totalDevices)) + metrics.SetStatusGauge(metrics.ProfilesSyncStatus, err) + }() + + reqID := agd.NewRequestID() + ctx = agd.WithRequestID(ctx, reqID) + + defer func() { err = errors.Annotate(err, "req %s: %w", reqID) }() + + db.refreshMu.Lock() + defer db.refreshMu.Unlock() + + syncTime := db.syncTime + + sinceLastFullSync := time.Since(db.lastFullSync) + isFullSync := sinceLastFullSync >= db.fullSyncIvl + if isFullSync { + log.Info("profiledb: full sync, %s since %s", sinceLastFullSync, db.lastFullSync) + + syncTime = time.Time{} + } + + resp, err := db.storage.Profiles(ctx, &StorageRequest{ + SyncTime: syncTime, + }) + if err != nil { + return fmt.Errorf("updating profiles: %w", err) + } + + profiles := resp.Profiles + devices := resp.Devices + db.setProfiles(profiles, devices, isFullSync) + + profNum := len(profiles) + devNum := len(devices) + log.Debug("profiledb: req %s: got %d profiles with %d devices", reqID, profNum, devNum) + + metrics.ProfilesNewCountGauge.Set(float64(profNum)) + metrics.DevicesNewCountGauge.Set(float64(devNum)) + + db.syncTime = resp.SyncTime + if isFullSync { + db.lastFullSync = time.Now() + + err = db.cache.Store(&internal.FileCache{ + SyncTime: resp.SyncTime, + Profiles: resp.Profiles, + Devices: resp.Devices, + Version: internal.FileCacheVersion, + }) + if err != nil { + return fmt.Errorf("saving cache: %w", err) + } + } + + totalProfiles = len(db.profiles) + totalDevices = len(db.devices) + + return nil +} + +// loadFileCache loads the profiles data from the filesystem cache. +func (db *Default) loadFileCache() (err error) { + const logPrefix = "profiledb: cache" + + start := time.Now() + log.Info("%s: initial loading", logPrefix) + + c, err := db.cache.Load() + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } else if c == nil { + log.Info("%s: no cache", logPrefix) + + return nil + } + + profNum, devNum := len(c.Profiles), len(c.Devices) + log.Info( + "%s: got version %d, %d profiles, %d devices in %s", + logPrefix, + c.Version, + profNum, + devNum, + time.Since(start), + ) + + if c.Version != internal.FileCacheVersion { + log.Info( + "%s: version %d is different from %d", + logPrefix, + c.Version, + internal.FileCacheVersion, + ) + + return nil + } else if profNum == 0 || devNum == 0 { + log.Info("%s: empty", logPrefix) + + return nil + } + + db.setProfiles(c.Profiles, c.Devices, true) + db.syncTime, db.lastFullSync = c.SyncTime, c.SyncTime + + return nil +} + +// setProfiles adds or updates the data for all profiles and devices. +func (db *Default) setProfiles(profiles []*agd.Profile, devices []*agd.Device, isFullSync bool) { + db.mapsMu.Lock() + defer db.mapsMu.Unlock() + + if isFullSync { + maps.Clear(db.profiles) + maps.Clear(db.devices) + maps.Clear(db.deviceIDToProfileID) + maps.Clear(db.linkedIPToDeviceID) + maps.Clear(db.dedicatedIPToDeviceID) + } + + for _, p := range profiles { + db.profiles[p.ID] = p + + for _, devID := range p.DeviceIDs { + db.deviceIDToProfileID[devID] = p.ID + } + } + + for _, d := range devices { + devID := d.ID + db.devices[devID] = d + + if d.LinkedIP != (netip.Addr{}) { + db.linkedIPToDeviceID[d.LinkedIP] = devID + } + + for _, dedIP := range d.DedicatedIPs { + db.dedicatedIPToDeviceID[dedIP] = devID + } + } +} + +// type check +var _ Interface = (*Default)(nil) + +// ProfileByDeviceID implements the [Interface] interface for *Default. +func (db *Default) ProfileByDeviceID( + ctx context.Context, + id agd.DeviceID, +) (p *agd.Profile, d *agd.Device, err error) { + db.mapsMu.RLock() + defer db.mapsMu.RUnlock() + + return db.profileByDeviceID(ctx, id) +} + +// profileByDeviceID returns the profile and the device by the ID of the device, +// if found. It assumes that db.mapsMu is locked for reading. +func (db *Default) profileByDeviceID( + _ context.Context, + id agd.DeviceID, +) (p *agd.Profile, d *agd.Device, err error) { + // Do not use [errors.Annotate] here, because it allocates even when the + // error is nil. Also do not use fmt.Errorf in a defer, because it + // allocates when a device is not found, which is the most common case. + + profID, ok := db.deviceIDToProfileID[id] + if !ok { + return nil, nil, ErrDeviceNotFound + } + + p, ok = db.profiles[profID] + if !ok { + // We have an older device record with a deleted profile. Remove it + // from our profile DB in a goroutine, since that requires a write lock. + go db.removeDevice(id) + + return nil, nil, fmt.Errorf("empty profile: %w", ErrDeviceNotFound) + } + + // Reinspect the devices in the profile record to make sure that the device + // is still attached to this profile. + for _, profDevID := range p.DeviceIDs { + if profDevID == id { + d = db.devices[id] + + break + } + } + + if d == nil { + // Perhaps, the device has been deleted from this profile. May happen + // when the device was found by a linked IP. Remove it from our profile + // DB in a goroutine, since that requires a write lock. + go db.removeDevice(id) + + return nil, nil, fmt.Errorf("rechecking devices: %w", ErrDeviceNotFound) + } + + return p, d, nil +} + +// removeDevice removes the device with the given ID from the database. It is +// intended to be used as a goroutine. +func (db *Default) removeDevice(id agd.DeviceID) { + defer log.OnPanicAndExit("removeDevice", 1) + + db.mapsMu.Lock() + defer db.mapsMu.Unlock() + + delete(db.deviceIDToProfileID, id) +} + +// ProfileByLinkedIP implements the [Interface] interface for *Default. ip must +// be valid. +func (db *Default) ProfileByLinkedIP( + ctx context.Context, + ip netip.Addr, +) (p *agd.Profile, d *agd.Device, err error) { + // Do not use errors.Annotate here, because it allocates even when the error + // is nil. Also do not use fmt.Errorf in a defer, because it allocates when + // a device is not found, which is the most common case. + + db.mapsMu.RLock() + defer db.mapsMu.RUnlock() + + id, ok := db.linkedIPToDeviceID[ip] + if !ok { + return nil, nil, ErrDeviceNotFound + } + + const errPrefix = "profile by device linked ip" + p, d, err = db.profileByDeviceID(ctx, id) + if err != nil { + if errors.Is(err, ErrDeviceNotFound) { + // Probably, the device has been deleted. Remove it from our + // profile DB in a goroutine, since that requires a write lock. + go db.removeLinkedIP(ip) + } + + // Don't add the device ID to the error here, since it is already added + // by profileByDeviceID. + return nil, nil, fmt.Errorf("%s: %w", errPrefix, err) + } + + if d.LinkedIP == (netip.Addr{}) { + return nil, nil, fmt.Errorf( + "%s: device does not have linked ip: %w", + errPrefix, + ErrDeviceNotFound, + ) + } else if d.LinkedIP != ip { + // The linked IP has changed. Remove it from our profile DB in a + // goroutine, since that requires a write lock. + go db.removeLinkedIP(ip) + + return nil, nil, fmt.Errorf( + "%s: %q doesn't match: %w", + errPrefix, + d.LinkedIP, + ErrDeviceNotFound, + ) + } + + return p, d, nil +} + +// removeLinkedIP removes the device link for the given linked IP address from +// the profile database. It is intended to be used as a goroutine. +func (db *Default) removeLinkedIP(ip netip.Addr) { + defer log.OnPanicAndExit("removeLinkedIP", 1) + + db.mapsMu.Lock() + defer db.mapsMu.Unlock() + + delete(db.linkedIPToDeviceID, ip) +} + +// ProfileByDedicatedIP implements the [Interface] interface for *Default. ip +// must be valid. +func (db *Default) ProfileByDedicatedIP( + ctx context.Context, + ip netip.Addr, +) (p *agd.Profile, d *agd.Device, err error) { + // Do not use errors.Annotate here, because it allocates even when the error + // is nil. Also do not use fmt.Errorf in a defer, because it allocates when + // a device is not found, which is the most common case. + + db.mapsMu.RLock() + defer db.mapsMu.RUnlock() + + id, ok := db.dedicatedIPToDeviceID[ip] + if !ok { + return nil, nil, ErrDeviceNotFound + } + + const errPrefix = "profile by device dedicated ip" + p, d, err = db.profileByDeviceID(ctx, id) + if err != nil { + if errors.Is(err, ErrDeviceNotFound) { + // Probably, the device has been deleted. Remove it from our + // profile DB in a goroutine, since that requires a write lock. + go db.removeDedicatedIP(ip) + } + + // Don't add the device ID to the error here, since it is already added + // by profileByDeviceID. + return nil, nil, fmt.Errorf("%s: %w", errPrefix, err) + } + + if ipIdx := slices.Index(d.DedicatedIPs, ip); ipIdx < 0 { + // Perhaps, the device has changed its dedicated IPs. Remove it from + // our profile DB in a goroutine, since that requires a write lock. + go db.removeDedicatedIP(ip) + + return nil, nil, fmt.Errorf( + "%s: rechecking dedicated ips: %w", + errPrefix, + ErrDeviceNotFound, + ) + } + + return p, d, nil +} + +// removeDedicatedIP removes the device link for the given dedicated IP address +// from the profile database. It is intended to be used as a goroutine. +func (db *Default) removeDedicatedIP(ip netip.Addr) { + defer log.OnPanicAndExit("removeDedicatedIP", 1) + + db.mapsMu.Lock() + defer db.mapsMu.Unlock() + + delete(db.dedicatedIPToDeviceID, ip) +} diff --git a/internal/profiledb/profiledb_test.go b/internal/profiledb/profiledb_test.go new file mode 100644 index 0000000..a9345b2 --- /dev/null +++ b/internal/profiledb/profiledb_test.go @@ -0,0 +1,495 @@ +package profiledb_test + +import ( + "bytes" + "context" + "net/netip" + "os" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" + "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb" + "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + testutil.DiscardLogOutput(m) +} + +// Common IPs for tests +// +// Keep in sync with testdata/profiles.json. +var ( + testClientIPv4 = netip.MustParseAddr("1.2.3.4") + testOtherClientIPv4 = netip.MustParseAddr("1.2.3.5") + + testDedicatedIPv4 = netip.MustParseAddr("1.2.4.5") + testOtherDedicatedIPv4 = netip.MustParseAddr("1.2.4.6") +) + +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second + +// newDefaultProfileDB returns a new default profile database for tests. +// devicesCh receives the devices that the storage should return in its +// response. +func newDefaultProfileDB(tb testing.TB, devices <-chan []*agd.Device) (db *profiledb.Default) { + tb.Helper() + + onProfiles := func( + _ context.Context, + _ *profiledb.StorageRequest, + ) (resp *profiledb.StorageResponse, err error) { + devices, _ := testutil.RequireReceive(tb, devices, testTimeout) + devIDs := make([]agd.DeviceID, 0, len(devices)) + for _, d := range devices { + devIDs = append(devIDs, d.ID) + } + + return &profiledb.StorageResponse{ + Profiles: []*agd.Profile{{ + BlockingMode: dnsmsg.BlockingModeCodec{ + Mode: &dnsmsg.BlockingModeNullIP{}, + }, + ID: profiledbtest.ProfileID, + DeviceIDs: devIDs, + }}, + Devices: devices, + }, nil + } + + ps := &agdtest.ProfileStorage{ + OnProfiles: onProfiles, + } + + db, err := profiledb.New(ps, 1*time.Minute, "none") + require.NoError(tb, err) + + return db +} + +func TestDefaultProfileDB(t *testing.T) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + LinkedIP: testClientIPv4, + DedicatedIPs: []netip.Addr{ + testDedicatedIPv4, + }, + } + + devicesCh := make(chan []*agd.Device, 1) + devicesCh <- []*agd.Device{dev} + db := newDefaultProfileDB(t, devicesCh) + + t.Run("by_device_id", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + p, d, err := db.ProfileByDeviceID(ctx, profiledbtest.DeviceID) + require.NoError(t, err) + + assert.Equal(t, profiledbtest.ProfileID, p.ID) + assert.Equal(t, d, dev) + }) + + t.Run("by_dedicated_ip", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + p, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4) + require.NoError(t, err) + + assert.Equal(t, profiledbtest.ProfileID, p.ID) + assert.Equal(t, d, dev) + }) + + t.Run("by_linked_ip", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + p, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4) + require.NoError(t, err) + + assert.Equal(t, profiledbtest.ProfileID, p.ID) + assert.Equal(t, d, dev) + }) +} + +func TestDefaultProfileDB_ProfileByDedicatedIP_removedDevice(t *testing.T) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + DedicatedIPs: []netip.Addr{ + testDedicatedIPv4, + }, + } + + devicesCh := make(chan []*agd.Device, 2) + + // The first response, the device is still there. + devicesCh <- []*agd.Device{dev} + + db := newDefaultProfileDB(t, devicesCh) + + ctx := context.Background() + _, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4) + require.NoError(t, err) + + assert.Equal(t, d, dev) + + // The second response, the device is removed. + devicesCh <- nil + + err = db.Refresh(ctx) + require.NoError(t, err) + + assert.Eventually(t, func() (ok bool) { + _, d, err = db.ProfileByDedicatedIP(ctx, testDedicatedIPv4) + + return errors.Is(err, profiledb.ErrDeviceNotFound) + }, testTimeout, testTimeout/10) +} + +func TestDefaultProfileDB_ProfileByDedicatedIP_deviceNewIP(t *testing.T) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + DedicatedIPs: []netip.Addr{ + testDedicatedIPv4, + }, + } + + devicesCh := make(chan []*agd.Device, 2) + + // The first response, the device is still there. + devicesCh <- []*agd.Device{dev} + + db := newDefaultProfileDB(t, devicesCh) + + ctx := context.Background() + _, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4) + require.NoError(t, err) + + assert.Equal(t, d, dev) + + // The second response, the device has a new IP. + dev.DedicatedIPs[0] = testOtherDedicatedIPv4 + devicesCh <- []*agd.Device{dev} + + err = db.Refresh(ctx) + require.NoError(t, err) + + assert.Eventually(t, func() (ok bool) { + _, _, err = db.ProfileByDedicatedIP(ctx, testDedicatedIPv4) + + if !errors.Is(err, profiledb.ErrDeviceNotFound) { + return false + } + + _, d, err = db.ProfileByDedicatedIP(ctx, testOtherDedicatedIPv4) + if err != nil { + return false + } + + return d != nil && d.ID == dev.ID + }, testTimeout, testTimeout/10) +} + +func TestDefaultProfileDB_ProfileByLinkedIP_removedDevice(t *testing.T) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + LinkedIP: testClientIPv4, + } + + devicesCh := make(chan []*agd.Device, 2) + + // The first response, the device is still there. + devicesCh <- []*agd.Device{dev} + + db := newDefaultProfileDB(t, devicesCh) + + ctx := context.Background() + _, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4) + require.NoError(t, err) + + assert.Equal(t, d, dev) + + // The second response, the device is removed. + devicesCh <- nil + + err = db.Refresh(ctx) + require.NoError(t, err) + + assert.Eventually(t, func() (ok bool) { + _, d, err = db.ProfileByLinkedIP(ctx, testClientIPv4) + + return errors.Is(err, profiledb.ErrDeviceNotFound) + }, testTimeout, testTimeout/10) +} + +func TestDefaultProfileDB_ProfileByLinkedIP_deviceNewIP(t *testing.T) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + LinkedIP: testClientIPv4, + } + + devicesCh := make(chan []*agd.Device, 2) + + // The first response, the device is still there. + devicesCh <- []*agd.Device{dev} + + db := newDefaultProfileDB(t, devicesCh) + + ctx := context.Background() + _, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4) + require.NoError(t, err) + + assert.Equal(t, d, dev) + + // The second response, the device has a new IP. + dev.LinkedIP = testOtherClientIPv4 + devicesCh <- []*agd.Device{dev} + + err = db.Refresh(ctx) + require.NoError(t, err) + + assert.Eventually(t, func() (ok bool) { + _, _, err = db.ProfileByLinkedIP(ctx, testClientIPv4) + + if !errors.Is(err, profiledb.ErrDeviceNotFound) { + return false + } + + _, d, err = db.ProfileByLinkedIP(ctx, testOtherClientIPv4) + if err != nil { + return false + } + + return d != nil && d.ID == dev.ID + }, testTimeout, testTimeout/10) +} + +func TestDefaultProfileDB_fileCache_success(t *testing.T) { + var gotSyncTime time.Time + onProfiles := func( + _ context.Context, + req *profiledb.StorageRequest, + ) (resp *profiledb.StorageResponse, err error) { + gotSyncTime = req.SyncTime + + return &profiledb.StorageResponse{}, nil + } + + ps := &agdtest.ProfileStorage{ + OnProfiles: onProfiles, + } + + cacheFileTmplPath := filepath.Join("testdata", "profiles.json") + data, err := os.ReadFile(cacheFileTmplPath) + require.NoError(t, err) + + // Use the time with monotonic clocks stripped down. + wantSyncTime := time.Now().Round(0).UTC() + data = bytes.ReplaceAll( + data, + []byte("SYNC_TIME"), + []byte(wantSyncTime.Format(time.RFC3339Nano)), + ) + + cacheFilePath := filepath.Join(t.TempDir(), "profiles.json") + err = os.WriteFile(cacheFilePath, data, 0o600) + require.NoError(t, err) + + db, err := profiledb.New(ps, 1*time.Minute, cacheFilePath) + require.NoError(t, err) + require.NotNil(t, db) + + assert.Equal(t, wantSyncTime, gotSyncTime) + + prof, dev := profiledbtest.NewProfile(t) + p, d, err := db.ProfileByDeviceID(context.Background(), dev.ID) + require.NoError(t, err) + + assert.Equal(t, dev, d) + assert.Equal(t, prof, p) +} + +func TestDefaultProfileDB_fileCache_badVersion(t *testing.T) { + storageCalled := false + ps := &agdtest.ProfileStorage{ + OnProfiles: func( + _ context.Context, + _ *profiledb.StorageRequest, + ) (resp *profiledb.StorageResponse, err error) { + storageCalled = true + + return &profiledb.StorageResponse{}, nil + }, + } + + cacheFilePath := filepath.Join(t.TempDir(), "profiles.json") + err := os.WriteFile(cacheFilePath, []byte(`{"version":1000}`), 0o600) + require.NoError(t, err) + + db, err := profiledb.New(ps, 1*time.Minute, cacheFilePath) + assert.NoError(t, err) + assert.NotNil(t, db) + assert.True(t, storageCalled) +} + +// Sinks for benchmarks. +var ( + profSink *agd.Profile + devSink *agd.Device + errSink error +) + +func BenchmarkDefaultProfileDB_ProfileByDeviceID(b *testing.B) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + } + + devicesCh := make(chan []*agd.Device, 1) + devicesCh <- []*agd.Device{dev} + db := newDefaultProfileDB(b, devicesCh) + + ctx := context.Background() + + b.Run("success", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByDeviceID(ctx, profiledbtest.DeviceID) + } + + assert.NotNil(b, profSink) + assert.NotNil(b, devSink) + assert.NoError(b, errSink) + }) + + const wrongDevID = profiledbtest.DeviceID + "_bad" + + b.Run("not_found", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByDeviceID(ctx, wrongDevID) + } + + assert.Nil(b, profSink) + assert.Nil(b, devSink) + assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound) + }) + + // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro + // 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkDefaultProfileDB_ProfileByDeviceID/success-16 59396382 21.36 ns/op 0 B/op 0 allocs/op + // BenchmarkDefaultProfileDB_ProfileByDeviceID/not_found-16 74497800 16.45 ns/op 0 B/op 0 allocs/op +} + +func BenchmarkDefaultProfileDB_ProfileByLinkedIP(b *testing.B) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + LinkedIP: testClientIPv4, + } + + devicesCh := make(chan []*agd.Device, 1) + devicesCh <- []*agd.Device{dev} + db := newDefaultProfileDB(b, devicesCh) + + ctx := context.Background() + + b.Run("success", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByLinkedIP(ctx, testClientIPv4) + } + + assert.NotNil(b, profSink) + assert.NotNil(b, devSink) + assert.NoError(b, errSink) + }) + + b.Run("not_found", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByLinkedIP(ctx, testOtherClientIPv4) + } + + assert.Nil(b, profSink) + assert.Nil(b, devSink) + assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound) + }) + + // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro + // 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkDefaultProfileDB_ProfileByLinkedIP/success-16 24822542 44.11 ns/op 0 B/op 0 allocs/op + // BenchmarkDefaultProfileDB_ProfileByLinkedIP/not_found-16 63539154 20.04 ns/op 0 B/op 0 allocs/op +} + +func BenchmarkDefaultProfileDB_ProfileByDedicatedIP(b *testing.B) { + dev := &agd.Device{ + ID: profiledbtest.DeviceID, + DedicatedIPs: []netip.Addr{ + testClientIPv4, + }, + } + + devicesCh := make(chan []*agd.Device, 1) + devicesCh <- []*agd.Device{dev} + db := newDefaultProfileDB(b, devicesCh) + + ctx := context.Background() + + b.Run("success", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByDedicatedIP(ctx, testClientIPv4) + } + + assert.NotNil(b, profSink) + assert.NotNil(b, devSink) + assert.NoError(b, errSink) + }) + + b.Run("not_found", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + profSink, devSink, errSink = db.ProfileByDedicatedIP(ctx, testOtherClientIPv4) + } + + assert.Nil(b, profSink) + assert.Nil(b, devSink) + assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound) + }) + + // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro + // 7 CPU: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb + // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics + // BenchmarkDefaultProfileDB_ProfileByDedicatedIP/success-16 22697658 48.19 ns/op 0 B/op 0 allocs/op + // BenchmarkDefaultProfileDB_ProfileByDedicatedIP/not_found-16 61062061 19.89 ns/op 0 B/op 0 allocs/op +} diff --git a/internal/profiledb/storage.go b/internal/profiledb/storage.go new file mode 100644 index 0000000..b3b68ed --- /dev/null +++ b/internal/profiledb/storage.go @@ -0,0 +1,35 @@ +package profiledb + +import ( + "context" + "time" + + "github.com/AdguardTeam/AdGuardDNS/internal/agd" +) + +// Storage is a storage from which an [Default] receives data about profiles and +// devices. +type Storage interface { + // Profiles returns profile and device data that has changed since + // req.SyncTime. req must not be nil. + Profiles(ctx context.Context, req *StorageRequest) (resp *StorageResponse, err error) +} + +// StorageRequest is the request to [Storage] for profiles and devices. +type StorageRequest struct { + // SyncTime is the last time profiles were synced. + SyncTime time.Time +} + +// StorageResponse is the ProfileStorage.Profiles response. +type StorageResponse struct { + // SyncTime is the time that should be saved and used as the next + // [ProfilesRequest.SyncTime]. + SyncTime time.Time + + // Profiles are the profiles data from the [Storage]. + Profiles []*agd.Profile + + // Devices are the device data from the [Storage]. + Devices []*agd.Device +} diff --git a/internal/profiledb/testdata/profiles.json b/internal/profiledb/testdata/profiles.json new file mode 100644 index 0000000..d6b7350 --- /dev/null +++ b/internal/profiledb/testdata/profiles.json @@ -0,0 +1,77 @@ +{ + "sync_time": "SYNC_TIME", + "profiles": [ + { + "Parental": { + "Schedule": { + "Week": [ + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + }, + { + "Start": 0, + "End": 700 + } + ], + "TimeZone": "Europe/Brussels" + }, + "Enabled": true + }, + "BlockingMode": { + "type": "null_ip" + }, + "ID": "prof1234", + "UpdateTime": "0001-01-01T00:00:00.000Z", + "DeviceIDs": [ + "dev1234" + ], + "RuleListIDs": [ + "adguard_dns_filter" + ], + "CustomRules": [ + "|blocked-by-custom.example" + ], + "FilteredResponseTTL": 10000000000, + "FilteringEnabled": true, + "SafeBrowsingEnabled": true, + "RuleListsEnabled": true, + "QueryLogEnabled": true, + "Deleted": false, + "BlockPrivateRelay": true, + "BlockFirefoxCanary": true + } + ], + "devices": [ + { + "ID": "dev1234", + "LinkedIP": "1.2.3.4", + "Name": "dev1", + "DedicatedIPs": [ + "1.2.4.5" + ], + "FilteringEnabled": true + } + ], + "version": 6 +} diff --git a/internal/tools/go.mod b/internal/tools/go.mod index 201dc77..ea99e18 100644 --- a/internal/tools/go.mod +++ b/internal/tools/go.mod @@ -8,26 +8,27 @@ require ( github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28 github.com/kisielk/errcheck v1.6.3 github.com/kyoh86/looppointer v0.2.1 - github.com/securego/gosec/v2 v2.15.0 - golang.org/x/tools v0.7.0 - golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e - honnef.co/go/tools v0.4.2 - mvdan.cc/gofumpt v0.4.0 - mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95 + github.com/securego/gosec/v2 v2.16.0 + golang.org/x/tools v0.9.3 + golang.org/x/vuln v0.1.0 + google.golang.org/protobuf v1.30.0 + honnef.co/go/tools v0.4.3 + mvdan.cc/gofumpt v0.5.0 + mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8 ) require ( - github.com/BurntSushi/toml v1.2.1 // indirect + github.com/BurntSushi/toml v1.3.1 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/uuid v1.3.0 // indirect - github.com/gookit/color v1.5.2 // indirect + github.com/gookit/color v1.5.3 // indirect github.com/kyoh86/nolint v0.0.1 // indirect github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/exp v0.0.0-20230307190834-24139beb5833 // indirect - golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833 // indirect - golang.org/x/mod v0.9.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.6.0 // indirect + golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 // indirect + golang.org/x/mod v0.10.0 // indirect + golang.org/x/sync v0.2.0 // indirect + golang.org/x/sys v0.8.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/tools/go.sum b/internal/tools/go.sum index 5653710..9c17051 100644 --- a/internal/tools/go.sum +++ b/internal/tools/go.sum @@ -1,28 +1,31 @@ -github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= -github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/BurntSushi/toml v1.3.1 h1:rHnDkSK+/g6DlREUK73PkmIs60pqrnuduK+JmP++JmU= +github.com/BurntSushi/toml v1.3.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= +github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= -github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golangci/misspell v0.4.0 h1:KtVB/hTK4bbL/S6bs64rYyk8adjmh1BygbBiaAiX+a0= github.com/golangci/misspell v0.4.0/go.mod h1:W6O/bwV6lGDxUCChm2ykw9NQdd5bYd1Xkjo88UcWyJc= github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gookit/color v1.5.2 h1:uLnfXcaFjlrDnQDT+NCBcfhrXqYTx/rcCa6xn01Y8yI= -github.com/gookit/color v1.5.2/go.mod h1:w8h4bGiHeeBpvQVePTutdbERIUf3oJE5lZ8HM0UgXyg= +github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE= +github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE= github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28 h1:9alfqbrhuD+9fLZ4iaAVwhlp5PEhmnBt7yvK2Oy5C1U= github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0= github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8= github.com/kisielk/errcheck v1.6.3/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kyoh86/looppointer v0.2.1 h1:Jx9fnkBj/JrIryBLMTYNTj9rvc2SrPS98Dg0w7fxdJg= github.com/kyoh86/looppointer v0.2.1/go.mod h1:q358WcM8cMWU+5vzqukvaZtnJi1kw/MpRHQm3xvTrjw= @@ -30,21 +33,15 @@ github.com/kyoh86/nolint v0.0.1 h1:GjNxDEkVn2wAxKHtP7iNTrRxytRZ1wXxLV5j4XzGfRU= github.com/kyoh86/nolint v0.0.1/go.mod h1:1ZiZZ7qqrZ9dZegU96phwVcdQOMKIqRzFJL3ewq9gtI= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 h1:4kuARK6Y6FxaNu/BnU2OAaLF86eTVhP2hjTB6iMvItA= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8= -github.com/onsi/ginkgo/v2 v2.8.0 h1:pAM+oBNPrpXRs+E/8spkeGx9QgekbRVyr74EUvRVOUI= -github.com/onsi/gomega v1.26.0 h1:03cDLK28U6hWvCAns6NeydX3zIm4SF3ci69ulidS32Q= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= +github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/securego/gosec/v2 v2.15.0 h1:v4Ym7FF58/jlykYmmhZ7mTm7FQvN/setNm++0fgIAtw= -github.com/securego/gosec/v2 v2.15.0/go.mod h1:VOjTrZOkUtSDt2QLSJmQBMWnvwiQPEjg0l+5juIqGk8= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyOs5U= +github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -56,25 +53,25 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s= golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= -golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833 h1:jWGQJV4niP+CCmFW9ekjA9Zx8vYORzOUH2/Nl5WPuLQ= -golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= +golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 h1:pnP8r+W8Fm7XJ8CWtXi4S9oJmPBTrkfYN/dNbaPj6Y4= +golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= -golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -84,34 +81,37 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= -golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e h1:zLJSre6/6VuA0wMLQe9ShzEncb0k8E/z9wVhmlnRdNs= -golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e/go.mod h1:ydpjOTRSBwOBFJRP/w5NF2HSPnFg1JxobEZQGOirxgo= +golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= +golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/vuln v0.1.0 h1:9GRdj6wAIkDrsMevuolY+SXERPjQPp2P1ysYA0jpZe0= +golang.org/x/vuln v0.1.0/go.mod h1:/YuzZYjGbwB8y19CisAppfyw3uTZnuCz3r+qgx/QRzU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.4.2 h1:6qXr+R5w+ktL5UkwEbPp+fEvfyoMPche6GkOpGHZcLc= -honnef.co/go/tools v0.4.2/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA= -mvdan.cc/gofumpt v0.4.0 h1:JVf4NN1mIpHogBj7ABpgOyZc65/UUOkKQFkoURsz4MM= -mvdan.cc/gofumpt v0.4.0/go.mod h1:PljLOHDeZqgS8opHRKLzp2It2VBuSdteAgqUfzMTxlQ= -mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95 h1:n/xhncJPSt0YzfOhnyn41XxUdrWQNgmLBG72FE27Fqw= -mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95/go.mod h1:2vU506e8nGWodqcci641NLi4im2twWSq4Lod756epHQ= +honnef.co/go/tools v0.4.3 h1:o/n5/K5gXqk8Gozvs2cnL0F2S1/g1vcGCAx2vETjITw= +honnef.co/go/tools v0.4.3/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA= +mvdan.cc/gofumpt v0.5.0 h1:0EQ+Z56k8tXjj/6TQD25BFNKQXpCvT0rnansIc7Ug5E= +mvdan.cc/gofumpt v0.5.0/go.mod h1:HBeVDtMKRZpXyxFciAirzdKklDlGu8aAy1wEbH5Y9js= +mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8 h1:VuJo4Mt0EVPychre4fNlDWDuE5AjXtPJpRUWqZDQhaI= +mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8/go.mod h1:Oh/d7dEtzsNHGOq1Cdv8aMm3KdKhVvPbRQcM8WFpBR8= diff --git a/internal/tools/tools.go b/internal/tools/tools.go index 230d6c4..deaa4f1 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -13,6 +13,7 @@ import ( _ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness" _ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow" _ "golang.org/x/vuln/cmd/govulncheck" + _ "google.golang.org/protobuf/cmd/protoc-gen-go" _ "honnef.co/go/tools/cmd/staticcheck" _ "mvdan.cc/gofumpt" _ "mvdan.cc/unparam" diff --git a/internal/websvc/handler.go b/internal/websvc/handler.go index ed09a80..c84273d 100644 --- a/internal/websvc/handler.go +++ b/internal/websvc/handler.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "golang.org/x/sys/unix" ) @@ -22,7 +23,7 @@ var _ http.Handler = (*Service)(nil) // ServeHTTP implements the http.Handler interface for *Service. func (svc *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { respHdr := w.Header() - respHdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent()) + respHdr.Set(httphdr.Server, agdhttp.UserAgent()) m, p, rAddr := r.Method, r.URL.Path, r.RemoteAddr optlog.Debug3("websvc: starting req %s %s from %s", m, p, rAddr) @@ -57,7 +58,7 @@ func (svc *Service) processRec( action = "writing 404" if len(svc.error404) != 0 { body = svc.error404 - respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) + respHdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML) } metrics.WebSvcError404RequestsTotal.Inc() @@ -65,7 +66,7 @@ func (svc *Service) processRec( action = "writing 500" if len(svc.error500) != 0 { body = svc.error500 - respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) + respHdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML) } metrics.WebSvcError500RequestsTotal.Inc() @@ -114,7 +115,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) { func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) { f := func(w http.ResponseWriter, r *http.Request) { hdr := w.Header() - hdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent()) + hdr.Set(httphdr.Server, agdhttp.UserAgent()) switch r.URL.Path { case "/favicon.ico": @@ -125,7 +126,7 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) { // the predefined response instead. serveRobotsDisallow(hdr, w, name) default: - hdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML) + hdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML) _, err := w.Write(blockPage) if err != nil { @@ -186,7 +187,7 @@ type StaticFile struct { // serveRobotsDisallow writes predefined disallow-all response. func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) { - hdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextPlain) + hdr.Set(httphdr.ContentType, agdhttp.HdrValTextPlain) _, err := io.WriteString(w, agdhttp.RobotsDisallowAll) if err != nil { diff --git a/internal/websvc/handler_test.go b/internal/websvc/handler_test.go index 24afecf..35cca21 100644 --- a/internal/websvc/handler_test.go +++ b/internal/websvc/handler_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/websvc" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,7 +34,7 @@ func TestService_ServeHTTP(t *testing.T) { "/favicon.ico": { Content: []byte{}, Headers: http.Header{ - agdhttp.HdrNameContentType: []string{"image/x-icon"}, + httphdr.ContentType: []string{"image/x-icon"}, }, }, } @@ -62,8 +63,8 @@ func TestService_ServeHTTP(t *testing.T) { // Static content path with headers. h := http.Header{ - agdhttp.HdrNameContentType: []string{"image/x-icon"}, - agdhttp.HdrNameServer: []string{"AdGuardDNS/"}, + httphdr.ContentType: []string{"image/x-icon"}, + httphdr.Server: []string{"AdGuardDNS/"}, } assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h) @@ -96,7 +97,7 @@ func assertResponse( svc.ServeHTTP(rw, r) assert.Equal(t, statusCode, rw.Code) - assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer)) + assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(httphdr.Server)) return rw } diff --git a/internal/websvc/linkip.go b/internal/websvc/linkip.go index 66e394b..2270aa9 100644 --- a/internal/websvc/linkip.go +++ b/internal/websvc/linkip.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/metrics" "github.com/AdguardTeam/AdGuardDNS/internal/optlog" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" ) @@ -47,10 +48,10 @@ func linkedIPHandler( // Set the X-Forwarded-For header to a nil value to make sure that // the proxy doesn't add it automatically. - hdr["X-Forwarded-For"] = nil + hdr[httphdr.XForwardedFor] = nil // Make sure that all requests are marked with our user agent. - hdr.Set(agdhttp.HdrNameUserAgent, agdhttp.UserAgent()) + hdr.Set(httphdr.UserAgent, agdhttp.UserAgent()) } // Use largely the same transport as http.DefaultTransport, but with a @@ -75,10 +76,10 @@ func linkedIPHandler( // Delete the Server header value from the upstream. modifyResponse := func(r *http.Response) (err error) { - r.Header.Del(agdhttp.HdrNameServer) + r.Header.Del(httphdr.Server) // Make sure that this URL can be used from the web page. - r.Header.Set(agdhttp.HdrNameAccessControlAllowOrigin, agdhttp.HdrValWildcard) + r.Header.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard) return nil } @@ -111,7 +112,7 @@ var _ http.Handler = (*linkedIPProxy)(nil) func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Set the Server header here, so that all 404 and 500 responses carry it. respHdr := w.Header() - respHdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent()) + respHdr.Set(httphdr.Server, agdhttp.UserAgent()) m, p, rAddr := r.Method, r.URL.Path, r.RemoteAddr optlog.Debug3("websvc: starting req %s %s from %s", m, p, rAddr) @@ -124,9 +125,9 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Remove all proxy headers before sending the request to proxy. hdr := r.Header - hdr.Del("Forwarded") - hdr.Del("True-Client-IP") - hdr.Del("X-Real-IP") + hdr.Del(httphdr.Forwarded) + hdr.Del(httphdr.TrueClientIP) + hdr.Del(httphdr.XRealIP) // Set the real IP. ip, err := netutil.SplitHost(rAddr) @@ -141,12 +142,12 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - hdr.Set("CF-Connecting-IP", ip) + hdr.Set(httphdr.CFConnectingIP, ip) // Set the request ID. reqID := agd.NewRequestID() r = r.WithContext(agd.WithRequestID(r.Context(), reqID)) - hdr.Set(agdhttp.HdrNameXRequestID, string(reqID)) + hdr.Set(httphdr.XRequestID, string(reqID)) log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID) diff --git a/internal/websvc/linkip_internal_test.go b/internal/websvc/linkip_internal_test.go index 745fc42..e325474 100644 --- a/internal/websvc/linkip_internal_test.go +++ b/internal/websvc/linkip_internal_test.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,7 +23,7 @@ func TestLinkedIPProxy_ServeHTTP(t *testing.T) { upstream := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { pt := testutil.PanicT{} - rid := r.Header.Get(agdhttp.HdrNameXRequestID) + rid := r.Header.Get(httphdr.XRequestID) require.NotEmpty(pt, rid) numReq.Add(1) @@ -103,7 +104,7 @@ func TestLinkedIPProxy_ServeHTTP(t *testing.T) { assert.Equal(t, prev+tc.diff, numReq.Load(), "req was not expected") assert.Equal(t, tc.wantCode, rw.Code) - assert.Equal(t, expectedUserAgent, rw.Header().Get(agdhttp.HdrNameServer)) + assert.Equal(t, expectedUserAgent, rw.Header().Get(httphdr.Server)) }) } } diff --git a/internal/websvc/websvc_test.go b/internal/websvc/websvc_test.go index 8b663ef..1b8309b 100644 --- a/internal/websvc/websvc_test.go +++ b/internal/websvc/websvc_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" "github.com/AdguardTeam/AdGuardDNS/internal/websvc" + "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -142,7 +143,7 @@ func assertContent(t *testing.T, addr netip.AddrPort, path string, status int, e assert.Equal(t, expected, body) assert.Equal(t, status, resp.StatusCode) - assert.Equal(t, agdhttp.UserAgent(), resp.Header.Get(agdhttp.HdrNameServer)) + assert.Equal(t, agdhttp.UserAgent(), resp.Header.Get(httphdr.Server)) } func startService(t *testing.T, c *websvc.Config) { diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 33baede..9c5b2b4 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -35,7 +35,7 @@ set -f -u go_version="$( "${GO:-go}" version )" readonly go_version -go_min_version='go1.20.2' +go_min_version='go1.20.5' go_version_msg=" warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}). if you have the version installed, please set the GO environment variable. @@ -80,6 +80,13 @@ esac # # * Package golang.org/x/net/context has been moved into stdlib. # +# NOTE: For AdGuard DNS, there are the following exceptions: +# +# * internal/profiledb/internal/filecachepb/filecache.pb.go: a file generated +# by the protobuf compiler. +# +# * internal/profiledb/internal/filecachepb/unsafe.go: a “safe” unsafe helper +# to prevent excessive allocations. blocklist_imports() { git grep\ -e '[[:space:]]"errors"$'\ @@ -91,6 +98,8 @@ blocklist_imports() { -e '[[:space:]]"golang.org/x/net/context"$'\ -n\ -- '*.go'\ + ':!internal/profiledb/internal/filecachepb/filecache.pb.go'\ + ':!internal/profiledb/internal/filecachepb/unsafe.go'\ | sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 blocked import:\2/'\ || exit 0 } @@ -158,7 +167,8 @@ run_linter "$GO" vet ./... "$dnssrvmod" run_linter govulncheck ./... "$dnssrvmod" -run_linter gocyclo --over 10 . +# NOTE: For AdGuard DNS, ignore the generated protobuf file. +run_linter gocyclo --ignore '\.pb\.go$' --over 10 . run_linter ineffassign ./... "$dnssrvmod" @@ -174,7 +184,28 @@ run_linter nilness ./... "$dnssrvmod" # Do not use fieldalignment on $dnssrvmod, because ameshkov likes to place # struct fields in an order that he considers more readable. -run_linter fieldalignment ./... +# +# TODO(a.garipov): Remove the loop once golang/go#60509 is fixed. +( + run_linter fieldalignment ./main.go + + set +f + for d in ./internal/*/ ./internal/*/*/ ./internal/*/*/*/ + do + case "$d" + in + (*/testdata/*|\ + ./internal/dnsserver/*|\ + ./internal/profiledb/internal/filecachepb/|\ + ./internal/tools/) + continue + ;; + (*) + run_linter fieldalignment "$d" + ;; + esac + done +) run_linter -e shadow --strict ./... "$dnssrvmod" diff --git a/scripts/make/go-tools.sh b/scripts/make/go-tools.sh index d934458..8ec6135 100644 --- a/scripts/make/go-tools.sh +++ b/scripts/make/go-tools.sh @@ -44,6 +44,7 @@ rm -f\ bin/looppointer\ bin/misspell\ bin/nilness\ + bin/protoc-gen-go\ bin/shadow\ bin/staticcheck\ bin/unparam\ @@ -71,6 +72,7 @@ env\ golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness\ golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow\ golang.org/x/vuln/cmd/govulncheck\ + google.golang.org/protobuf/cmd/protoc-gen-go\ honnef.co/go/tools/cmd/staticcheck\ mvdan.cc/gofumpt\ mvdan.cc/unparam\ diff --git a/scripts/make/txt-lint.sh b/scripts/make/txt-lint.sh index 48b926a..6edc8bf 100644 --- a/scripts/make/txt-lint.sh +++ b/scripts/make/txt-lint.sh @@ -27,6 +27,26 @@ set -f -u # Source the common helpers, including not_found. . ./scripts/make/helper.sh +# trailing_newlines is a simple check that makes sure that all plain-text files +# have a trailing newlines to make sure that all tools work correctly with them. +trailing_newlines() { + nl="$( printf "\n" )" + readonly nl + + # NOTE: Adjust for your project. + git ls-files\ + ':!*.mmdb'\ + | while read -r f + do + if [ "$( tail -c -1 "$f" )" != "$nl" ] + then + printf '%s: must have a trailing newline\n' "$f" + fi + done +} + +run_linter -e trailing_newlines + git ls-files -- '*.md' '*.yaml' '*.yml'\ | xargs misspell --error\ | sed -e 's/^/misspell: /'