diff --git a/.gitignore b/.gitignore
index a52cf5c..57dd421 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,7 +14,7 @@
/github-mirror/
AdGuardDNS
asn.mmdb
-config.yml
+config.yaml
country.mmdb
dnsdb.bolt
querylog.jsonl
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a222fcf..982a33f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,117 @@ The format is **not** based on [Keep a Changelog][kec], since the project
+## AGDNS-1498 / Build 527
+
+ * Object `ratelimit` has a new property, `connection_limit`, which allows
+ setting stream-connection limits. Example configuration:
+
+ ```yaml
+ ratelimit:
+ # …
+ connection_limit:
+ enabled: true
+ stop: 1000
+ resume: 800
+ ```
+
+
+
+## AGDNS-1383 / Build 525
+
+ * The environment variable `PROFILES_CACHE_PATH` is now sensitive to the file
+ extension. Use `.json` for the previous behavior of encoding the cache into
+ a JSON file or `.pb` for encoding it into protobuf. Other extensions are
+ invalid.
+
+
+
+## AGDNS-1381 / Build 518
+
+ * The new object `network` has been added:
+
+ ```yaml
+ network:
+ so_sndbuf: 0
+ so_rcvbuf: 0
+ ```
+
+
+
+## AGDNS-1383 / Build 515
+
+ * The environment variable `PROFILES_CACHE_PATH` now has a new special value,
+ `none`, which disables profile caching entirely. The default value of
+ `./profilecache.json` has not been changed.
+
+
+
+## AGDNS-1479 / Build 513
+
+ * The profile-cache version has been changed to `6`. Versions of the profile
+ cache from `3` to `5` are invalid and should not be reused.
+
+
+
+## AGDNS-1473 / Build 506
+
+ * The profile-cache version has been changed to `5`.
+
+
+
+## AGDNS-1247 / Build 484
+
+ * The new object `interface_listeners` has been added:
+
+ ```yaml
+ interface_listeners:
+ channel_buffer_size: 1000
+ list:
+ eth0_plain_dns:
+ interface: 'eth0'
+ port': 53
+ eth0_plain_dns_secondary:
+ interface: 'eth0'
+ port': 5353
+ ```
+
+ * The objects within the `server_groups.*.servers` array have a new optional
+ property, `bind_interfaces`:
+
+ ```yaml
+ server_groups:
+ -
+ # …
+ servers:
+ - name: 'default_dns'
+ # …
+ bind_interfaces:
+ - id: 'eth0_plain_dns'
+ subnet: '127.0.0.0/8'
+ - id: 'eth0_plain_dns_secondary'
+ subnet: '127.0.0.0/8'
+ ```
+
+ It is mutually exclusive with the current `bind_addresses` field.
+
+
+
+## AGDNS-1406 / Build 480
+
+ * The default behavior of the environment variable `DNSDB_PATH` has been
+ changed. Previously, if the variable was unset then the default value,
+ `./dnsdb.bolt`, was used, but if it was an empty string, DNSDB was disabled.
+ Now both unset and empty value disable DNSDB, which is consistent with the
+ documentation.
+
+ This means that DNSDB is disabled by default.
+
+ * The default configuration file path has been changed from `config.yml` to
+ ./config.yaml
for consistency with other
+ services.
+
+
+
## AGDNS-916 / Build 456
* `ratelimit` now defines rate of requests per second for IPv4 and IPv6
@@ -181,7 +292,7 @@ The format is **not** based on [Keep a Changelog][kec], since the project
## AGDNS-842 / Build 372
- * The new environment variable `PROFILES_CACHE_PATH` has been added. Its
+ * The new environment variable `PROFILES_CACHE_PATH` has been added. Its
default value is `./profilecache.json`. Adjust the value, if necessary.
@@ -189,7 +300,7 @@ The format is **not** based on [Keep a Changelog][kec], since the project
## AGDNS-891 / Build 371
* The property `server` of `upstream` object has been changed. Now it
- is a URL optionally starting with `tcp://` or `udp://`, and then an address
+ is a URL optionally starting with `tcp://` or `udp://`, and then an address
in `ip:port` format.
```yaml
diff --git a/HACKING.md b/HACKING.md
index 1ce3bde..4e1b89b 100644
--- a/HACKING.md
+++ b/HACKING.md
@@ -1 +1 @@
-See Adguard Home [`HACKING.md`](https://github.com/AdguardTeam/AdGuardHome/blob/master/HACKING.md).
\ No newline at end of file
+See the [Adguard Code Guidelines](https://github.com/AdguardTeam/CodeGuidelines/).
diff --git a/Makefile b/Makefile
index 78e925d..736b97d 100644
--- a/Makefile
+++ b/Makefile
@@ -59,6 +59,9 @@ go-gen:
cd ./internal/agd/ && "$(GO.MACRO)" run ./country_generate.go
cd ./internal/geoip/ && "$(GO.MACRO)" run ./asntops_generate.go
+ cd ./internal/profiledb/internal/filecachepb/ &&\
+ protoc --go_opt=paths=source_relative --go_out=. ./filecache.proto
+
go-check: go-tools go-lint go-test
# A quick check to make sure that all operating systems relevant to the
diff --git a/README.md b/README.md
index 78bbcaf..358f379 100644
--- a/README.md
+++ b/README.md
@@ -86,7 +86,7 @@ following features:
you need that.
[rules system]: https://adguard-dns.io/kb/general/dns-filtering-syntax/
-[API]: https://adguard-dns.io/kb/private-dns/api/
+[API]: https://adguard-dns.io/kb/private-dns/api/overview/
diff --git a/config.dist.yml b/config.dist.yaml
similarity index 89%
rename from config.dist.yml
rename to config.dist.yaml
index 02d83fb..10bc23d 100644
--- a/config.dist.yml
+++ b/config.dist.yaml
@@ -43,6 +43,15 @@ ratelimit:
# Time between two updates of allow list.
refresh_interval: 1h
+ # Configuration for the stream connection limiting.
+ connection_limit:
+ enabled: true
+ # The point at which the limiter stops accepting new connections. Once
+ # the number of active connections reaches this limit, new connections
+ # wait for the number to decrease below resume.
+ stop: 1000
+ resume: 800
+
# DNS cache configuration.
cache:
# The type of cache to use. Can be 'simple' (a simple LRU cache) or 'ecs'
@@ -257,6 +266,21 @@ filtering_groups:
block_private_relay: false
block_firefox_canary: true
+# The configuration for the device-listening feature. Works only on Linux with
+# SO_BINDTODEVICE support.
+interface_listeners:
+ # The size of the buffers of the channels used to dispatch TCP connections
+ # and UDP sessions.
+ channel_buffer_size: 1000
+ # List is the mapping of interface-listener IDs to their configuration.
+ list:
+ 'eth0_plain_dns':
+ interface: 'eth0'
+ port: 53
+ 'eth0_plain_dns_secondary':
+ interface: 'eth0'
+ port: 5353
+
# Server groups and servers.
server_groups:
- name: 'adguard_dns_default'
@@ -302,8 +326,13 @@ server_groups:
# See README for the list of protocol values.
protocol: 'dns'
linked_ip_enabled: true
- bind_addresses:
- - '127.0.0.1:53'
+ # Either bind_interfaces or bind_addresses (see below) can be used for
+ # the plain-DNS servers.
+ bind_interfaces:
+ - id: 'eth0_plain_dns'
+ subnet: '127.0.0.0/8'
+ - id: 'eth0_plain_dns_secondary'
+ subnet: '127.0.0.0/8'
- name: 'default_dot'
protocol: 'tls'
linked_ip_enabled: false
@@ -351,3 +380,12 @@ connectivity_check:
# Additional information to be exposed through metrics.
additional_metrics_info:
test_key: 'test_value'
+
+# Network settings.
+network:
+ # Defines the size of socket send buffer in bytes. Default is zero (uses
+ # system settings).
+ so_sndbuf: 0
+ # Defines the size of socket receive buffer in bytes. Default is zero
+ # (uses system settings).
+ so_rcvbuf: 0
diff --git a/doc/configuration.md b/doc/configuration.md
index 49fdc69..50e631c 100644
--- a/doc/configuration.md
+++ b/doc/configuration.md
@@ -6,9 +6,12 @@ configuration file with comments.
## Contents
- * [Recommended values](#recommended)
+ * [Recommended values and notes](#recommended)
* [Result cache sizes](#recommended-result_cache)
+ * [`SO_RCVBUF` and `SO_SNDBUF` on Linux](#recommended-buffers)
+ * [Connection limiter](#recommended-connection_limit)
* [Rate limiting](#ratelimit)
+ * [Stream connection limit](#ratelimit-connection_limit)
* [Cache](#cache)
* [Upstream](#upstream)
* [Healthcheck](#upstream-healthcheck)
@@ -21,11 +24,13 @@ configuration file with comments.
* [Adult-content blocking](#adult_blocking)
* [Filters](#filters)
* [Filtering groups](#filtering_groups)
+ * [Network interface listeners](#interface_listeners)
* [Server groups](#server_groups)
* [TLS](#server_groups-*-tls)
* [DDR](#server_groups-*-ddr)
* [Servers](#server_groups-*-servers-*)
* [Connectivity check](#connectivity-check)
+ * [Network settings](#network)
* [Additional metrics information](#additional_metrics_info)
[dist]: ../config.dist.yml
@@ -34,7 +39,7 @@ configuration file with comments.
-## Recommended values
+## Recommended values and notes
### Result cache sizes
@@ -59,6 +64,55 @@ from answers, you'll need to multiply the value from the statistic by 5 or 6.
+ ### `SO_RCVBUF` and `SO_SNDBUF` on Linux
+
+On Linux OSs the values for these socket options coming from the configuration
+file (parameters [`network.so_rcvbuf`](#network-so_rcvbuf) and
+[`network.so_sndbuf`](#network-so_sndbuf)) is doubled, and the maximum and
+minimum values are controlled by the values in `/proc/`. See `man 7 socket`:
+
+ > `SO_RCVBUF`
+ >
+ > \[…\] The kernel doubles this value (to allow space for bookkeeping
+ > overhead) when it is set using setsockopt(2), and this doubled value is
+ > returned by getsockopt(2). The default value is set by the
+ > `/proc/sys/net/core/rmem_default` file, and the maximum allowed value is set
+ > by the `/proc/sys/net/core/rmem_max` file. The minimum (doubled) value for
+ > this option is `256`.
+ >
+ > \[…\]
+ >
+ > `SO_SNDBUF`
+ >
+ > \[…\] The default value is set by the `/proc/sys/net/core/wmem_default`
+ > file, and the maximum allowed value is set by the
+ > `/proc/sys/net/core/wmem_max` file. The minimum (doubled) value for this
+ > option is `2048`.
+
+
+
+ ### Stream connection limit
+
+Currently, there are the following recommendations for parameters
+[`ratelimit.connection_limit.stop`](#ratelimit-connection_limit-stop) and
+[`ratelimit.connection_limit.resume`](#ratelimit-connection_limit-resume):
+
+ * `stop` should be about 25 % above the current maximum daily number of used
+ TCP sockets. That is, if the instance currently has a maximum of 100 000
+ TCP sockets in use every day, `stop` should be set to about `125000`.
+
+ * `resume` should be about 20 % above the current maximum daily number of used
+ TCP sockets. That is, if the instance currently has a maximum of 100 000
+ TCP sockets in use every day, `resume` should be set to about `120000`.
+
+**NOTE:** The number of active stream-connections includes sockets that are
+in the process of accepting new connections but have not yet accepted one. That
+means that `resume` should be greater than the number of bound addresses.
+
+These recommendations are to be revised based on the metrics.
+
+
+
## Rate limiting
The `ratelimit` object has the following properties:
@@ -138,6 +192,30 @@ by `ipv4-subnet_key_len`) that made 15 requests in one second or 6 requests
(one above `rps`) every second for 10 seconds within one minute, the client is
blocked for `back_off_duration`.
+ ### Stream connection limit
+
+The `connection_limit` object has the following properties:
+
+ * `enabled`:
+ Whether or not the stream-connection limit should be enforced.
+
+ **Example:** `true`.
+
+ * `stop`:
+ The point at which the limiter stops accepting new connections. Once the
+ number of active connections reaches this limit, new connections wait for
+ the number to decrease to or below `resume`.
+
+ **Example:** `1000`.
+
+ * `resume`:
+ The point at which the limiter starts accepting new connections again after
+ reaching `stop`.
+
+ **Example:** `800`.
+
+See also [notes on these parameters](#recommended-connection_limit).
+
[env-consul_allowlist_url]: environment.md#CONSUL_ALLOWLIST_URL
@@ -660,6 +738,39 @@ The items of the `filtering_groups` array have the following properties:
+## Network interface listeners
+
+**NOTE:** The network interface listening works only on Linux with
+`SO_BINDTODEVICE` support (2.0.30 and later) and properly setup IP routes. See
+the [section on testing `SO_BINDTODEVICE` using Docker][dev-btd].
+
+The `interface_listeners` object has the following properties:
+
+ * `channel_buffer_size`:
+ The size of the buffers of the channels used to dispatch TCP connections and
+ UDP sessions.
+
+ **Example:** `1000`.
+
+ * `list`:
+ The mapping of interface-listener IDs to their configuration.
+
+ **Property example:**
+
+ ```yaml
+ list:
+ 'eth0_plain_dns':
+ interface: 'eth0'
+ port: 53
+ 'eth0_plain_dns_secondary':
+ interface: 'eth0'
+ port: 5353
+ ```
+
+[dev-btd]: development.md#testing-bindtodevice
+
+
+
## Server groups
The items of the `server_groups` array have the following properties:
@@ -829,10 +940,25 @@ The items of the `servers` array have the following properties:
**Example:** `true`.
* `bind_addresses`:
- The array of `ip:port` addresses to listen on.
+ The array of `ip:port` addresses to listen on. If `bind_addresses` is set,
+ `bind_interfaces` (see below) should not be set.
**Example:** `[127.0.0.1:53, 192.168.1.1:53]`.
+ * `bind_interfaces`:
+ The array of [interface listener](#ifl-list) data. If `bind_interfaces` is
+ set, `bind_addresses` (see above) should not be set.
+
+ **Property example:**
+
+ ```yaml
+ 'bind_interfaces':
+ - 'id': eth0_plain_dns'
+ 'subnet': '172.17.0.0/16'
+ - 'id': eth0_plain_dns_secondary'
+ 'subnet': '172.17.0.0/16'
+ ```
+
* `dnscrypt`:
The optional DNSCrypt configuration object. It has the following
properties:
@@ -886,6 +1012,28 @@ The `connectivity_check` object has the following properties:
+## Network settings
+
+The `network` object has the following properties:
+
+ * `so_rcvbuf`:
+ The size of socket receive buffer (`SO_RCVBUF`), in bytes. Default is zero,
+ which means use the default system settings.
+
+ See also [notes on these parameters](#recommended-buffers).
+
+ **Example:** `1048576`.
+
+ * `so_sndbuf`:
+ The size of socket send buffer (`SO_SNDBUF`), in bytes. Default is zero,
+ which means use the default system settings.
+
+ See also [notes on these parameters](#recommended-buffers).
+
+ **Example:** `1048576`.
+
+
+
## Additional metrics information
The `additional_metrics_info` object is a map of strings with extra information
diff --git a/doc/debugdns.md b/doc/debugdns.md
index a2d63ef..16e64fc 100644
--- a/doc/debugdns.md
+++ b/doc/debugdns.md
@@ -87,6 +87,17 @@ In the `ADDITIONAL SECTION`, the following debug information is returned:
```none
asn.adguard-dns.com. 10 CH TXT "1234"
```
+
+ * `subdivision`:
+ User's location subdivision code. This field could be empty even if user's
+ country code is present. The full name is `subdivision.adguard-dns.com`.
+
+ **Example:**
+
+ ```none
+ country.adguard-dns.com. 10 CH TXT "US"
+ subdivision.adguard-dns.com. 10 CH TXT "CA"
+ ```
The following debug records can have one of two prefixes: `req` or `resp`. The
prefix depends on whether the filtering was applied to the request or the
diff --git a/doc/development.md b/doc/development.md
index c31a5e4..b925b2d 100644
--- a/doc/development.md
+++ b/doc/development.md
@@ -68,10 +68,26 @@ This is not an extensive list. See `../Makefile`.
make go-gen
- Regenerate the automatically generated Go files. Those generated files
- are ../internal/agd/country_generate.go
and
- ../internal/geoip/asntops_generate.go
. They need to be
- periodically updated.
+
+ Regenerate the automatically generated Go files that need to be
+ periodically updated. Those generated files are:
+
+
+ -
+
../internal/agd/country_generate.go
;
+
+ -
+
../internal/geoip/asntops_generate.go
;
+
+ -
+
../internal/profiledb/internal/filecachepb/filecache.pb.go
.
+
+
+
+ You'll need to
+ install protoc
+ for the last one.
+
make go-lint
@@ -158,7 +174,7 @@ dnscrypt generate -p testdns -o ./dnscrypt.yml
```sh
cd ../
-cp -f config.dist.yml config.yml
+cp -f config.dist.yaml config.yaml
```
@@ -190,6 +206,7 @@ We'll use the test versions of the GeoIP databases here.
rm -f -r ./test/cache/
mkdir ./test/cache
curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoIP2-Country-Test.mmdb' -o ./test/GeoIP2-Country-Test.mmdb
+curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoIP2-City-Test.mmdb' -o ./test/GeoIP2-City-Test.mmdb
curl 'https://raw.githubusercontent.com/maxmind/MaxMind-DB/main/test-data/GeoLite2-ASN-Test.mmdb' -o ./test/GeoLite2-ASN-Test.mmdb
```
@@ -206,8 +223,8 @@ You'll need to supply the following:
See the [external HTTP API documentation][externalhttp].
-You may need to change the listen ports in `config.yml` which are less than 1024
-to some other ports. Otherwise, `sudo` or `doas` is required to run
+You may need to change the listen ports in `config.yaml` which are less than
+1024 to some other ports. Otherwise, `sudo` or `doas` is required to run
`AdGuardDNS`.
Examples below are for the configuration with the following changes:
@@ -224,14 +241,14 @@ env \
BACKEND_ENDPOINT='PUT BACKEND URL HERE' \
BLOCKED_SERVICE_INDEX_URL='https://atropnikov.github.io/HostlistsRegistry/assets/services.json'\
CONSUL_ALLOWLIST_URL='PUT CONSUL ALLOWLIST URL HERE' \
- CONFIG_PATH='./config.yml' \
+ CONFIG_PATH='./config.yaml' \
DNSDB_PATH='./test/cache/dnsdb.bolt' \
FILTER_INDEX_URL='https://atropnikov.github.io/HostlistsRegistry/assets/filters.json' \
FILTER_CACHE_PATH='./test/cache' \
PROFILES_CACHE_PATH='./test/profilecache.json' \
GENERAL_SAFE_SEARCH_URL='https://adguardteam.github.io/HostlistsRegistry/assets/engines_safe_search.txt' \
GEOIP_ASN_PATH='./test/GeoLite2-ASN-Test.mmdb' \
- GEOIP_COUNTRY_PATH='./test/GeoIP2-Country-Test.mmdb' \
+ GEOIP_COUNTRY_PATH='./test/GeoIP2-City-Test.mmdb' \
QUERYLOG_PATH='./test/cache/querylog.jsonl' \
LISTEN_ADDR='127.0.0.1' \
LISTEN_PORT='8081' \
diff --git a/doc/environment.md b/doc/environment.md
index ffc39e3..f4d8828 100644
--- a/doc/environment.md
+++ b/doc/environment.md
@@ -59,7 +59,7 @@ requirements section][ext-blocked] on the expected format of the response.
The path to the configuration file.
-**Default:** `./config.yml`.
+**Default:** `./config.yaml`.
@@ -122,8 +122,29 @@ The path to the directory with the filter lists cache.
## `PROFILES_CACHE_PATH`
-The path to the profile cache file. The profile cache is read on start and is
-later updated on every [full refresh][conf-backend-full_refresh_interval].
+The path to the profile cache file:
+
+ * `none` means that the profile caching is disabled.
+
+ * A file with the extension `.pb` means that the profiles are cached in the
+ protobuf format.
+
+ Use the following command to inspect the cache, assuming that the version is
+ correct:
+
+ ```sh
+ protoc\
+ --decode\
+ profiledb.FileCache\
+ ./internal/profiledb/internal/filecachepb/filecache.proto\
+ < /path/to/profilecache.pb
+ ```
+
+ * A file with the extension `.json` means that the profiles are cached in the
+ JSON format. This format is **deprecated** and is not recommended.
+
+The profile cache is read on start and is later updated on every
+[full refresh][conf-backend-full_refresh_interval].
**Default:** `./profilecache.json`.
diff --git a/go.mod b/go.mod
index f2a2282..6db877b 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@ go 1.20
require (
github.com/AdguardTeam/AdGuardDNS/internal/dnsserver v0.100.0
- github.com/AdguardTeam/golibs v0.12.1
+ github.com/AdguardTeam/golibs v0.13.2
github.com/AdguardTeam/urlfilter v0.16.1
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc
@@ -19,13 +19,14 @@ require (
github.com/prometheus/client_golang v1.14.0
github.com/prometheus/client_model v0.3.0
github.com/prometheus/common v0.41.0
- github.com/quic-go/quic-go v0.33.0
+ github.com/quic-go/quic-go v0.35.1
github.com/stretchr/testify v1.8.2
go.etcd.io/bbolt v1.3.7
- golang.org/x/exp v0.0.0-20230307190834-24139beb5833
+ golang.org/x/exp v0.0.0-20230321023759-10a507213a29
golang.org/x/net v0.8.0
golang.org/x/sys v0.6.0
golang.org/x/time v0.3.0
+ google.golang.org/protobuf v1.30.0
gopkg.in/yaml.v2 v2.4.0
)
@@ -47,13 +48,12 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
- github.com/quic-go/qtls-go1-19 v0.2.1 // indirect
- github.com/quic-go/qtls-go1-20 v0.1.1 // indirect
+ github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
+ github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/mod v0.9.0 // indirect
golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.7.0 // indirect
- google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 703fbdc..da53b2a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,7 +1,7 @@
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
-github.com/AdguardTeam/golibs v0.12.1 h1:bJfFzCnUCl+QsP6prUltM2Sjt0fTiDBPlxuAwfKP3g8=
-github.com/AdguardTeam/golibs v0.12.1/go.mod h1:rIglKDHdLvFT1UbhumBLHO9S4cvWS9MEyT1njommI/Y=
+github.com/AdguardTeam/golibs v0.13.2 h1:BPASsyQKmb+b8VnvsNOHp7bKfcZl9Z+Z2UhPjOiupSc=
+github.com/AdguardTeam/golibs v0.13.2/go.mod h1:7ylQLv2Lqsc3UW3jHoITynYk6Y1tYtgEMkR09ppfsN8=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
@@ -91,12 +91,12 @@ github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJf
github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
-github.com/quic-go/qtls-go1-19 v0.2.1 h1:aJcKNMkH5ASEJB9FXNeZCyTEIHU1J7MmHyz1Q1TSG1A=
-github.com/quic-go/qtls-go1-19 v0.2.1/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
-github.com/quic-go/qtls-go1-20 v0.1.1 h1:KbChDlg82d3IHqaj2bn6GfKRj84Per2VGf5XV3wSwQk=
-github.com/quic-go/qtls-go1-20 v0.1.1/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
-github.com/quic-go/quic-go v0.33.0 h1:ItNoTDN/Fm/zBlq769lLJc8ECe9gYaW40veHCCco7y0=
-github.com/quic-go/quic-go v0.33.0/go.mod h1:YMuhaAV9/jIu0XclDXwZPAsP/2Kgr5yMYhe9oxhhOFA=
+github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
+github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
+github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
+github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
+github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
+github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/shirou/gopsutil/v3 v3.21.8 h1:nKct+uP0TV8DjjNiHanKf8SAuub+GNsbrOtM9Nl9biA=
github.com/shirou/gopsutil/v3 v3.21.8/go.mod h1:YWp/H8Qs5fVmf17v7JNZzA0mPJ+mS2e9JdiUF9LlKzQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -120,8 +120,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
-golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
-golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
+golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
+golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
@@ -170,8 +170,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
-google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
-google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
+google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
diff --git a/go.work.sum b/go.work.sum
index 820e538..a688d3a 100644
--- a/go.work.sum
+++ b/go.work.sum
@@ -1,17 +1,27 @@
+cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo=
cloud.google.com/go v0.65.0 h1:Dg9iHVQfrhq82rUNu9ZxUDrJLaxFUe/HlCVaLyRruq8=
cloud.google.com/go/bigquery v1.8.0 h1:PQcPefKFdaIzjQFbiyOgAqyx8q5djaE7x9Sqe712DPA=
cloud.google.com/go/datastore v1.1.0 h1:/May9ojXjRkPBNVrq+oWLqmWCkr4OU5uRY29bu0mRyQ=
cloud.google.com/go/pubsub v1.3.1 h1:ukjixP1wl0LpnZ6LWtZJ0mX5tBmjp1f8Sqer8Z2OMUU=
cloud.google.com/go/storage v1.10.0 h1:STgFzyU5/8miMl0//zKh2aQeTyeaUH3WN9bSUiJ09bA=
dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3 h1:hJiie5Bf3QucGRa4ymsAUOxyhYwGEz1xrsVk0P8erlw=
+dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9 h1:VpgP7xuJadIUuKccphEpTJnWhS2jkQyMt6Y7pJCD7fY=
dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0 h1:SPOUaucgtVls75mg+X7CXigS71EnsfVUK/2CgVrwqgw=
+dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU=
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412 h1:GvWw74lx5noHocd+f6HBMXK6DuggBB1dhVkuGZbv7qM=
+dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c h1:ivON6cwHK1OH26MZyWDCnbTRZZf0IhNsENoNAKFS1g4=
+dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999 h1:OR8VhtwhcAI3U48/rzBsVOuHi0zDPzYI1xASVcdSgR8=
+git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/AdguardTeam/golibs v0.10.7/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
+github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc=
github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 h1:sR+/8Yb4slttB4vD+b9btVEnWgL3Q00OBTzVT8B9C0c=
@@ -22,39 +32,58 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 h1:kFOfPq6dUM1hTo4JG6LR5AXSUEsOjtdm0kw0FtQtMJA=
+github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
+github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625 h1:ckJgFhFWywOx+YLEMIJsTb+NV6NexWICk5+AMSuz3ss=
+github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23 h1:D21IyuvjDCshj1/qq+pCNd3VZOAEI9jy6Bi131YlXgI=
+github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=
+github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY=
+github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8=
+github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/readline v1.5.0 h1:lSwwFrbNviGePhkewF1az4oLmcwqCZijQ2/Wi3BGHAI=
github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic=
+github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8=
+github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
+github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
+github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=
github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI=
+github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 h1:clC1lXBpe2kTj2VHdaIu9ajZQe4kcEY9j0NsnDDBZ3o=
github.com/envoyproxy/go-control-plane v0.9.4 h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=
github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/flosch/pongo2/v4 v4.0.2 h1:gv+5Pe3vaSVmiJvh/BZa82b7/00YUGm0PIyVVLop0Hw=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
+github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk=
+github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY=
+github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU=
github.com/getsentry/sentry-go v0.13.0 h1:20dgTiUSfxRB/EhMPtxcL9ZEbM1ZdR+W/7f7NWD+xWo=
github.com/getsentry/sentry-go v0.13.0/go.mod h1:EOsfu5ZdvKPfeHYV6pTVQnsjfp30+XA7//UooKNumH0=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
+github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8=
github.com/gliderlabs/ssh v0.1.1 h1:j3L6gSLQalDETeEg/Jg0mGY0/y/N6zI2xX1978P0Uqw=
+github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
+github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1 h1:QbL/5oDUmRBzO9/Z7Seo6zf912W/a6Sr4Eu0G/3Jho0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4 h1:WtGNWLvXpe6ZudgnXrq0barxBImvnnJoMEhXAzcbM0I=
github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk=
@@ -68,42 +97,65 @@ github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJ
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk=
github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo=
+github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
+github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7 h1:2hRPrmiwPrp3fQX967rNJIhQPtiGXdlQWAxKbKw3VHA=
+github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
+github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
+github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
+github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
+github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no=
+github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
+github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99 h1:Ak8CrdlwwXwAZxzS66vgPt4U8yUZX7JwLvVR58FN5jM=
+github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/googleapis/gax-go v2.0.0+incompatible h1:j0GKcs05QVmm7yesiZq2+9cxHkNK9YM6zKx4D2qucQU=
+github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
+github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg=
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
+github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=
+github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/grpc-ecosystem/grpc-gateway v1.5.0 h1:WcmKMm43DR7RdtlkEXQJyo5ws8iTp98CyhCCbOHMvNI=
+github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw=
github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6 h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639 h1:mV02weKRL81bEnm8A0HT1/CAelMQDBuQIfLw8n+d6xI=
+github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2 h1:rcanfLhLDA8nozr/K289V1zcntHr3V+SHlXwzz1ZI2g=
github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
+github.com/ianlancetaylor/demangle v0.0.0-20220517205856-0058ec4f073c/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/influxdata/influxdb v1.7.6 h1:8mQ7A/V+3noMGCt/P9pD09ISaiz9XvgCk303UYA3gcs=
github.com/iris-contrib/jade v1.1.4 h1:WoYdfyJFfZIUgqNAeOyRfTNQZOksSlZ6+FnXR3AEpX0=
github.com/iris-contrib/schema v0.0.6 h1:CPSBLyx2e91H2yJzPuhGuifVRnZBBJ3pCOMbOvPZaTw=
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1 h1:ujPKutqRlJtcfWk6toYVYagwra7HQHbXOaS171b4Tg8=
+github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
+github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
+github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/kataras/blocks v0.0.7 h1:cF3RDY/vxnSRezc7vLFlQFTYXG/yAr1o7WImJuZbzC4=
@@ -113,20 +165,25 @@ github.com/kataras/pio v0.0.11 h1:kqreJ5KOEXGMwHAWHDwIl+mjfNCPhAwZPa8gK7MKlyw=
github.com/kataras/sitemap v0.0.6 h1:w71CRMMKYMJh6LR2wTgnk5hSgjVNB9KL60n5e2KHvLY=
github.com/kataras/tunnel v0.0.4 h1:sCAqWuJV7nPzGrlb0os3j49lk2JhILT0rID38NHNLpA=
github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg=
+github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c=
github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
github.com/kr/pty v1.1.3 h1:/Um6a/ZmD5tF7peoOJ5oN5KMQ0DrGVQSXLNwyckutPk=
+github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/labstack/echo/v4 v4.9.0 h1:wPOF1CE6gvt/kmbMR4dGzWvHMPT+sAEUJOwOTtvITVY=
github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/lucas-clemente/quic-go v0.25.0/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg=
github.com/lucas-clemente/quic-go v0.27.1/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI=
github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+LVb8=
+github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/mailgun/raymond/v2 v2.0.46 h1:aOYHhvTpF5USySJ0o7cpPno/Uh2I5qg2115K25A+Ft4=
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe h1:W/GaMY0y69G4cFlmsC6B9sbuo2fP8OFP1ABjt4kPz+w=
+github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/marten-seemann/qtls-go1-15 v0.1.4/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I=
github.com/marten-seemann/qtls-go1-16 v0.1.4/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk=
@@ -137,54 +194,98 @@ github.com/marten-seemann/qtls-go1-18 v0.1.0/go.mod h1:PUhIQk19LoFt2174H4+an8TYv
github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1 h1:SIYunPjnlXcW+gVfvm0IlSeR5U3WZUOLfVmqg85Go44=
+github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg=
github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 h1:D6paGObi5Wud7xg83MaEFyjxQB1W5bz5d0IFppr+ymk=
+github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab h1:eFXv9Nu1lGbrNbj619aWwZfVF5HBrm9Plte8aNptuTI=
+github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
+github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
+github.com/onsi/ginkgo/v2 v2.8.1/go.mod h1:N1/NbDngAFcSLdyZ+/aYTYGSlq9qMCS/cNKGJjy+csc=
+github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo=
github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM=
github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
+github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw=
github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw=
+github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
+github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
+github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
+github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/rogpeppe/go-internal v1.3.0 h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=
github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo=
+github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/schollz/closestmatch v2.1.0+incompatible h1:Uel2GXEpJqOWBrlyI+oY9LTiyyjYS17cCYRqP13/SHk=
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
+github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4 h1:Fth6mevc5rX7glNLpbAMJnqKlfIkcTjZCSHEeqvKbcI=
+github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48 h1:vabduItPAIz9px5iryD5peyx7O3Ya8TBThapgXim98o=
+github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470 h1:qb9IthCFBmROJ6YBS31BEMeSYjOscSiG+EO+JVNTz64=
+github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0=
github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e h1:MZM7FHLqUHYI0Y/mQAt3d2aYa0SiNms/hFqC9qJYolM=
+github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk=
github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041 h1:llrF3Fs4018ePo4+G/HV/uQUqEI1HMDjCeOf2V6puPc=
+github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ=
github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d h1:Yoy/IzG4lULT6qZg62sVC+qyBL8DQkmD2zv6i7OImrc=
+github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw=
github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c h1:UOk+nlt1BJtTcH15CT7iNO7YVWTfTv/DNwEAQHLIaDQ=
+github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI=
github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b h1:vYEG87HxbU6dXj5npkeulCS96Dtz5xg3jcfCgpcvbIw=
+github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU=
github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20 h1:7pDq9pAMCQgRohFmd25X8hIH8VxmT3TaDm+r9LHxgBk=
+github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag=
github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9 h1:MPblCbqA5+z6XARjScMfz1TqtJC7TuTRj0U9VqIBs6k=
+github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg=
github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50 h1:crYRwvwjdVh1biHzzciFHe8DrZcYrVcZFlJtykhRctg=
+github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw=
github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc h1:eHRtZoIi6n9Wo1uR+RU44C247msLWwyA89hVKwRLkMk=
+github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y=
github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371 h1:SWV2fHctRpRrp49VXJ6UZja7gU9QLHwRpIPBN89SKEo=
+github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg=
github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9 h1:fxoFD0in0/CBzXoyNhMTjvBZYW6ilSnTw7N7y/8vkmM=
+github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q=
github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191 h1:T4wuULTrzCKMFlg3HmKHgXAF8oStFb/+lOIupLV2v+o=
+github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ=
github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241 h1:Y+TeIabU8sJD10Qwd/zMty2/LEaT9GNDaA6nyZf+jgo=
+github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I=
github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122 h1:TQVQrsyNaimGwF7bIhzoVC9QkKm4KsWd8cECGzFx8gI=
+github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0=
github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2 h1:bu666BQci+y4S0tVRVjsHUeRon6vUXmsGBwdowgMrg4=
+github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ=
github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82 h1:LneqU9PHDsg/AkPDU3AkqMxnMYL+imaqkpflHu73us8=
+github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk=
github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95 h1:/vdW8Cb7EXrkqWGufVMES1OH2sU9gKVb2n9/1y5NMBY=
+github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537 h1:YGaxtkYjb8mnTvtufv2LKLwCQu2/C7qFB7UtrOlTWOY=
+github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4=
github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133 h1:JtcyT0rk/9PKOdnKQzuDR+FSjh7SGtJwpgVpfZBRKlQ=
+github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw=
github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E=
+github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e h1:qpG93cPwA5f7s/ZPBJnGOYQNK/vKsaDaseuKT5Asee8=
+github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU=
+github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/tdewolff/minify/v2 v2.12.4 h1:kejsHQMM17n6/gwdw53qsi6lg0TGddZADVyQOz1KMdE=
github.com/tdewolff/parse/v2 v2.6.4 h1:KCkDvNUMof10e3QExio9OPZJT8SbdKojLBumw8YZycQ=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
@@ -193,60 +294,137 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/viant/assertly v0.4.8 h1:5x1GzBaRteIwTr5RAGFVG14uNeRFxVNbXPWrK2qAgpc=
+github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0 h1:6TteTDQ68CjgcCe8wH3D3ZhUQQOJXMTbj/D9rkk2a1k=
+github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/yosssi/ace v0.0.5 h1:tUkIP/BLdKqrlrPwcmH0shwEEhTRHoGnc1wFIWmaBUA=
github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto=
go4.org v0.0.0-20180809161055-417644f6feb5 h1:+hE86LblG4AyDgwMCLTE6FOlM9+qjHSYS+rKqxUVdsM=
+go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=
golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d h1:E2M5QgjZ/Jg+ObCQAudsXxuTsLj7Nl5RV/lZcQZmKSo=
+golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw=
+golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
+golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
+golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
+golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20230306221820-f0f767cdffd6/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
+golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
+golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220516155154-20f960328961/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
+golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852 h1:xYq6+9AtI+xP3M4r0N1hCkHrInHDBohhquRgx9Kk6gI=
+golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw=
+golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
+golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
+golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
+google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
+google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y=
google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
+google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
+google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc=
+google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
+google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8=
+google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
+google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
+google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
+google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
+gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
+gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919 h1:tmXTu+dfa+d9Evp8NpJdgOy6+rt8/x4yG7qPBrtNfLY=
+grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
+honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY=
rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4=
sourcegraph.com/sourcegraph/go-diff v0.5.0 h1:eTiIR0CoWjGzJcnQ3OkhIl/b9GJovq4lSAVRt0ZFEG8=
+sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4 h1:JPJh2pk3+X4lXAkZIk2RuE/7/FoK9maXw+TNPJhVS/c=
+sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
diff --git a/internal/agd/agd_test.go b/internal/agd/agd_test.go
index 960e071..b1395a5 100644
--- a/internal/agd/agd_test.go
+++ b/internal/agd/agd_test.go
@@ -1,11 +1,9 @@
package agd_test
import (
- "net/netip"
"testing"
"time"
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/golibs/testutil"
)
@@ -17,12 +15,3 @@ func TestMain(m *testing.M) {
// testTimeout is the timeout for common test operations.
const testTimeout = 1 * time.Second
-
-// testProfID is the profile ID for tests.
-const testProfID agd.ProfileID = "prof1234"
-
-// testDevID is the device ID for tests.
-const testDevID agd.DeviceID = "dev1234"
-
-// testClientIPv4 is the client IP for tests
-var testClientIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
diff --git a/internal/agd/country_generate.go b/internal/agd/country_generate.go
index db1c993..066ef1f 100644
--- a/internal/agd/country_generate.go
+++ b/internal/agd/country_generate.go
@@ -10,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)
@@ -22,7 +23,7 @@ func main() {
req, err := http.NewRequest(http.MethodGet, csvURL, nil)
check(err)
- req.Header.Add("User-Agent", agdhttp.UserAgent())
+ req.Header.Add(httphdr.UserAgent, agdhttp.UserAgent())
resp, err := c.Do(req)
check(err)
diff --git a/internal/agd/device.go b/internal/agd/device.go
index 221f815..a2b7c26 100644
--- a/internal/agd/device.go
+++ b/internal/agd/device.go
@@ -12,17 +12,25 @@ import (
// Devices
// Device is a device of a device attached to a profile.
+//
+// NOTE: Do not change fields of this structure without incrementing
+// [internal/profiledb/internal.FileCacheVersion].
type Device struct {
// ID is the unique ID of the device.
ID DeviceID
- // LinkedIP, when non-nil, allows AdGuard DNS to identify a device by its IP
- // address when it can only use plain DNS.
- LinkedIP *netip.Addr
+ // LinkedIP, when non-empty, allows AdGuard DNS to identify a device by its
+ // IP address when it can only use plain DNS.
+ LinkedIP netip.Addr
// Name is the human-readable name of the device.
Name DeviceName
+ // DedicatedIPs are the destination (server) IP-addresses dedicated to this
+ // device, if any. A device can use one of these addresses as a DNS server
+ // address for AdGuard DNS to recognize it.
+ DedicatedIPs []netip.Addr
+
// FilteringEnabled defines whether queries from the device should be
// filtered in any way at all.
FilteringEnabled bool
diff --git a/internal/agd/error.go b/internal/agd/error.go
index 0fc315c..6057f70 100644
--- a/internal/agd/error.go
+++ b/internal/agd/error.go
@@ -25,53 +25,6 @@ func (err *ArgumentError) Error() (msg string) {
return fmt.Sprintf("argument %s is invalid: %s", err.Name, err.Message)
}
-// EntityName is the type for names of entities. Currently only used in errors.
-type EntityName string
-
-// Current entity names.
-const (
- EntityNameDevice EntityName = "device"
- EntityNameProfile EntityName = "profile"
-)
-
-// NotFoundError is an error returned by lookup methods when an entity wasn't
-// found.
-//
-// We use separate types that implement a common interface instead of a single
-// structure to reduce allocations.
-type NotFoundError interface {
- error
-
- // EntityName returns the name of the entity that couldn't be found.
- EntityName() (e EntityName)
-}
-
-// DeviceNotFoundError is a NotFoundError returned by lookup methods when
-// a device wasn't found.
-type DeviceNotFoundError struct{}
-
-// type check
-var _ NotFoundError = DeviceNotFoundError{}
-
-// Error implements the NotFoundError interface for DeviceNotFoundError.
-func (DeviceNotFoundError) Error() (msg string) { return "device not found" }
-
-// EntityName implements the NotFoundError interface for DeviceNotFoundError.
-func (DeviceNotFoundError) EntityName() (e EntityName) { return EntityNameDevice }
-
-// ProfileNotFoundError is a NotFoundError returned by lookup methods when
-// a profile wasn't found.
-type ProfileNotFoundError struct{}
-
-// type check
-var _ NotFoundError = ProfileNotFoundError{}
-
-// Error implements the NotFoundError interface for ProfileNotFoundError.
-func (ProfileNotFoundError) Error() (msg string) { return "profile not found" }
-
-// EntityName implements the NotFoundError interface for ProfileNotFoundError.
-func (ProfileNotFoundError) EntityName() (e EntityName) { return EntityNameProfile }
-
// NotACountryError is returned from NewCountry when the string doesn't represent
// a valid country.
type NotACountryError struct {
diff --git a/internal/agd/location.go b/internal/agd/location.go
index 406c15b..a41780b 100644
--- a/internal/agd/location.go
+++ b/internal/agd/location.go
@@ -4,9 +4,19 @@ package agd
// Location represents the GeoIP location data about an IP address.
type Location struct {
- Country Country
+ // Country is the country whose subnets contain the IP address.
+ Country Country
+
+ // Continent is the continent whose subnets contain the IP address.
Continent Continent
- ASN ASN
+
+ // TopSubdivision is the ISO-code of the political subdivision of a country
+ // whose subnets contain the IP address. This field may be empty.
+ TopSubdivision string
+
+ // ASN is the number of the autonomous system whose subnets contain the IP
+ // address.
+ ASN ASN
}
// ASN is the autonomous system number of an IP address.
diff --git a/internal/agd/profile.go b/internal/agd/profile.go
index 1b0696a..913ea3c 100644
--- a/internal/agd/profile.go
+++ b/internal/agd/profile.go
@@ -5,6 +5,7 @@ import (
"math"
"time"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/errors"
)
@@ -15,8 +16,8 @@ import (
// the infrastructure, a profile is also called a “DNS server”. We call it
// profile, because it's less confusing.
//
-// NOTE: Increment [defaultProfileDBCacheVersion] on any change of this
-// structure.
+// NOTE: Do not change fields of this structure without incrementing
+// [internal/profiledb/internal.FileCacheVersion].
//
// TODO(a.garipov): Consider making it closer to the config file and the backend
// response by grouping parental, rule list, and safe browsing settings into
@@ -24,63 +25,107 @@ import (
type Profile struct {
// Parental are the parental settings for this profile. They are ignored if
// FilteringEnabled is set to false.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
Parental *ParentalProtectionSettings
// BlockingMode defines the way blocked responses are constructed.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
BlockingMode dnsmsg.BlockingModeCodec
// ID is the unique ID of this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
ID ProfileID
// UpdateTime shows the last time this profile was updated from the backend.
// This is NOT the time of update in the backend's database, since the
// backend doesn't send this information.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
UpdateTime time.Time
- // Devices are the devices attached to this profile. Every element of the
- // slice must be non-nil.
- Devices []*Device
+ // DeviceIDs are the IDs of devices attached to this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
+ DeviceIDs []DeviceID
// RuleListIDs are the IDs of the filtering rule lists enabled for this
// profile. They are ignored if FilteringEnabled or RuleListsEnabled are
// set to false.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
RuleListIDs []FilterListID
// CustomRules are the custom filtering rules for this profile. They are
// ignored if RuleListsEnabled is set to false.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
CustomRules []FilterRuleText
// FilteredResponseTTL is the time-to-live value used for responses sent to
// the devices of this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
FilteredResponseTTL time.Duration
// FilteringEnabled defines whether queries from devices of this profile
// should be filtered in any way at all.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
FilteringEnabled bool
// SafeBrowsingEnabled defines whether queries from devices of this profile
// should be filtered using the safe browsing filter. Requires
// FilteringEnabled to be set to true.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
SafeBrowsingEnabled bool
// RuleListsEnabled defines whether queries from devices of this profile
// should be filtered using the filtering rule lists in RuleListIDs.
// Requires FilteringEnabled to be set to true.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
RuleListsEnabled bool
// QueryLogEnabled defines whether query logs should be saved for the
// devices of this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
QueryLogEnabled bool
// Deleted shows if this profile is deleted.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
Deleted bool
// BlockPrivateRelay shows if Apple Private Relay queries are blocked for
// requests from all devices in this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
BlockPrivateRelay bool
// BlockFirefoxCanary shows if Firefox canary domain is blocked for
// requests from all devices in this profile.
+ //
+ // NOTE: Do not change fields of this structure without incrementing
+ // [internal/profiledb/internal.FileCacheVersion].
BlockFirefoxCanary bool
}
@@ -163,23 +208,26 @@ type WeeklySchedule [7]DayRange
// ParentalProtectionSchedule is the schedule of a client's parental protection.
// All fields must not be nil.
+//
+// NOTE: Do not change fields of this structure without incrementing
+// [internal/profiledb/internal.FileCacheVersion].
type ParentalProtectionSchedule struct {
// Week is the parental protection schedule for every day of the week.
Week *WeeklySchedule
// TimeZone is the profile's time zone.
- TimeZone *time.Location
+ TimeZone *agdtime.Location
}
// Contains returns true if t is within the allowed schedule.
func (s *ParentalProtectionSchedule) Contains(t time.Time) (ok bool) {
- t = t.In(s.TimeZone)
+ t = t.In(&s.TimeZone.Location)
r := s.Week[int(t.Weekday())]
if r.IsZeroLength() {
return false
}
- day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, s.TimeZone)
+ day := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, &s.TimeZone.Location)
start := day.Add(time.Duration(r.Start) * time.Minute)
end := day.Add(time.Duration(r.End+1)*time.Minute - 1*time.Nanosecond)
@@ -187,6 +235,9 @@ func (s *ParentalProtectionSchedule) Contains(t time.Time) (ok bool) {
}
// ParentalProtectionSettings are the parental protection settings of a profile.
+//
+// NOTE: Do not change fields of this structure without incrementing
+// [internal/profiledb/internal.FileCacheVersion].
type ParentalProtectionSettings struct {
Schedule *ParentalProtectionSchedule
diff --git a/internal/agd/profile_test.go b/internal/agd/profile_test.go
index 7ade079..5c61018 100644
--- a/internal/agd/profile_test.go
+++ b/internal/agd/profile_test.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
@@ -78,7 +79,7 @@ func TestParentalProtectionSchedule_Contains(t *testing.T) {
time.Saturday: agd.ZeroLengthDayRange(),
},
- TimeZone: time.UTC,
+ TimeZone: agdtime.UTC(),
}
// allDaySchedule, 00:00:00 to 23:59:59.
@@ -95,7 +96,7 @@ func TestParentalProtectionSchedule_Contains(t *testing.T) {
time.Saturday: agd.ZeroLengthDayRange(),
},
- TimeZone: time.UTC,
+ TimeZone: agdtime.UTC(),
}
testCases := []struct {
diff --git a/internal/agd/profiledb.go b/internal/agd/profiledb.go
deleted file mode 100644
index fcb1a3d..0000000
--- a/internal/agd/profiledb.go
+++ /dev/null
@@ -1,423 +0,0 @@
-package agd
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/netip"
- "os"
- "sync"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
- "github.com/AdguardTeam/golibs/errors"
- "github.com/AdguardTeam/golibs/log"
-)
-
-// Data Storage
-
-// ProfileDB is the local database of profiles and other data.
-//
-// TODO(a.garipov): move this logic to the backend package.
-type ProfileDB interface {
- ProfileByDeviceID(ctx context.Context, id DeviceID) (p *Profile, d *Device, err error)
- ProfileByIP(ctx context.Context, ip netip.Addr) (p *Profile, d *Device, err error)
-}
-
-// DefaultProfileDB is the default implementation of the ProfileDB interface
-// that can refresh itself from the provided storage.
-type DefaultProfileDB struct {
- // mapsMu protects the deviceToProfile, deviceIDToIP, and ipToDeviceID maps.
- mapsMu *sync.RWMutex
-
- // refreshMu protects syncTime and syncTimeFull. These are only used within
- // Refresh, so this is also basically a refresh serializer.
- refreshMu *sync.Mutex
-
- // storage returns the data for this profiledb.
- storage ProfileStorage
-
- // deviceToProfile maps device IDs to their profiles. It is cleared lazily
- // whenever a device is found to be missing from a profile.
- deviceToProfile map[DeviceID]*Profile
-
- // deviceIDToIP maps device IDs to their linked IP addresses. It is used to
- // take changes in IP address linking into account during refreshes. It is
- // cleared lazily whenever a device is found to be missing from a profile.
- deviceIDToIP map[DeviceID]netip.Addr
-
- // ipToDeviceID maps linked IP addresses to the IDs of their devices. It is
- // cleared lazily whenever a device is found to be missing from a profile.
- ipToDeviceID map[netip.Addr]DeviceID
-
- // syncTime is the time of the last synchronization. It is used in refresh
- // requests to the storage.
- syncTime time.Time
-
- // syncTimeFull is the time of the last full synchronization.
- syncTimeFull time.Time
-
- // cacheFilePath is the path to profiles cache file.
- cacheFilePath string
-
- // fullSyncIvl is the interval between two full synchronizations with the
- // storage
- fullSyncIvl time.Duration
-}
-
-// NewDefaultProfileDB returns a new default profile profiledb. The initial
-// refresh is performed immediately with the constant timeout of 1 minute,
-// beyond which an empty profiledb is returned. db is never nil.
-func NewDefaultProfileDB(
- ds ProfileStorage,
- fullRefreshIvl time.Duration,
- cacheFilePath string,
-) (db *DefaultProfileDB, err error) {
- db = &DefaultProfileDB{
- mapsMu: &sync.RWMutex{},
- refreshMu: &sync.Mutex{},
- storage: ds,
- syncTime: time.Time{},
- syncTimeFull: time.Time{},
- deviceToProfile: make(map[DeviceID]*Profile),
- deviceIDToIP: make(map[DeviceID]netip.Addr),
- ipToDeviceID: make(map[netip.Addr]DeviceID),
- fullSyncIvl: fullRefreshIvl,
- cacheFilePath: cacheFilePath,
- }
-
- err = db.loadProfileCache()
- if err != nil {
- log.Error("profiledb: cache: loading: %s", err)
- }
-
- // initialTimeout defines the maximum duration of the first attempt to load
- // the profiledb.
- const initialTimeout = 1 * time.Minute
-
- ctx, cancel := context.WithTimeout(context.Background(), initialTimeout)
- defer cancel()
-
- log.Info("profiledb: initial refresh")
-
- err = db.Refresh(ctx)
- if err != nil {
- if errors.Is(err, context.DeadlineExceeded) {
- log.Info("profiledb: warning: initial refresh timeout: %s", err)
-
- return db, nil
- }
-
- return nil, fmt.Errorf("initial refresh: %w", err)
- }
-
- log.Info("profiledb: initial refresh succeeded")
-
- return db, nil
-}
-
-// type check
-var _ Refresher = (*DefaultProfileDB)(nil)
-
-// Refresh implements the Refresher interface for *DefaultProfileDB. It updates
-// the internal maps from the data it receives from the storage.
-func (db *DefaultProfileDB) Refresh(ctx context.Context) (err error) {
- startTime := time.Now()
- defer func() {
- metrics.ProfilesSyncTime.SetToCurrentTime()
- metrics.ProfilesCountGauge.Set(float64(len(db.deviceToProfile)))
- metrics.ProfilesSyncDuration.Observe(time.Since(startTime).Seconds())
- metrics.SetStatusGauge(metrics.ProfilesSyncStatus, err)
- }()
-
- reqID := NewRequestID()
- ctx = WithRequestID(ctx, reqID)
-
- defer func() { err = errors.Annotate(err, "req %s: %w", reqID) }()
-
- db.refreshMu.Lock()
- defer db.refreshMu.Unlock()
-
- isFullSync := time.Since(db.syncTimeFull) >= db.fullSyncIvl
- syncTime := db.syncTime
- if isFullSync {
- syncTime = time.Time{}
- }
-
- var resp *PSProfilesResponse
- resp, err = db.storage.Profiles(ctx, &PSProfilesRequest{
- SyncTime: syncTime,
- })
- if err != nil {
- return fmt.Errorf("updating profiles: %w", err)
- }
-
- profiles := resp.Profiles
- devNum := db.setProfiles(profiles)
- log.Debug("profiledb: req %s: got %d profiles with %d devices", reqID, len(profiles), devNum)
- metrics.ProfilesNewCountGauge.Set(float64(len(profiles)))
-
- db.syncTime = resp.SyncTime
- if isFullSync {
- db.syncTimeFull = time.Now()
-
- err = db.saveProfileCache(ctx)
- if err != nil {
- return fmt.Errorf("saving cache: %w", err)
- }
- }
-
- return nil
-}
-
-// profileCache is the structure for profiles db cache.
-type profileCache struct {
- SyncTime time.Time `json:"sync_time"`
- Profiles []*Profile `json:"profiles"`
- Version int `json:"version"`
-}
-
-// saveStorageCache saves profiles data to cache file.
-func (db *DefaultProfileDB) saveProfileCache(ctx context.Context) (err error) {
- log.Info("profiledb: saving profile cache")
-
- var resp *PSProfilesResponse
- resp, err = db.storage.Profiles(ctx, &PSProfilesRequest{
- SyncTime: time.Time{},
- })
-
- if err != nil {
- return err
- }
-
- data := &profileCache{
- Profiles: resp.Profiles,
- Version: defaultProfileDBCacheVersion,
- SyncTime: time.Now(),
- }
-
- cache, err := json.Marshal(data)
- if err != nil {
- return fmt.Errorf("encoding json: %w", err)
- }
-
- err = os.WriteFile(db.cacheFilePath, cache, 0o600)
- if err != nil {
- // Don't wrap the error, because it's informative enough as is.
- return err
- }
-
- log.Info("profiledb: cache: saved %d profiles to %q", len(resp.Profiles), db.cacheFilePath)
-
- return nil
-}
-
-// defaultProfileDBCacheVersion is the version of cached data structure. It's
-// manually incremented on every change in [profileCache] structure.
-const defaultProfileDBCacheVersion = 2
-
-// loadProfileCache loads profiles data from cache file.
-func (db *DefaultProfileDB) loadProfileCache() (err error) {
- log.Info("profiledb: loading cache")
-
- data, err := db.loadStorageCache()
- if err != nil {
- return fmt.Errorf("loading cache: %w", err)
- }
-
- if data == nil {
- log.Info("profiledb: cache is empty")
-
- return nil
- }
-
- if data.Version == defaultProfileDBCacheVersion {
- profiles := data.Profiles
- devNum := db.setProfiles(profiles)
- log.Info("profiledb: cache: got %d profiles with %d devices", len(profiles), devNum)
-
- db.syncTime = data.SyncTime
- db.syncTimeFull = data.SyncTime
- } else {
- log.Info(
- "profiledb: cache version %d is different from %d",
- data.Version,
- defaultProfileDBCacheVersion,
- )
- }
-
- return nil
-}
-
-// loadStorageCache loads data from cache file.
-func (db *DefaultProfileDB) loadStorageCache() (data *profileCache, err error) {
- file, err := os.Open(db.cacheFilePath)
- if err != nil {
- if errors.Is(err, os.ErrNotExist) {
- // File could be deleted or not yet created, go on.
- return nil, nil
- }
-
- return nil, err
- }
- defer func() { err = errors.WithDeferred(err, file.Close()) }()
-
- data = &profileCache{}
- err = json.NewDecoder(file).Decode(data)
- if err != nil {
- return nil, fmt.Errorf("decoding json: %w", err)
- }
-
- return data, nil
-}
-
-// setProfiles adds or updates the data for all profiles.
-func (db *DefaultProfileDB) setProfiles(profiles []*Profile) (devNum int) {
- db.mapsMu.Lock()
- defer db.mapsMu.Unlock()
-
- for _, p := range profiles {
- devNum += len(p.Devices)
- for _, d := range p.Devices {
- db.deviceToProfile[d.ID] = p
- if d.LinkedIP == nil {
- // Delete any records from the device-to-IP map just in case
- // there used to be one.
- delete(db.deviceIDToIP, d.ID)
-
- continue
- }
-
- newIP := *d.LinkedIP
- if prevIP, ok := db.deviceIDToIP[d.ID]; !ok || prevIP != newIP {
- // The IP has changed. Remove the previous records before
- // setting the new ones.
- delete(db.ipToDeviceID, prevIP)
- delete(db.deviceIDToIP, d.ID)
- }
-
- db.ipToDeviceID[newIP] = d.ID
- db.deviceIDToIP[d.ID] = newIP
- }
- }
-
- return devNum
-}
-
-// type check
-var _ ProfileDB = (*DefaultProfileDB)(nil)
-
-// ProfileByDeviceID implements the ProfileDB interface for *DefaultProfileDB.
-func (db *DefaultProfileDB) ProfileByDeviceID(
- ctx context.Context,
- id DeviceID,
-) (p *Profile, d *Device, err error) {
- db.mapsMu.RLock()
- defer db.mapsMu.RUnlock()
-
- return db.profileByDeviceID(ctx, id)
-}
-
-// profileByDeviceID returns the profile and the device by the ID of the device,
-// if found. Any returned errors will have the underlying type of
-// NotFoundError. It assumes that db is currently locked for reading.
-func (db *DefaultProfileDB) profileByDeviceID(
- _ context.Context,
- id DeviceID,
-) (p *Profile, d *Device, err error) {
- // Do not use errors.Annotate here, because it allocates even when the error
- // is nil. Also do not use fmt.Errorf in a defer, because it allocates when
- // a device is not found, which is the most common case.
- //
- // TODO(a.garipov): Find out, why does it allocate and perhaps file an
- // issue about that in the Go issue tracker.
-
- var ok bool
- p, ok = db.deviceToProfile[id]
- if !ok {
- return nil, nil, ProfileNotFoundError{}
- }
-
- for _, pd := range p.Devices {
- if pd.ID == id {
- d = pd
-
- break
- }
- }
-
- if d == nil {
- // Perhaps, the device has been deleted. May happen when the device was
- // found by a linked IP.
- return nil, nil, fmt.Errorf("rechecking devices: %w", DeviceNotFoundError{})
- }
-
- return p, d, nil
-}
-
-// ProfileByIP implements the ProfileDB interface for *DefaultProfileDB. ip
-// must be valid.
-func (db *DefaultProfileDB) ProfileByIP(
- ctx context.Context,
- ip netip.Addr,
-) (p *Profile, d *Device, err error) {
- // Do not use errors.Annotate here, because it allocates even when the error
- // is nil. Also do not use fmt.Errorf in a defer, because it allocates when
- // a device is not found, which is the most common case.
-
- db.mapsMu.RLock()
- defer db.mapsMu.RUnlock()
-
- id, ok := db.ipToDeviceID[ip]
- if !ok {
- return nil, nil, DeviceNotFoundError{}
- }
-
- p, d, err = db.profileByDeviceID(ctx, id)
- if errors.Is(err, DeviceNotFoundError{}) {
- // Probably, the device has been deleted. Remove it from our profiledb
- // in a goroutine, since that requires a write lock.
- go db.removeDeviceByIP(id, ip)
-
- // Go on and return the error.
- }
-
- if err != nil {
- // Don't add the device ID to the error here, since it is already added
- // by profileByDeviceID.
- return nil, nil, fmt.Errorf("profile by linked device id: %w", err)
- }
-
- return p, d, nil
-}
-
-// removeDeviceByIP removes the device with the given ID and linked IP address
-// from the profiledb. It is intended to be used as a goroutine.
-func (db *DefaultProfileDB) removeDeviceByIP(id DeviceID, ip netip.Addr) {
- defer log.OnPanicAndExit("removeDeviceByIP", 1)
-
- db.mapsMu.Lock()
- defer db.mapsMu.Unlock()
-
- delete(db.ipToDeviceID, ip)
- delete(db.deviceIDToIP, id)
- delete(db.deviceToProfile, id)
-}
-
-// ProfileStorage is a storage of data about profiles and other entities.
-type ProfileStorage interface {
- // Profiles returns all profiles known to this particular data storage. req
- // must not be nil.
- Profiles(ctx context.Context, req *PSProfilesRequest) (resp *PSProfilesResponse, err error)
-}
-
-// PSProfilesRequest is the ProfileStorage.Profiles request.
-type PSProfilesRequest struct {
- SyncTime time.Time
-}
-
-// PSProfilesResponse is the ProfileStorage.Profiles response.
-type PSProfilesResponse struct {
- SyncTime time.Time
- Profiles []*Profile
-}
diff --git a/internal/agd/profiledb_test.go b/internal/agd/profiledb_test.go
deleted file mode 100644
index 37020e7..0000000
--- a/internal/agd/profiledb_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package agd_test
-
-import (
- "context"
- "net/netip"
- "path/filepath"
- "testing"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// newDefaultProfileDB returns a new default profile database for tests.
-func newDefaultProfileDB(tb testing.TB, dev *agd.Device) (db *agd.DefaultProfileDB) {
- tb.Helper()
-
- onProfiles := func(
- _ context.Context,
- _ *agd.PSProfilesRequest,
- ) (resp *agd.PSProfilesResponse, err error) {
- return &agd.PSProfilesResponse{
- Profiles: []*agd.Profile{{
- BlockingMode: dnsmsg.BlockingModeCodec{
- Mode: &dnsmsg.BlockingModeNullIP{},
- },
- ID: testProfID,
- Devices: []*agd.Device{dev},
- }},
- }, nil
- }
-
- ds := &agdtest.ProfileStorage{
- OnProfiles: onProfiles,
- }
-
- cacheFilePath := filepath.Join(tb.TempDir(), "profiles.json")
- db, err := agd.NewDefaultProfileDB(ds, 1*time.Minute, cacheFilePath)
- require.NoError(tb, err)
-
- return db
-}
-
-func TestDefaultProfileDB(t *testing.T) {
- dev := &agd.Device{
- ID: testDevID,
- LinkedIP: &testClientIPv4,
- }
-
- db := newDefaultProfileDB(t, dev)
-
- t.Run("by_device_id", func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
- defer cancel()
-
- p, d, err := db.ProfileByDeviceID(ctx, testDevID)
- require.NoError(t, err)
-
- assert.Equal(t, testProfID, p.ID)
- assert.Equal(t, d, dev)
- })
-
- t.Run("by_ip", func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
- defer cancel()
-
- p, d, err := db.ProfileByIP(ctx, testClientIPv4)
- require.NoError(t, err)
-
- assert.Equal(t, testProfID, p.ID)
- assert.Equal(t, d, dev)
- })
-}
-
-var profSink *agd.Profile
-
-var devSink *agd.Device
-
-var errSink error
-
-func BenchmarkDefaultProfileDB_ProfileByDeviceID(b *testing.B) {
- dev := &agd.Device{
- ID: testDevID,
- }
-
- db := newDefaultProfileDB(b, dev)
- ctx := context.Background()
-
- b.Run("success", func(b *testing.B) {
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- profSink, devSink, errSink = db.ProfileByDeviceID(ctx, testDevID)
- }
-
- assert.NotNil(b, profSink)
- assert.NotNil(b, devSink)
- assert.NoError(b, errSink)
- })
-
- const wrongDevID = testDevID + "_bad"
-
- b.Run("not_found", func(b *testing.B) {
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- profSink, devSink, errSink = db.ProfileByDeviceID(ctx, wrongDevID)
- }
-
- assert.Nil(b, profSink)
- assert.Nil(b, devSink)
- assert.ErrorAs(b, errSink, new(agd.NotFoundError))
- })
-}
-
-func BenchmarkDefaultProfileDB_ProfileByIP(b *testing.B) {
- dev := &agd.Device{
- ID: testDevID,
- LinkedIP: &testClientIPv4,
- }
-
- db := newDefaultProfileDB(b, dev)
- ctx := context.Background()
-
- b.Run("success", func(b *testing.B) {
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- profSink, devSink, errSink = db.ProfileByIP(ctx, testClientIPv4)
- }
-
- assert.NotNil(b, profSink)
- assert.NotNil(b, devSink)
- assert.NoError(b, errSink)
- })
-
- wrongClientIP := netip.MustParseAddr("5.6.7.8")
-
- b.Run("not_found", func(b *testing.B) {
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- profSink, devSink, errSink = db.ProfileByIP(ctx, wrongClientIP)
- }
-
- assert.Nil(b, profSink)
- assert.Nil(b, devSink)
- assert.ErrorAs(b, errSink, new(agd.NotFoundError))
- })
-}
diff --git a/internal/agd/server.go b/internal/agd/server.go
index 1211779..c010dab 100644
--- a/internal/agd/server.go
+++ b/internal/agd/server.go
@@ -102,8 +102,6 @@ type Server struct {
// ServerBindData are the socket binding data for a server. Either AddrPort or
// ListenConfig with Address must be set.
//
-// TODO(a.garipov): Add support for ListenConfig in the config file.
-//
// TODO(a.garipov): Consider turning this into a sum type.
//
// TODO(a.garipov): Consider renaming this and the one in websvc to something
diff --git a/internal/agd/upstream.go b/internal/agd/upstream.go
deleted file mode 100644
index 77952be..0000000
--- a/internal/agd/upstream.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package agd
-
-import (
- "net/netip"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
-)
-
-// Upstream
-
-// Upstream module configuration.
-type Upstream struct {
- // Server is the upstream server we're using to forward DNS queries.
- Server netip.AddrPort
-
- // Network is the Server network protocol.
- Network forward.Network
-
- // FallbackServers is a list of the DNS servers we're using to fallback to
- // when the upstream server fails to respond.
- FallbackServers []netip.AddrPort
-
- // Timeout is the timeout for all outgoing DNS requests.
- Timeout time.Duration
-}
diff --git a/internal/agdhttp/agdhttp.go b/internal/agdhttp/agdhttp.go
index 53dc6b1..b0a51a8 100644
--- a/internal/agdhttp/agdhttp.go
+++ b/internal/agdhttp/agdhttp.go
@@ -12,20 +12,6 @@ import (
// Common Constants, Functions And Types
-// HTTP header name constants.
-const (
- HdrNameAcceptEncoding = "Accept-Encoding"
- HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
- HdrNameContentType = "Content-Type"
- HdrNameContentEncoding = "Content-Encoding"
- HdrNameServer = "Server"
- HdrNameTrailer = "Trailer"
- HdrNameUserAgent = "User-Agent"
-
- HdrNameXError = "X-Error"
- HdrNameXRequestID = "X-Request-Id"
-)
-
// HTTP header value constants.
const (
HdrValApplicationJSON = "application/json"
diff --git a/internal/agdhttp/client.go b/internal/agdhttp/client.go
index a155756..3c46d66 100644
--- a/internal/agdhttp/client.go
+++ b/internal/agdhttp/client.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/golibs/httphdr"
)
// Client is a wrapper around http.Client.
@@ -85,15 +86,15 @@ func (c *Client) do(
}
if contentType != "" {
- req.Header.Set(HdrNameContentType, contentType)
+ req.Header.Set(httphdr.ContentType, contentType)
}
reqID, ok := agd.RequestIDFromContext(ctx)
if ok {
- req.Header.Set(HdrNameXRequestID, string(reqID))
+ req.Header.Set(httphdr.XRequestID, string(reqID))
}
- req.Header.Set(HdrNameUserAgent, c.userAgent)
+ req.Header.Set(httphdr.UserAgent, c.userAgent)
resp, err = c.http.Do(req)
if err != nil && resp != nil && resp.Header != nil {
diff --git a/internal/agdhttp/error.go b/internal/agdhttp/error.go
index e32e57c..ed22a3a 100644
--- a/internal/agdhttp/error.go
+++ b/internal/agdhttp/error.go
@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
)
// Common HTTP Errors
@@ -40,7 +41,7 @@ func CheckStatus(resp *http.Response, expected int) (err error) {
}
return &StatusError{
- ServerName: resp.Header.Get(HdrNameServer),
+ ServerName: resp.Header.Get(httphdr.Server),
Expected: expected,
Got: resp.StatusCode,
}
@@ -73,6 +74,6 @@ func (err *ServerError) Unwrap() (unwrapped error) {
func WrapServerError(err error, resp *http.Response) (wrapped *ServerError) {
return &ServerError{
Err: err,
- ServerName: resp.Header.Get(HdrNameServer),
+ ServerName: resp.Header.Get(httphdr.Server),
}
}
diff --git a/internal/agdhttp/error_test.go b/internal/agdhttp/error_test.go
index a88f501..f9666ca 100644
--- a/internal/agdhttp/error_test.go
+++ b/internal/agdhttp/error_test.go
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
@@ -41,7 +42,7 @@ func TestCheckStatus(t *testing.T) {
resp := &http.Response{
StatusCode: tc.got,
Header: http.Header{
- agdhttp.HdrNameServer: []string{tc.srv},
+ httphdr.Server: []string{tc.srv},
},
}
err := agdhttp.CheckStatus(resp, tc.exp)
@@ -73,7 +74,7 @@ func TestServerError(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
resp := &http.Response{
Header: http.Header{
- agdhttp.HdrNameServer: []string{tc.srv},
+ httphdr.Server: []string{tc.srv},
},
}
err := agdhttp.WrapServerError(tc.err, resp)
diff --git a/internal/agdnet/agdnet_example_test.go b/internal/agdnet/agdnet_example_test.go
index 2f9260b..aa9dcf2 100644
--- a/internal/agdnet/agdnet_example_test.go
+++ b/internal/agdnet/agdnet_example_test.go
@@ -7,14 +7,18 @@ import (
)
func ExampleAndroidMetricDomainReplacement() {
+ printResult := func(input string) {
+ fmt.Printf("%-42q: %q\n", input, agdnet.AndroidMetricDomainReplacement(input))
+ }
+
anAndroidDomain := "12345678-dnsotls-ds.metric.gstatic.com."
- fmt.Printf("%-42q: %q\n", anAndroidDomain, agdnet.AndroidMetricDomainReplacement(anAndroidDomain))
+ printResult(anAndroidDomain)
anAndroidDomain = "123456-dnsohttps-ds.metric.gstatic.com."
- fmt.Printf("%-42q: %q\n", anAndroidDomain, agdnet.AndroidMetricDomainReplacement(anAndroidDomain))
+ printResult(anAndroidDomain)
notAndroidDomain := "example.com"
- fmt.Printf("%-42q: %q\n", notAndroidDomain, agdnet.AndroidMetricDomainReplacement(notAndroidDomain))
+ printResult(notAndroidDomain)
// Output:
// "12345678-dnsotls-ds.metric.gstatic.com." : "00000000-dnsotls-ds.metric.gstatic.com."
diff --git a/internal/agdtest/agdtest.go b/internal/agdtest/agdtest.go
index 3c8b0e7..672745c 100644
--- a/internal/agdtest/agdtest.go
+++ b/internal/agdtest/agdtest.go
@@ -10,7 +10,11 @@ import (
// FilteredResponseTTL is the common filtering response TTL for tests. It is
// also used by [NewConstructor].
-const FilteredResponseTTL = 10 * time.Second
+const FilteredResponseTTL = FilteredResponseTTLSec * time.Second
+
+// FilteredResponseTTLSec is the common filtering response TTL for tests, as a
+// number to simplify message creation.
+const FilteredResponseTTLSec = 10
// NewConstructor returns a standard dnsmsg.Constructor for tests.
func NewConstructor() (c *dnsmsg.Constructor) {
diff --git a/internal/agdtest/interface.go b/internal/agdtest/interface.go
index 43b0eb2..3036150 100644
--- a/internal/agdtest/interface.go
+++ b/internal/agdtest/interface.go
@@ -11,9 +11,11 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/netutil"
@@ -22,7 +24,144 @@ import (
// Interface Mocks
//
-// Keep entities in this file in alphabetic order.
+// Keep entities within a module/package in alphabetic order.
+
+// Module std
+
+// Package net
+//
+// TODO(a.garipov): Move these to golibs?
+
+// type check
+var _ net.Conn = (*Conn)(nil)
+
+// Conn is the [net.Conn] for tests.
+type Conn struct {
+ OnClose func() (err error)
+ OnLocalAddr func() (laddr net.Addr)
+ OnRead func(b []byte) (n int, err error)
+ OnRemoteAddr func() (raddr net.Addr)
+ OnSetDeadline func(t time.Time) (err error)
+ OnSetReadDeadline func(t time.Time) (err error)
+ OnSetWriteDeadline func(t time.Time) (err error)
+ OnWrite func(b []byte) (n int, err error)
+}
+
+// Close implements the [net.Conn] interface for *Conn.
+func (c *Conn) Close() (err error) {
+ return c.OnClose()
+}
+
+// LocalAddr implements the [net.Conn] interface for *Conn.
+func (c *Conn) LocalAddr() (laddr net.Addr) {
+ return c.OnLocalAddr()
+}
+
+// Read implements the [net.Conn] interface for *Conn.
+func (c *Conn) Read(b []byte) (n int, err error) {
+ return c.OnRead(b)
+}
+
+// RemoteAddr implements the [net.Conn] interface for *Conn.
+func (c *Conn) RemoteAddr() (raddr net.Addr) {
+ return c.OnRemoteAddr()
+}
+
+// SetDeadline implements the [net.Conn] interface for *Conn.
+func (c *Conn) SetDeadline(t time.Time) (err error) {
+ return c.OnSetDeadline(t)
+}
+
+// SetReadDeadline implements the [net.Conn] interface for *Conn.
+func (c *Conn) SetReadDeadline(t time.Time) (err error) {
+ return c.OnSetReadDeadline(t)
+}
+
+// SetWriteDeadline implements the [net.Conn] interface for *Conn.
+func (c *Conn) SetWriteDeadline(t time.Time) (err error) {
+ return c.OnSetWriteDeadline(t)
+}
+
+// Write implements the [net.Conn] interface for *Conn.
+func (c *Conn) Write(b []byte) (n int, err error) {
+ return c.OnWrite(b)
+}
+
+// type check
+var _ net.Listener = (*Listener)(nil)
+
+// Listener is a [net.Listener] for tests.
+type Listener struct {
+ OnAccept func() (c net.Conn, err error)
+ OnAddr func() (addr net.Addr)
+ OnClose func() (err error)
+}
+
+// Accept implements the [net.Listener] interface for *Listener.
+func (l *Listener) Accept() (c net.Conn, err error) {
+ return l.OnAccept()
+}
+
+// Addr implements the [net.Listener] interface for *Listener.
+func (l *Listener) Addr() (addr net.Addr) {
+ return l.OnAddr()
+}
+
+// Close implements the [net.Listener] interface for *Listener.
+func (l *Listener) Close() (err error) {
+ return l.OnClose()
+}
+
+// type check
+var _ net.PacketConn = (*PacketConn)(nil)
+
+// PacketConn is the [net.PacketConn] for tests.
+type PacketConn struct {
+ OnClose func() (err error)
+ OnLocalAddr func() (laddr net.Addr)
+ OnReadFrom func(b []byte) (n int, addr net.Addr, err error)
+ OnSetDeadline func(t time.Time) (err error)
+ OnSetReadDeadline func(t time.Time) (err error)
+ OnSetWriteDeadline func(t time.Time) (err error)
+ OnWriteTo func(b []byte, addr net.Addr) (n int, err error)
+}
+
+// Close implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) Close() (err error) {
+ return c.OnClose()
+}
+
+// LocalAddr implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) LocalAddr() (laddr net.Addr) {
+ return c.OnLocalAddr()
+}
+
+// ReadFrom implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
+ return c.OnReadFrom(b)
+}
+
+// SetDeadline implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) SetDeadline(t time.Time) (err error) {
+ return c.OnSetDeadline(t)
+}
+
+// SetReadDeadline implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) SetReadDeadline(t time.Time) (err error) {
+ return c.OnSetReadDeadline(t)
+}
+
+// SetWriteDeadline implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) SetWriteDeadline(t time.Time) (err error) {
+ return c.OnSetWriteDeadline(t)
+}
+
+// WriteTo implements the [net.PacketConn] interface for *PacketConn.
+func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
+ return c.OnWriteTo(b, addr)
+}
+
+// Module AdGuardDNS
// type check
var _ agd.ErrorCollector = (*ErrorCollector)(nil)
@@ -39,56 +178,6 @@ func (c *ErrorCollector) Collect(ctx context.Context, err error) {
c.OnCollect(ctx, err)
}
-// type check
-var _ agd.ProfileDB = (*ProfileDB)(nil)
-
-// ProfileDB is an agd.ProfileDB for tests.
-type ProfileDB struct {
- OnProfileByDeviceID func(
- ctx context.Context,
- id agd.DeviceID,
- ) (p *agd.Profile, d *agd.Device, err error)
- OnProfileByIP func(
- ctx context.Context,
- ip netip.Addr,
- ) (p *agd.Profile, d *agd.Device, err error)
-}
-
-// ProfileByDeviceID implements the agd.ProfileDB interface for *ProfileDB.
-func (db *ProfileDB) ProfileByDeviceID(
- ctx context.Context,
- id agd.DeviceID,
-) (p *agd.Profile, d *agd.Device, err error) {
- return db.OnProfileByDeviceID(ctx, id)
-}
-
-// ProfileByIP implements the agd.ProfileDB interface for *ProfileDB.
-func (db *ProfileDB) ProfileByIP(
- ctx context.Context,
- ip netip.Addr,
-) (p *agd.Profile, d *agd.Device, err error) {
- return db.OnProfileByIP(ctx, ip)
-}
-
-// type check
-var _ agd.ProfileStorage = (*ProfileStorage)(nil)
-
-// ProfileStorage is a agd.ProfileStorage for tests.
-type ProfileStorage struct {
- OnProfiles func(
- ctx context.Context,
- req *agd.PSProfilesRequest,
- ) (resp *agd.PSProfilesResponse, err error)
-}
-
-// Profiles implements the agd.ProfileStorage interface for *ProfileStorage.
-func (ds *ProfileStorage) Profiles(
- ctx context.Context,
- req *agd.PSProfilesRequest,
-) (resp *agd.PSProfilesResponse, err error) {
- return ds.OnProfiles(ctx, req)
-}
-
// type check
var _ agd.Refresher = (*Refresher)(nil)
@@ -206,7 +295,7 @@ func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo)
// type check
var _ filter.Interface = (*Filter)(nil)
-// Filter is a filter.Interface for tests.
+// Filter is a [filter.Interface] for tests.
type Filter struct {
OnFilterRequest func(
ctx context.Context,
@@ -218,10 +307,9 @@ type Filter struct {
resp *dns.Msg,
ri *agd.RequestInfo,
) (r filter.Result, err error)
- OnClose func() (err error)
}
-// FilterRequest implements the filter.Interface interface for *Filter.
+// FilterRequest implements the [filter.Interface] interface for *Filter.
func (f *Filter) FilterRequest(
ctx context.Context,
req *dns.Msg,
@@ -230,7 +318,7 @@ func (f *Filter) FilterRequest(
return f.OnFilterRequest(ctx, req, ri)
}
-// FilterResponse implements the filter.Interface interface for *Filter.
+// FilterResponse implements the [filter.Interface] interface for *Filter.
func (f *Filter) FilterResponse(
ctx context.Context,
resp *dns.Msg,
@@ -239,21 +327,36 @@ func (f *Filter) FilterResponse(
return f.OnFilterResponse(ctx, resp, ri)
}
-// Close implements the filter.Interface interface for *Filter.
-func (f *Filter) Close() (err error) {
- return f.OnClose()
+// type check
+var _ filter.HashMatcher = (*HashMatcher)(nil)
+
+// HashMatcher is a [filter.HashMatcher] for tests.
+type HashMatcher struct {
+ OnMatchByPrefix func(
+ ctx context.Context,
+ host string,
+ ) (hashes []string, matched bool, err error)
+}
+
+// MatchByPrefix implements the [filter.HashMatcher] interface for *HashMatcher.
+func (m *HashMatcher) MatchByPrefix(
+ ctx context.Context,
+ host string,
+) (hashes []string, matched bool, err error) {
+ return m.OnMatchByPrefix(ctx, host)
}
// type check
var _ filter.Storage = (*FilterStorage)(nil)
-// FilterStorage is an filter.Storage for tests.
+// FilterStorage is a [filter.Storage] for tests.
type FilterStorage struct {
OnFilterFromContext func(ctx context.Context, ri *agd.RequestInfo) (f filter.Interface)
OnHasListID func(id agd.FilterListID) (ok bool)
}
-// FilterFromContext implements the filter.Storage interface for *FilterStorage.
+// FilterFromContext implements the [filter.Storage] interface for
+// *FilterStorage.
func (s *FilterStorage) FilterFromContext(
ctx context.Context,
ri *agd.RequestInfo,
@@ -261,7 +364,7 @@ func (s *FilterStorage) FilterFromContext(
return s.OnFilterFromContext(ctx, ri)
}
-// HasListID implements the filter.Storage interface for *FilterStorage.
+// HasListID implements the [filter.Storage] interface for *FilterStorage.
func (s *FilterStorage) HasListID(id agd.FilterListID) (ok bool) {
return s.OnHasListID(id)
}
@@ -295,6 +398,73 @@ func (g *GeoIP) Data(host string, ip netip.Addr) (l *agd.Location, err error) {
return g.OnData(host, ip)
}
+// Package profiledb
+
+// type check
+var _ profiledb.Interface = (*ProfileDB)(nil)
+
+// ProfileDB is a [profiledb.Interface] for tests.
+type ProfileDB struct {
+ OnProfileByDeviceID func(
+ ctx context.Context,
+ id agd.DeviceID,
+ ) (p *agd.Profile, d *agd.Device, err error)
+ OnProfileByDedicatedIP func(
+ ctx context.Context,
+ ip netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error)
+ OnProfileByLinkedIP func(
+ ctx context.Context,
+ ip netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error)
+}
+
+// ProfileByDeviceID implements the [profiledb.Interface] interface for
+// *ProfileDB.
+func (db *ProfileDB) ProfileByDeviceID(
+ ctx context.Context,
+ id agd.DeviceID,
+) (p *agd.Profile, d *agd.Device, err error) {
+ return db.OnProfileByDeviceID(ctx, id)
+}
+
+// ProfileByDedicatedIP implements the [profiledb.Interface] interface for
+// *ProfileDB.
+func (db *ProfileDB) ProfileByDedicatedIP(
+ ctx context.Context,
+ ip netip.Addr,
+) (p *agd.Profile, d *agd.Device, err error) {
+ return db.OnProfileByDedicatedIP(ctx, ip)
+}
+
+// ProfileByLinkedIP implements the [profiledb.Interface] interface for
+// *ProfileDB.
+func (db *ProfileDB) ProfileByLinkedIP(
+ ctx context.Context,
+ ip netip.Addr,
+) (p *agd.Profile, d *agd.Device, err error) {
+ return db.OnProfileByLinkedIP(ctx, ip)
+}
+
+// type check
+var _ profiledb.Storage = (*ProfileStorage)(nil)
+
+// ProfileStorage is a profiledb.Storage for tests.
+type ProfileStorage struct {
+ OnProfiles func(
+ ctx context.Context,
+ req *profiledb.StorageRequest,
+ ) (resp *profiledb.StorageResponse, err error)
+}
+
+// Profiles implements the [profiledb.Storage] interface for *ProfileStorage.
+func (s *ProfileStorage) Profiles(
+ ctx context.Context,
+ req *profiledb.StorageRequest,
+) (resp *profiledb.StorageResponse, err error) {
+ return s.OnProfiles(ctx, req)
+}
+
// Package querylog
// type check
@@ -327,6 +497,39 @@ func (s *RuleStat) Collect(ctx context.Context, id agd.FilterListID, text agd.Fi
// Module dnsserver
+// Package netext
+
+var _ netext.ListenConfig = (*ListenConfig)(nil)
+
+// ListenConfig is a [netext.ListenConfig] for tests.
+type ListenConfig struct {
+ OnListen func(ctx context.Context, network, address string) (l net.Listener, err error)
+ OnListenPacket func(
+ ctx context.Context,
+ network string,
+ address string,
+ ) (conn net.PacketConn, err error)
+}
+
+// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
+func (c *ListenConfig) Listen(
+ ctx context.Context,
+ network string,
+ address string,
+) (l net.Listener, err error) {
+ return c.OnListen(ctx, network, address)
+}
+
+// ListenPacket implements the [netext.ListenConfig] interface for
+// *ListenConfig.
+func (c *ListenConfig) ListenPacket(
+ ctx context.Context,
+ network string,
+ address string,
+) (conn net.PacketConn, err error) {
+ return c.OnListenPacket(ctx, network, address)
+}
+
// Package ratelimit
// type check
diff --git a/internal/agdtime/agdtime.go b/internal/agdtime/agdtime.go
new file mode 100644
index 0000000..c0da4c8
--- /dev/null
+++ b/internal/agdtime/agdtime.go
@@ -0,0 +1,63 @@
+// Package agdtime contains time-related utilities.
+package agdtime
+
+import (
+ "encoding"
+ "time"
+
+ "github.com/AdguardTeam/golibs/errors"
+)
+
+// Location is a wrapper around time.Location that can de/serialize itself from
+// and to JSON.
+//
+// TODO(a.garipov): Move to timeutil.
+type Location struct {
+ time.Location
+}
+
+// LoadLocation is a wrapper around [time.LoadLocation] that returns a
+// *Location instead.
+func LoadLocation(name string) (l *Location, err error) {
+ tl, err := time.LoadLocation(name)
+ if err != nil {
+ // Don't wrap the error, because this function is a wrapper.
+ return nil, err
+ }
+
+ return &Location{
+ Location: *tl,
+ }, nil
+}
+
+// UTC returns [time.UTC] as *Location.
+func UTC() (l *Location) {
+ return &Location{
+ Location: *time.UTC,
+ }
+}
+
+// type check
+var _ encoding.TextMarshaler = Location{}
+
+// MarshalText implements the [encoding.TextMarshaler] interface for Location.
+func (l Location) MarshalText() (text []byte, err error) {
+ return []byte(l.String()), nil
+}
+
+var _ encoding.TextUnmarshaler = (*Location)(nil)
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface for
+// *Location.
+func (l *Location) UnmarshalText(b []byte) (err error) {
+ defer func() { err = errors.Annotate(err, "unmarshaling location: %w") }()
+
+ tl, err := time.LoadLocation(string(b))
+ if err != nil {
+ return err
+ }
+
+ l.Location = *tl
+
+ return nil
+}
diff --git a/internal/agdtime/agdtime_example_test.go b/internal/agdtime/agdtime_example_test.go
new file mode 100644
index 0000000..d160558
--- /dev/null
+++ b/internal/agdtime/agdtime_example_test.go
@@ -0,0 +1,41 @@
+package agdtime_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
+)
+
+func ExampleLocation() {
+ var req struct {
+ TimeZone *agdtime.Location `json:"tmz"`
+ }
+
+ l, err := agdtime.LoadLocation("Europe/Brussels")
+ if err != nil {
+ panic(err)
+ }
+
+ req.TimeZone = l
+ buf := &bytes.Buffer{}
+ err = json.NewEncoder(buf).Encode(req)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Print(buf)
+
+ req.TimeZone = nil
+ err = json.NewDecoder(buf).Decode(&req)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("%+v\n", req)
+
+ // Output:
+ // {"tmz":"Europe/Brussels"}
+ // {TimeZone:Europe/Brussels}
+}
diff --git a/internal/backend/profiledb.go b/internal/backend/profiledb.go
index e7b7aba..a61544e 100644
--- a/internal/backend/profiledb.go
+++ b/internal/backend/profiledb.go
@@ -12,10 +12,13 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
+ "golang.org/x/exp/slices"
)
// Profile Storage
@@ -47,7 +50,7 @@ func NewProfileStorage(c *ProfileStorageConfig) (s *ProfileStorage) {
}
}
-// ProfileStorage is the implementation of the [agd.ProfileStorage] interface
+// ProfileStorage is the implementation of the [profiledb.Storage] interface
// that retrieves the profile and device information from the business logic
// backend. It is safe for concurrent use.
//
@@ -61,13 +64,13 @@ type ProfileStorage struct {
}
// type check
-var _ agd.ProfileStorage = (*ProfileStorage)(nil)
+var _ profiledb.Storage = (*ProfileStorage)(nil)
-// Profiles implements the [agd.ProfileStorage] interface for *ProfileStorage.
+// Profiles implements the [profiledb.Storage] interface for *ProfileStorage.
func (s *ProfileStorage) Profiles(
ctx context.Context,
- req *agd.PSProfilesRequest,
-) (resp *agd.PSProfilesResponse, err error) {
+ req *profiledb.StorageRequest,
+) (resp *profiledb.StorageResponse, err error) {
q := url.Values{}
if !req.SyncTime.IsZero() {
syncTimeStr := strconv.FormatInt(req.SyncTime.UnixMilli(), 10)
@@ -127,7 +130,12 @@ type v1SettingsRespSchedule struct {
Friday *[2]timeutil.Duration `json:"fri"`
Saturday *[2]timeutil.Duration `json:"sat"`
Sunday *[2]timeutil.Duration `json:"sun"`
- TimeZone string `json:"tmz"`
+
+ // TimeZone is the tzdata name of the time zone.
+ //
+ // NOTE: Do not use *agdtime.Location here so that lookup failures are
+ // properly mitigated in [v1SettingsRespParental.toInternal].
+ TimeZone string `json:"tmz"`
}
// v1SettingsRespParental is the structure for decoding the settings.*.parental
@@ -146,10 +154,11 @@ type v1SettingsRespParental struct {
// v1SettingsRespDevice is the structure for decoding the settings.devices
// property of the response from the backend.
type v1SettingsRespDevice struct {
- LinkedIP *netip.Addr `json:"linked_ip"`
- ID string `json:"id"`
- Name string `json:"name"`
- FilteringEnabled bool `json:"filtering_enabled"`
+ LinkedIP netip.Addr `json:"linked_ip"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ DedicatedIPs []netip.Addr `json:"dedicated_ips"`
+ FilteringEnabled bool `json:"filtering_enabled"`
}
// v1SettingsRespSettings is the structure for decoding the settings property of
@@ -232,12 +241,12 @@ func (p *v1SettingsRespParental) toInternal(
sch = &agd.ParentalProtectionSchedule{}
// TODO(a.garipov): Cache location lookup results.
- sch.TimeZone, err = time.LoadLocation(psch.TimeZone)
+ sch.TimeZone, err = agdtime.LoadLocation(psch.TimeZone)
if err != nil {
// Report the error and assume UTC.
reportf(ctx, errColl, "settings at index %d: schedule: time zone: %w", settIdx, err)
- sch.TimeZone = time.UTC
+ sch.TimeZone = agdtime.UTC()
}
sch.Week = &agd.WeeklySchedule{}
@@ -330,10 +339,10 @@ func devicesToInternal(
errColl agd.ErrorCollector,
settIdx int,
respDevices []*v1SettingsRespDevice,
-) (devices []*agd.Device) {
+) (devices []*agd.Device, ids []agd.DeviceID) {
l := len(respDevices)
if l == 0 {
- return nil
+ return nil, nil
}
devices = make([]*agd.Device, 0, l)
@@ -344,8 +353,11 @@ func devicesToInternal(
continue
}
+ // TODO(a.garipov): Consider validating uniqueness of linked and
+ // dedicated IPs.
dev := &agd.Device{
LinkedIP: d.LinkedIP,
+ DedicatedIPs: slices.Clone(d.DedicatedIPs),
FilteringEnabled: d.FilteringEnabled,
}
@@ -368,10 +380,11 @@ func devicesToInternal(
continue
}
+ ids = append(ids, dev.ID)
devices = append(devices, dev)
}
- return devices
+ return devices, ids
}
// filterListsToInternal is a helper that converts the filter lists from the
@@ -458,12 +471,12 @@ func (r *v1SettingsResp) toInternal(
// TODO(a.garipov): Here and in other functions, consider just adding the
// error collector to the context.
errColl agd.ErrorCollector,
-) (pr *agd.PSProfilesResponse) {
+) (pr *profiledb.StorageResponse) {
if r == nil {
return nil
}
- pr = &agd.PSProfilesResponse{
+ pr = &profiledb.StorageResponse{
SyncTime: time.Unix(0, r.SyncTime*1_000_000),
Profiles: make([]*agd.Profile, 0, len(r.Settings)),
}
@@ -476,7 +489,7 @@ func (r *v1SettingsResp) toInternal(
continue
}
- devices := devicesToInternal(ctx, errColl, i, s.Devices)
+ devices, deviceIDs := devicesToInternal(ctx, errColl, i, s.Devices)
rlEnabled, ruleLists := filterListsToInternal(ctx, errColl, i, s.RuleLists)
rules := rulesToInternal(ctx, errColl, i, s.CustomRules)
@@ -499,12 +512,14 @@ func (r *v1SettingsResp) toInternal(
sbEnabled := s.SafeBrowsing != nil && s.SafeBrowsing.Enabled
+ pr.Devices = append(pr.Devices, devices...)
+
pr.Profiles = append(pr.Profiles, &agd.Profile{
Parental: parental,
BlockingMode: s.BlockingMode,
ID: id,
UpdateTime: updTime,
- Devices: devices,
+ DeviceIDs: deviceIDs,
RuleListIDs: ruleLists,
CustomRules: rules,
FilteredResponseTTL: fltRespTTL,
diff --git a/internal/backend/profiledb_test.go b/internal/backend/profiledb_test.go
index 521d234..6abcb22 100644
--- a/internal/backend/profiledb_test.go
+++ b/internal/backend/profiledb_test.go
@@ -13,8 +13,10 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
"github.com/AdguardTeam/AdGuardDNS/internal/backend"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -61,7 +63,7 @@ func TestProfileStorage_Profiles(t *testing.T) {
require.NotNil(t, ds)
ctx := context.Background()
- req := &agd.PSProfilesRequest{
+ req := &profiledb.StorageRequest{
SyncTime: syncTime,
}
@@ -81,10 +83,10 @@ func TestProfileStorage_Profiles(t *testing.T) {
// testProfileResp returns profile resp corresponding with testdata.
//
// Keep in sync with the testdata one.
-func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
+func testProfileResp(t *testing.T) (resp *profiledb.StorageResponse) {
t.Helper()
- wantLoc, err := time.LoadLocation("GMT")
+ wantLoc, err := agdtime.LoadLocation("GMT")
require.NoError(t, err)
dayRange := agd.DayRange{
@@ -121,7 +123,7 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
},
}
- want := &agd.PSProfilesResponse{
+ want := &profiledb.StorageResponse{
SyncTime: syncTime,
Profiles: []*agd.Profile{{
Parental: nil,
@@ -130,15 +132,10 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
},
ID: "37f97ee9",
UpdateTime: updTime,
- Devices: []*agd.Device{{
- ID: "118ffe93",
- Name: "Device 1",
- FilteringEnabled: true,
- }, {
- ID: "b9e1a762",
- Name: "Device 2",
- FilteringEnabled: true,
- }},
+ DeviceIDs: []agd.DeviceID{
+ "118ffe93",
+ "b9e1a762",
+ },
RuleListIDs: []agd.FilterListID{"1"},
CustomRules: nil,
FilteredResponseTTL: 10 * time.Second,
@@ -154,24 +151,12 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
BlockingMode: wantBlockingMode,
ID: "83f3ea8f",
UpdateTime: updTime,
- Devices: []*agd.Device{{
- ID: "0d7724fa",
- Name: "Device 1",
- FilteringEnabled: true,
- }, {
- ID: "6d2ac775",
- Name: "Device 2",
- FilteringEnabled: true,
- }, {
- ID: "94d4c481",
- Name: "Device 3",
- FilteringEnabled: true,
- }, {
- ID: "ada436e3",
- LinkedIP: &wantLinkedIP,
- Name: "Device 4",
- FilteringEnabled: true,
- }},
+ DeviceIDs: []agd.DeviceID{
+ "0d7724fa",
+ "6d2ac775",
+ "94d4c481",
+ "ada436e3",
+ },
RuleListIDs: []agd.FilterListID{"1"},
CustomRules: []agd.FilterRuleText{"||example.org^"},
FilteredResponseTTL: 3600 * time.Second,
@@ -183,6 +168,35 @@ func testProfileResp(t *testing.T) *agd.PSProfilesResponse {
BlockPrivateRelay: false,
BlockFirefoxCanary: false,
}},
+ Devices: []*agd.Device{{
+ ID: "118ffe93",
+ Name: "Device 1",
+ FilteringEnabled: true,
+ }, {
+ ID: "b9e1a762",
+ Name: "Device 2",
+ FilteringEnabled: true,
+ }, {
+ ID: "0d7724fa",
+ Name: "Device 1",
+ FilteringEnabled: true,
+ }, {
+ ID: "6d2ac775",
+ Name: "Device 2",
+ FilteringEnabled: true,
+ }, {
+ ID: "94d4c481",
+ Name: "Device 3",
+ DedicatedIPs: []netip.Addr{
+ netip.MustParseAddr("1.2.3.4"),
+ },
+ FilteringEnabled: true,
+ }, {
+ ID: "ada436e3",
+ LinkedIP: wantLinkedIP,
+ Name: "Device 4",
+ FilteringEnabled: true,
+ }},
}
return want
diff --git a/internal/backend/testdata/profiles.json b/internal/backend/testdata/profiles.json
index 6a937f8..26f3dee 100644
--- a/internal/backend/testdata/profiles.json
+++ b/internal/backend/testdata/profiles.json
@@ -50,18 +50,23 @@
{
"id": "6d2ac775",
"name": "Device 2",
+ "linked_ip": null,
"filtering_enabled": true
},
{
"id": "94d4c481",
"name": "Device 3",
+ "linked_ip": "",
+ "dedicated_ips": [
+ "1.2.3.4"
+ ],
"filtering_enabled": true
},
{
"id": "ada436e3",
"name": "Device 4",
- "filtering_enabled": true,
- "linked_ip": "1.2.3.4"
+ "linked_ip": "1.2.3.4",
+ "filtering_enabled": true
}
],
"parental": {
diff --git a/internal/bindtodevice/bindtodevice.go b/internal/bindtodevice/bindtodevice.go
index d1aaa5d..41c19c2 100644
--- a/internal/bindtodevice/bindtodevice.go
+++ b/internal/bindtodevice/bindtodevice.go
@@ -1,40 +1,6 @@
// Package bindtodevice contains an implementation of the [netext.ListenConfig]
// interface that uses Linux's SO_BINDTODEVICE socket option to be able to bind
// to a device.
-//
-// TODO(a.garipov): Finish the package. The current plan is to eventually have
-// something like this:
-//
-// mgr, err := bindtodevice.New()
-// err := mgr.Add("wlp3s0_plain_dns", "wlp3s0", 53)
-// subnet := netip.MustParsePrefix("1.2.3.0/24")
-// lc, err := mgr.ListenConfig("wlp3s0_plain_dns", subnet)
-// err := mgr.Start()
-//
-// Approximate YAML configuration example:
-//
-// 'interface_listeners':
-// # Put listeners into a list so that there is space for future additional
-// # settings, such as timeouts and buffer sizes.
-// 'list':
-// 'iface0_plain_dns':
-// 'interface': 'iface0'
-// 'port': 53
-// 'iface0_plain_dns_secondary':
-// 'interface': 'iface0'
-// 'port': 5353
-// # …
-// # …
-// 'server_groups':
-// # …
-// 'servers':
-// - 'name': 'default_dns'
-// # …
-// bind_interfaces:
-// - 'id': 'iface0_plain_dns'
-// 'subnet': '1.2.3.0/24'
-// - 'id': 'iface0_plain_dns_secondary'
-// 'subnet': '1.2.3.0/24'
package bindtodevice
import (
diff --git a/internal/bindtodevice/bindtodevice_internal_test.go b/internal/bindtodevice/bindtodevice_internal_test.go
index 1925747..d79c0b3 100644
--- a/internal/bindtodevice/bindtodevice_internal_test.go
+++ b/internal/bindtodevice/bindtodevice_internal_test.go
@@ -2,9 +2,13 @@ package bindtodevice
import (
"net"
+ "net/netip"
"time"
)
+// Common timeout for tests
+const testTimeout = 1 * time.Second
+
// Common addresses for tests.
var (
testLAddr = &net.UDPAddr{
@@ -17,5 +21,8 @@ var (
}
)
-// Common timeout for tests
-const testTimeout = 1 * time.Second
+// Common subnets for tests.
+var (
+ testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24")
+ testSubnetIPv6 = netip.MustParsePrefix("1234:5678::/64")
+)
diff --git a/internal/bindtodevice/bindtodevice_test.go b/internal/bindtodevice/bindtodevice_test.go
index 7f70d8c..f138b03 100644
--- a/internal/bindtodevice/bindtodevice_test.go
+++ b/internal/bindtodevice/bindtodevice_test.go
@@ -1,6 +1,16 @@
package bindtodevice_test
-import "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
+import (
+ "net/netip"
+ "testing"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
+ "github.com/AdguardTeam/golibs/testutil"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
// Common interface listener IDs for tests
const (
@@ -18,3 +28,6 @@ const (
// testIfaceName is the common network interface name for tests.
const testIfaceName = "not_a_real_iface0"
+
+// testSubnetIPv4 is a common subnet for tests.
+var testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24")
diff --git a/internal/bindtodevice/chanlistenconfig_linux_internal_test.go b/internal/bindtodevice/chanlistenconfig_linux_internal_test.go
index 3fb0639..c00a924 100644
--- a/internal/bindtodevice/chanlistenconfig_linux_internal_test.go
+++ b/internal/bindtodevice/chanlistenconfig_linux_internal_test.go
@@ -11,8 +11,8 @@ import (
)
func TestChanListenConfig(t *testing.T) {
- pc := newChanPacketConn(nil, nil, testLAddr)
- lsnr := newChanListener(nil, testLAddr)
+ pc := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr)
+ lsnr := newChanListener(nil, testSubnetIPv4, testLAddr)
c := chanListenConfig{
packetConn: pc,
listener: lsnr,
diff --git a/internal/bindtodevice/chanlistener_linux.go b/internal/bindtodevice/chanlistener_linux.go
index 433e63b..503dedc 100644
--- a/internal/bindtodevice/chanlistener_linux.go
+++ b/internal/bindtodevice/chanlistener_linux.go
@@ -4,6 +4,7 @@ package bindtodevice
import (
"net"
+ "net/netip"
"sync"
)
@@ -13,17 +14,22 @@ import (
// Listeners of this type are returned by [chanListenConfig.Listen] and are used
// in module dnsserver to make the bind-to-device logic work in DNS-over-TCP.
type chanListener struct {
- closeOnce *sync.Once
- conns chan net.Conn
- laddr net.Addr
+ // mu protects conns (against closure) and isClosed.
+ mu *sync.Mutex
+ conns chan net.Conn
+ laddr net.Addr
+ subnet netip.Prefix
+ isClosed bool
}
// newChanListener returns a new properly initialized *chanListener.
-func newChanListener(conns chan net.Conn, laddr net.Addr) (l *chanListener) {
+func newChanListener(conns chan net.Conn, subnet netip.Prefix, laddr net.Addr) (l *chanListener) {
return &chanListener{
- closeOnce: &sync.Once{},
- conns: conns,
- laddr: laddr,
+ mu: &sync.Mutex{},
+ conns: conns,
+ laddr: laddr,
+ subnet: subnet,
+ isClosed: false,
}
}
@@ -46,15 +52,30 @@ func (l *chanListener) Addr() (addr net.Addr) { return l.laddr }
// Close implements the [net.Listener] interface for *chanListener.
func (l *chanListener) Close() (err error) {
- closedNow := false
- l.closeOnce.Do(func() {
- close(l.conns)
- closedNow = true
- })
+ l.mu.Lock()
+ defer l.mu.Unlock()
- if !closedNow {
+ if l.isClosed {
return wrapConnError(tnChanLsnr, "Close", l.laddr, net.ErrClosed)
}
+ close(l.conns)
+ l.isClosed = true
+
return nil
}
+
+// send is a helper method to send a conn to the listener's channel. ok is
+// false if the listener is closed.
+func (l *chanListener) send(conn net.Conn) (ok bool) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if l.isClosed {
+ return false
+ }
+
+ l.conns <- conn
+
+ return true
+}
diff --git a/internal/bindtodevice/chanlistener_linux_internal_test.go b/internal/bindtodevice/chanlistener_linux_internal_test.go
index da85c27..7ed4086 100644
--- a/internal/bindtodevice/chanlistener_linux_internal_test.go
+++ b/internal/bindtodevice/chanlistener_linux_internal_test.go
@@ -12,7 +12,7 @@ import (
func TestChanListener_Accept(t *testing.T) {
conns := make(chan net.Conn, 1)
- l := newChanListener(conns, testLAddr)
+ l := newChanListener(conns, testSubnetIPv4, testLAddr)
// A simple way to have a distinct net.Conn without actually implementing
// the entire interface.
@@ -32,14 +32,14 @@ func TestChanListener_Accept(t *testing.T) {
}
func TestChanListener_Addr(t *testing.T) {
- l := newChanListener(nil, testLAddr)
+ l := newChanListener(nil, testSubnetIPv4, testLAddr)
got := l.Addr()
assert.Equal(t, testLAddr, got)
}
func TestChanListener_Close(t *testing.T) {
conns := make(chan net.Conn)
- l := newChanListener(conns, testLAddr)
+ l := newChanListener(conns, testSubnetIPv4, testLAddr)
err := l.Close()
assert.NoError(t, err)
diff --git a/internal/bindtodevice/chanpacketconn_linux.go b/internal/bindtodevice/chanpacketconn_linux.go
index dac6e09..94e3051 100644
--- a/internal/bindtodevice/chanpacketconn_linux.go
+++ b/internal/bindtodevice/chanpacketconn_linux.go
@@ -5,6 +5,7 @@ package bindtodevice
import (
"fmt"
"net"
+ "net/netip"
"os"
"sync"
"time"
@@ -19,32 +20,38 @@ import (
// are used in module dnsserver to make the bind-to-device logic work in
// DNS-over-UDP.
type chanPacketConn struct {
- closeOnce *sync.Once
- sessions chan *packetSession
- laddr net.Addr
+ // mu protects sessions (against closure) and isClosed.
+ mu *sync.Mutex
+ sessions chan *packetSession
+
+ writeRequests chan *packetConnWriteReq
// deadlineMu protects readDeadline and writeDeadline.
deadlineMu *sync.RWMutex
readDeadline time.Time
writeDeadline time.Time
- writeRequests chan *packetConnWriteReq
+ laddr net.Addr
+ subnet netip.Prefix
+ isClosed bool
}
// newChanPacketConn returns a new properly initialized *chanPacketConn.
func newChanPacketConn(
sessions chan *packetSession,
+ subnet netip.Prefix,
writeRequests chan *packetConnWriteReq,
laddr net.Addr,
) (c *chanPacketConn) {
return &chanPacketConn{
- closeOnce: &sync.Once{},
- sessions: sessions,
- laddr: laddr,
+ mu: &sync.Mutex{},
+ sessions: sessions,
+ writeRequests: writeRequests,
deadlineMu: &sync.RWMutex{},
- writeRequests: writeRequests,
+ laddr: laddr,
+ subnet: subnet,
}
}
@@ -70,16 +77,16 @@ var _ netext.SessionPacketConn = (*chanPacketConn)(nil)
// Close implements the [netext.SessionPacketConn] interface for
// *chanPacketConn.
func (c *chanPacketConn) Close() (err error) {
- closedNow := false
- c.closeOnce.Do(func() {
- close(c.sessions)
- closedNow = true
- })
+ c.mu.Lock()
+ defer c.mu.Unlock()
- if !closedNow {
+ if c.isClosed {
return wrapConnError(tnChanPConn, "Close", c.laddr, net.ErrClosed)
}
+ close(c.sessions)
+ c.isClosed = true
+
return nil
}
@@ -289,3 +296,18 @@ func sendWithTimer[T any](ch chan<- T, v T, timerCh <-chan time.Time) (err error
return os.ErrDeadlineExceeded
}
}
+
+// send is a helper method to send a session to the packet connection's channel.
+// ok is false if the listener is closed.
+func (c *chanPacketConn) send(sess *packetSession) (ok bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.isClosed {
+ return false
+ }
+
+ c.sessions <- sess
+
+ return true
+}
diff --git a/internal/bindtodevice/chanpacketconn_linux_internal_test.go b/internal/bindtodevice/chanpacketconn_linux_internal_test.go
index 7ee22b1..a917f9a 100644
--- a/internal/bindtodevice/chanpacketconn_linux_internal_test.go
+++ b/internal/bindtodevice/chanpacketconn_linux_internal_test.go
@@ -14,7 +14,7 @@ import (
func TestChanPacketConn_Close(t *testing.T) {
sessions := make(chan *packetSession)
- c := newChanPacketConn(sessions, nil, testLAddr)
+ c := newChanPacketConn(sessions, testSubnetIPv4, nil, testLAddr)
err := c.Close()
assert.NoError(t, err)
@@ -23,14 +23,14 @@ func TestChanPacketConn_Close(t *testing.T) {
}
func TestChanPacketConn_LocalAddr(t *testing.T) {
- c := newChanPacketConn(nil, nil, testLAddr)
+ c := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr)
got := c.LocalAddr()
assert.Equal(t, testLAddr, got)
}
func TestChanPacketConn_ReadFromSession(t *testing.T) {
sessions := make(chan *packetSession, 1)
- c := newChanPacketConn(sessions, nil, testLAddr)
+ c := newChanPacketConn(sessions, testSubnetIPv4, nil, testLAddr)
body := []byte("hello")
bodyLen := len(body)
@@ -79,7 +79,7 @@ func TestChanPacketConn_ReadFromSession(t *testing.T) {
func TestChanPacketConn_WriteToSession(t *testing.T) {
sessions := make(chan *packetSession, 1)
writes := make(chan *packetConnWriteReq, 1)
- c := newChanPacketConn(sessions, writes, testLAddr)
+ c := newChanPacketConn(sessions, testSubnetIPv4, writes, testLAddr)
body := []byte("hello")
bodyLen := len(body)
@@ -148,7 +148,7 @@ func checkWriteReqAndRespond(
}
func TestChanPacketConn_deadlines(t *testing.T) {
- c := newChanPacketConn(nil, nil, testLAddr)
+ c := newChanPacketConn(nil, testSubnetIPv4, nil, testLAddr)
deadline := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
testCases := []struct {
diff --git a/internal/bindtodevice/chanindex_linux.go b/internal/bindtodevice/connindex_linux.go
similarity index 52%
rename from internal/bindtodevice/chanindex_linux.go
rename to internal/bindtodevice/connindex_linux.go
index 3f4937f..856d376 100644
--- a/internal/bindtodevice/chanindex_linux.go
+++ b/internal/bindtodevice/connindex_linux.go
@@ -4,33 +4,20 @@ package bindtodevice
import (
"fmt"
- "net"
"net/netip"
"golang.org/x/exp/slices"
)
-// chanIndex is the data structure that contains the channels, to which the
-// [Manager] sends new connections and packets based on their protocol (TCP vs.
-// UDP), and subnet.
+// connIndex is the data structure that contains the channel listeners and
+// packet connections, to which the [Manager] sends new connections and packets
+// based on their protocol (TCP vs. UDP), and subnet.
//
// In both slices a subnet with the largest prefix (the narrowest subnet) is
// sorted closer to the beginning.
-type chanIndex struct {
- packetConns []*indexPacketConn
- listeners []*indexListener
-}
-
-// indexPacketConn contains data of a [chanPacketConn] in the index.
-type indexPacketConn struct {
- channel chan *packetSession
- subnet netip.Prefix
-}
-
-// indexListener contains data of a [chanListener] in the index.
-type indexListener struct {
- channel chan net.Conn
- subnet netip.Prefix
+type connIndex struct {
+ packetConns []*chanPacketConn
+ listeners []*chanListener
}
// subnetSortsBefore returns true if subnet x sorts before subnet y.
@@ -58,26 +45,18 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
}
}
-// addPacketConnChannel adds the channel to the subnet index. It returns an
-// error if there is already one for this subnet. subnet should be masked.
+// addPacketConn adds the channel packet connection to the index. It returns an
+// error if there is already one for this subnet. c.subnet should be masked.
//
// TODO(a.garipov): Merge with [addListenerChannel].
-func (idx *chanIndex) addPacketConnChannel(
- subnet netip.Prefix,
- ch chan *packetSession,
-) (err error) {
- c := &indexPacketConn{
- channel: ch,
- subnet: subnet,
- }
-
- cmpFunc := func(x, y *indexPacketConn) (cmp int) {
+func (idx *connIndex) addPacketConn(c *chanPacketConn) (err error) {
+ cmpFunc := func(x, y *chanPacketConn) (cmp int) {
return subnetCompare(x.subnet, y.subnet)
}
newIdx, ok := slices.BinarySearchFunc(idx.packetConns, c, cmpFunc)
if ok {
- return fmt.Errorf("packetconn channel for subnet %s already registered", subnet)
+ return fmt.Errorf("packetconn channel for subnet %s already registered", c.subnet)
}
// TODO(a.garipov): Consider using a list for idx.packetConns. Currently,
@@ -88,23 +67,18 @@ func (idx *chanIndex) addPacketConnChannel(
return nil
}
-// addListenerChannel adds the channel to the subnet index. It returns an error
-// if there is already one for this subnet. subnet should be masked.
+// addListener adds the channel listener to the index. It returns an error if
+// there is already one for this subnet. l.subnet should be masked.
//
// TODO(a.garipov): Merge with [addPacketConnChannel].
-func (idx *chanIndex) addListenerChannel(subnet netip.Prefix, ch chan net.Conn) (err error) {
- l := &indexListener{
- channel: ch,
- subnet: subnet,
- }
-
- cmpFunc := func(x, y *indexListener) (cmp int) {
+func (idx *connIndex) addListener(l *chanListener) (err error) {
+ cmpFunc := func(x, y *chanListener) (cmp int) {
return subnetCompare(x.subnet, y.subnet)
}
newIdx, ok := slices.BinarySearchFunc(idx.listeners, l, cmpFunc)
if ok {
- return fmt.Errorf("listener channel for subnet %s already registered", subnet)
+ return fmt.Errorf("listener channel for subnet %s already registered", l.subnet)
}
// TODO(a.garipov): Consider using a list for idx.listeners. Currently,
@@ -115,24 +89,24 @@ func (idx *chanIndex) addListenerChannel(subnet netip.Prefix, ch chan net.Conn)
return nil
}
-// packetConnChannel returns a packet-connection channel which accepts
-// connections to local address laddr or nil if there is no such channel
-func (idx *chanIndex) packetConnChannel(laddr netip.Addr) (ch chan *packetSession) {
- for _, c := range idx.packetConns {
+// packetConn returns a channel packet connection which accepts connections to
+// local address laddr or nil if there is no such channel
+func (idx *connIndex) packetConn(laddr netip.Addr) (c *chanPacketConn) {
+ for _, c = range idx.packetConns {
if c.subnet.Contains(laddr) {
- return c.channel
+ return c
}
}
return nil
}
-// listenerChannel returns a listener channel which accepts connections to local
+// listener returns a channel listener which accepts connections to local
// address laddr or nil if there is no such channel
-func (idx *chanIndex) listenerChannel(laddr netip.Addr) (ch chan net.Conn) {
- for _, l := range idx.listeners {
+func (idx *connIndex) listener(laddr netip.Addr) (l *chanListener) {
+ for _, l = range idx.listeners {
if l.subnet.Contains(laddr) {
- return l.channel
+ return l
}
}
diff --git a/internal/bindtodevice/chanindex_linux_test.go b/internal/bindtodevice/connindex_linux_test.go
similarity index 100%
rename from internal/bindtodevice/chanindex_linux_test.go
rename to internal/bindtodevice/connindex_linux_test.go
diff --git a/internal/bindtodevice/interfacelistener_linux.go b/internal/bindtodevice/interfacelistener_linux.go
index 197b357..0d2f33e 100644
--- a/internal/bindtodevice/interfacelistener_linux.go
+++ b/internal/bindtodevice/interfacelistener_linux.go
@@ -17,7 +17,7 @@ import (
// interfaceListener contains information about a single interface listener.
type interfaceListener struct {
- channels *chanIndex
+ conns *connIndex
writeRequests chan *packetConnWriteReq
done chan unit
listenConf *net.ListenConfig
@@ -64,14 +64,16 @@ func (l *interfaceListener) listenTCP(errCh chan<- error) {
}
laddr := netutil.NetAddrToAddrPort(conn.LocalAddr())
- ch := l.channels.listenerChannel(laddr.Addr())
- if ch == nil {
+ lsnr := l.conns.listener(laddr.Addr())
+ if lsnr == nil {
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
continue
}
- ch <- conn
+ if !lsnr.send(conn) {
+ log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
+ }
}
}
@@ -120,14 +122,16 @@ func (l *interfaceListener) listenUDP(errCh chan<- error) {
}
laddr := sess.laddr.AddrPort().Addr()
- ch := l.channels.packetConnChannel(laddr)
- if ch == nil {
+ chanPConn := l.conns.packetConn(laddr)
+ if chanPConn == nil {
log.Info("%s: no channel for laddr %s", logPrefix, laddr)
continue
}
- ch <- sess
+ if !chanPConn.send(sess) {
+ log.Info("%s: channel for laddr %s is closed", logPrefix, laddr)
+ }
}
}
diff --git a/internal/bindtodevice/interfacestorage.go b/internal/bindtodevice/interfacestorage.go
new file mode 100644
index 0000000..58fcd76
--- /dev/null
+++ b/internal/bindtodevice/interfacestorage.go
@@ -0,0 +1,77 @@
+package bindtodevice
+
+import (
+ "fmt"
+ "net"
+ "net/netip"
+
+ "github.com/AdguardTeam/golibs/netutil"
+)
+
+// NetInterface represents a network interface (aka device).
+//
+// TODO(a.garipov): Consider moving this and InterfaceStorage to netutil.
+type NetInterface interface {
+ Subnets() (subnets []netip.Prefix, err error)
+}
+
+// type check
+var _ NetInterface = osInterface{}
+
+// osInterface is a wapper around [*net.Interface] that implements the
+// [NetInterface] interface.
+type osInterface struct {
+ iface *net.Interface
+}
+
+// Subnets implements the [NetInterface] interface for osInterface.
+func (osIface osInterface) Subnets() (subnets []netip.Prefix, err error) {
+ name := osIface.iface.Name
+ ifaceAddrs, err := osIface.iface.Addrs()
+ if err != nil {
+ return nil, fmt.Errorf("getting addrs for interface %s: %w", name, err)
+ }
+
+ subnets = make([]netip.Prefix, 0, len(ifaceAddrs))
+ for _, addr := range ifaceAddrs {
+ ipNet, ok := addr.(*net.IPNet)
+ if !ok {
+ return nil, fmt.Errorf("addr for interface %s is %T, not *net.IPNet", name, addr)
+ }
+
+ var subnet netip.Prefix
+ subnet, err = netutil.IPNetToPrefixNoMapped(ipNet)
+ if err != nil {
+ return nil, fmt.Errorf("converting addr for interface %s: %w", name, err)
+ }
+
+ subnets = append(subnets, subnet)
+ }
+
+ return subnets, nil
+}
+
+// InterfaceStorage is the interface for storages of network interfaces (aka
+// devices). Its main implementation is [DefaultInterfaceStorage].
+type InterfaceStorage interface {
+ InterfaceByName(name string) (iface NetInterface, err error)
+}
+
+// type check
+var _ InterfaceStorage = DefaultInterfaceStorage{}
+
+// DefaultInterfaceStorage is the storage that uses the OS's network interfaces.
+type DefaultInterfaceStorage struct{}
+
+// InterfaceByName implements the [InterfaceStorage] interface for
+// DefaultInterfaceStorage.
+func (DefaultInterfaceStorage) InterfaceByName(name string) (iface NetInterface, err error) {
+ netIface, err := net.InterfaceByName(name)
+ if err != nil {
+ return nil, fmt.Errorf("looking up interface %s: %w", name, err)
+ }
+
+ return &osInterface{
+ iface: netIface,
+ }, nil
+}
diff --git a/internal/bindtodevice/manager.go b/internal/bindtodevice/manager.go
index 3eebd4c..598ab39 100644
--- a/internal/bindtodevice/manager.go
+++ b/internal/bindtodevice/manager.go
@@ -5,6 +5,10 @@ import "github.com/AdguardTeam/AdGuardDNS/internal/agd"
// ManagerConfig is the configuration structure for [NewManager]. All fields
// must be set.
type ManagerConfig struct {
+ // InterfaceStorage is used to get the information about the system's
+ // network interfaces. Normally, this is [DefaultInterfaceStorage].
+ InterfaceStorage InterfaceStorage
+
// ErrColl is the error collector that is used to collect non-critical
// errors.
ErrColl agd.ErrorCollector
@@ -13,3 +17,14 @@ type ManagerConfig struct {
// dispatch TCP connections and UDP sessions.
ChannelBufferSize int
}
+
+// ControlConfig is the configuration of socket options.
+type ControlConfig struct {
+ // RcvBufSize defines the size of socket receive buffer in bytes. Default
+ // is zero (uses system settings).
+ RcvBufSize int
+
+ // SndBufSize defines the size of socket send buffer in bytes. Default is
+ // zero (uses system settings).
+ SndBufSize int
+}
diff --git a/internal/bindtodevice/manager_linux.go b/internal/bindtodevice/manager_linux.go
index 13c79fb..567ad5f 100644
--- a/internal/bindtodevice/manager_linux.go
+++ b/internal/bindtodevice/manager_linux.go
@@ -18,6 +18,7 @@ import (
// Manager creates individual listeners and dispatches connections to them.
type Manager struct {
+ interfaces InterfaceStorage
closeOnce *sync.Once
ifaceListeners map[ID]*interfaceListener
errColl agd.ErrorCollector
@@ -28,6 +29,7 @@ type Manager struct {
// NewManager returns a new manager of interface listeners.
func NewManager(c *ManagerConfig) (m *Manager) {
return &Manager{
+ interfaces: c.InterfaceStorage,
closeOnce: &sync.Once{},
ifaceListeners: map[ID]*interfaceListener{},
errColl: c.ErrColl,
@@ -36,12 +38,25 @@ func NewManager(c *ManagerConfig) (m *Manager) {
}
}
-// Add creates a new interface-listener record in m.
+// defaultCtrlConf is the default control config. By default, don't alter
+// anything. defaultCtrlConf must not be mutated.
+var defaultCtrlConf = &ControlConfig{
+ RcvBufSize: 0,
+ SndBufSize: 0,
+}
+
+// Add creates a new interface-listener record in m. If conf is nil, a default
+// configuration is used.
//
// Add must not be called after Start is called.
-func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) {
+func (m *Manager) Add(id ID, ifaceName string, port uint16, conf *ControlConfig) (err error) {
defer func() { err = errors.Annotate(err, "adding interface listener with id %q: %w", id) }()
+ _, err = m.interfaces.InterfaceByName(ifaceName)
+ if err != nil {
+ return fmt.Errorf("looking up interface %q: %w", ifaceName, err)
+ }
+
validateDup := func(lsnrID ID, lsnr *interfaceListener) (lsnrErr error) {
lsnrIfaceName, lsnrPort := lsnr.ifaceName, lsnr.port
if lsnrID == id {
@@ -71,11 +86,15 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) {
return err
}
+ if conf == nil {
+ conf = defaultCtrlConf
+ }
+
m.ifaceListeners[id] = &interfaceListener{
- channels: &chanIndex{},
+ conns: &connIndex{},
writeRequests: make(chan *packetConnWriteReq, m.chanBufSize),
done: m.done,
- listenConf: newListenConfig(ifaceName),
+ listenConf: newListenConfig(ifaceName, conf),
errColl: m.errColl,
ifaceName: ifaceName,
port: port,
@@ -90,53 +109,96 @@ func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) {
//
// ListenConfig must not be called after Start is called.
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
- if masked := subnet.Masked(); subnet != masked {
- return nil, fmt.Errorf(
- "subnet %s for interface listener %q not masked (expected %s)",
+ defer func() {
+ err = errors.Annotate(
+ err,
+ "creating listen config for subnet %s and listener with id %q: %w",
subnet,
id,
- masked,
)
- }
-
+ }()
l, ok := m.ifaceListeners[id]
if !ok {
- return nil, fmt.Errorf("no listener for interface %q", id)
+ return nil, errors.Error("no interface listener found")
}
- connCh := make(chan net.Conn, m.chanBufSize)
- err = l.channels.addListenerChannel(subnet, connCh)
+ err = m.validateIfaceSubnet(l.ifaceName, subnet)
if err != nil {
- return nil, fmt.Errorf("adding tcp conn channel: %w", err)
+ // Don't wrap the error, because it's informative enough as is.
+ return nil, err
+ }
+
+ lsnrCh := make(chan net.Conn, m.chanBufSize)
+ lsnr := newChanListener(lsnrCh, subnet, &prefixNetAddr{
+ prefix: subnet,
+ network: "tcp",
+ port: l.port,
+ })
+
+ err = l.conns.addListener(lsnr)
+ if err != nil {
+ return nil, fmt.Errorf("adding tcp conn: %w", err)
}
sessCh := make(chan *packetSession, m.chanBufSize)
- err = l.channels.addPacketConnChannel(subnet, sessCh)
+ pConn := newChanPacketConn(sessCh, subnet, l.writeRequests, &prefixNetAddr{
+ prefix: subnet,
+ network: "udp",
+ port: l.port,
+ })
+
+ err = l.conns.addPacketConn(pConn)
if err != nil {
// Technically shouldn't happen, since [chanIndex.addListenerChannel]
// has already checked for duplicates.
- return nil, fmt.Errorf("adding udp conn channel: %w", err)
+ return nil, fmt.Errorf("adding udp conn: %w", err)
}
return &chanListenConfig{
- packetConn: newChanPacketConn(sessCh, l.writeRequests, &prefixNetAddr{
- prefix: subnet,
- network: "udp",
- port: l.port,
- }),
- listener: newChanListener(connCh, &prefixNetAddr{
- prefix: subnet,
- network: "tcp",
- port: l.port,
- }),
+ packetConn: pConn,
+ listener: lsnr,
}, nil
}
+// validateIfaceSubnet validates the interface with the name ifaceName exists
+// and that it can accept addresses from subnet.
+func (m *Manager) validateIfaceSubnet(ifaceName string, subnet netip.Prefix) (err error) {
+ if masked := subnet.Masked(); subnet != masked {
+ return fmt.Errorf("subnet not masked (expected %s)", masked)
+ }
+
+ iface, err := m.interfaces.InterfaceByName(ifaceName)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ ifaceSubnets, err := iface.Subnets()
+ if err != nil {
+ return fmt.Errorf("getting subnets: %w", err)
+ }
+
+ for _, s := range ifaceSubnets {
+ if s.Contains(subnet.Addr()) && s.Bits() <= subnet.Bits() {
+ return nil
+ }
+ }
+
+ return fmt.Errorf("interface %s does not contain subnet %s", ifaceName, subnet)
+}
+
// type check
var _ agd.Service = (*Manager)(nil)
-// Start implements the [agd.Service] interface for *Manager.
+// Start implements the [agd.Service] interface for *Manager. If m is nil,
+// Start returns nil, since this feature is optional.
+//
+// TODO(a.garipov): Consider an interface solution.
func (m *Manager) Start() (err error) {
+ if m == nil {
+ return nil
+ }
+
numListen := 2 * len(m.ifaceListeners)
errCh := make(chan error, numListen)
@@ -162,10 +224,17 @@ func (m *Manager) Start() (err error) {
return nil
}
-// Shutdown implements the [agd.Service] interface for *Manager.
+// Shutdown implements the [agd.Service] interface for *Manager. If m is nil,
+// Shutdown returns nil, since this feature is optional.
+//
+// TODO(a.garipov): Consider an interface solution.
//
// TODO(a.garipov): Consider waiting for all sockets to close.
func (m *Manager) Shutdown(_ context.Context) (err error) {
+ if m == nil {
+ return nil
+ }
+
closedNow := false
m.closeOnce.Do(func() {
close(m.done)
diff --git a/internal/bindtodevice/manager_linux_test.go b/internal/bindtodevice/manager_linux_test.go
index e33ea39..e7531d8 100644
--- a/internal/bindtodevice/manager_linux_test.go
+++ b/internal/bindtodevice/manager_linux_test.go
@@ -18,12 +18,46 @@ import (
// TODO(a.garipov): Add tests for other platforms?
+// type check
+var _ bindtodevice.InterfaceStorage = (*fakeInterfaceStorage)(nil)
+
+// fakeInterfaceStorage is a fake [bindtodevice.InterfaceStorage] for tests.
+type fakeInterfaceStorage struct {
+ OnInterfaceByName func(name string) (iface bindtodevice.NetInterface, err error)
+}
+
+// InterfaceByName implements the [bindtodevice.InterfaceStorage] interface
+// for *fakeInterfaceStorage.
+func (s *fakeInterfaceStorage) InterfaceByName(
+ name string,
+) (iface bindtodevice.NetInterface, err error) {
+ return s.OnInterfaceByName(name)
+}
+
+// type check
+var _ bindtodevice.NetInterface = (*fakeInterface)(nil)
+
+// fakeInterface is a fake [bindtodevice.Interface] for tests.
+type fakeInterface struct {
+ OnSubnets func() (subnets []netip.Prefix, err error)
+}
+
+// Subnets implements the [bindtodevice.Interface] interface for *fakeInterface.
+func (iface *fakeInterface) Subnets() (subnets []netip.Prefix, err error) {
+ return iface.OnSubnets()
+}
+
func TestManager_Add(t *testing.T) {
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
}
m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: &fakeInterfaceStorage{
+ OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) {
+ return nil, nil
+ },
+ },
ErrColl: errColl,
ChannelBufferSize: 1,
})
@@ -32,22 +66,22 @@ func TestManager_Add(t *testing.T) {
// Don't use a table, since the results of these subtests depend on each
// other.
t.Run("success", func(t *testing.T) {
- err := m.Add(testID1, testIfaceName, testPort1)
+ err := m.Add(testID1, testIfaceName, testPort1, nil)
assert.NoError(t, err)
})
t.Run("dup_id", func(t *testing.T) {
- err := m.Add(testID1, testIfaceName, testPort1)
+ err := m.Add(testID1, testIfaceName, testPort1, nil)
assert.Error(t, err)
})
t.Run("dup_iface_port", func(t *testing.T) {
- err := m.Add(testID2, testIfaceName, testPort1)
+ err := m.Add(testID2, testIfaceName, testPort1, nil)
assert.Error(t, err)
})
t.Run("success_other", func(t *testing.T) {
- err := m.Add(testID2, testIfaceName, testPort2)
+ err := m.Add(testID2, testIfaceName, testPort2, nil)
assert.NoError(t, err)
})
}
@@ -57,17 +91,27 @@ func TestManager_ListenConfig(t *testing.T) {
OnCollect: func(_ context.Context, _ error) { panic("not implemented") },
}
+ subnet := testSubnetIPv4
+ ifaceWithSubnet := &fakeInterface{
+ OnSubnets: func() (subnets []netip.Prefix, err error) {
+ return []netip.Prefix{subnet}, nil
+ },
+ }
+
m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: &fakeInterfaceStorage{
+ OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) {
+ return ifaceWithSubnet, nil
+ },
+ },
ErrColl: errColl,
ChannelBufferSize: 1,
})
require.NotNil(t, m)
- err := m.Add(testID1, testIfaceName, testPort1)
+ err := m.Add(testID1, testIfaceName, testPort1, nil)
require.NoError(t, err)
- subnet := netip.MustParsePrefix("1.2.3.0/24")
-
// Don't use a table, since the results of these subtests depend on each
// other.
t.Run("not_found", func(t *testing.T) {
@@ -94,6 +138,60 @@ func TestManager_ListenConfig(t *testing.T) {
assert.Nil(t, lc)
assert.Error(t, lcErr)
})
+
+ t.Run("no_subnet", func(t *testing.T) {
+ ifaceWithoutSubnet := &fakeInterface{
+ OnSubnets: func() (subnets []netip.Prefix, err error) {
+ return nil, nil
+ },
+ }
+
+ noSubnetMgr := bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: &fakeInterfaceStorage{
+ OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) {
+ return ifaceWithoutSubnet, nil
+ },
+ },
+ ErrColl: errColl,
+ ChannelBufferSize: 1,
+ })
+ require.NotNil(t, noSubnetMgr)
+
+ subTestErr := noSubnetMgr.Add(testID1, testIfaceName, testPort1, nil)
+ require.NoError(t, subTestErr)
+
+ lc, subTestErr := noSubnetMgr.ListenConfig(testID1, subnet)
+ assert.Nil(t, lc)
+ assert.Error(t, subTestErr)
+ })
+
+ t.Run("narrower_subnet", func(t *testing.T) {
+ ifaceWithNarrowerSubnet := &fakeInterface{
+ OnSubnets: func() (subnets []netip.Prefix, err error) {
+ narrowerSubnet := netip.PrefixFrom(subnet.Addr(), subnet.Bits()+4)
+
+ return []netip.Prefix{narrowerSubnet}, nil
+ },
+ }
+
+ narrowSubnetMgr := bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: &fakeInterfaceStorage{
+ OnInterfaceByName: func(_ string) (iface bindtodevice.NetInterface, err error) {
+ return ifaceWithNarrowerSubnet, nil
+ },
+ },
+ ErrColl: errColl,
+ ChannelBufferSize: 1,
+ })
+ require.NotNil(t, narrowSubnetMgr)
+
+ subTestErr := narrowSubnetMgr.Add(testID1, testIfaceName, testPort1, nil)
+ require.NoError(t, subTestErr)
+
+ lc, subTestErr := narrowSubnetMgr.ListenConfig(testID1, subnet)
+ assert.Nil(t, lc)
+ assert.Error(t, subTestErr)
+ })
}
func TestManager(t *testing.T) {
@@ -113,13 +211,14 @@ func TestManager(t *testing.T) {
}
m := bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: bindtodevice.DefaultInterfaceStorage{},
ErrColl: errColl,
ChannelBufferSize: 1,
})
require.NotNil(t, m)
// TODO(a.garipov): Add support for zero port.
- err := m.Add(testID1, ifaceName, testPort1)
+ err := m.Add(testID1, ifaceName, testPort1, nil)
require.NoError(t, err)
subnet, err := netutil.IPNetToPrefixNoMapped(&net.IPNet{
diff --git a/internal/bindtodevice/manager_others.go b/internal/bindtodevice/manager_others.go
index 457e481..d31f91b 100644
--- a/internal/bindtodevice/manager_others.go
+++ b/internal/bindtodevice/manager_others.go
@@ -13,12 +13,12 @@ import (
// Manager creates individual listeners and dispatches connections to them.
//
-// It is only suported on Linux.
+// It is only supported on Linux.
type Manager struct{}
// NewManager returns a new manager of interface listeners.
//
-// It is only suported on Linux.
+// It is only supported on Linux.
func NewManager(c *ManagerConfig) (m *Manager) {
return &Manager{}
}
@@ -29,14 +29,16 @@ const errUnsupported errors.Error = "bindtodevice is only supported on linux"
// Add creates a new interface-listener record in m.
//
-// It is only suported on Linux.
-func (m *Manager) Add(id ID, ifaceName string, port uint16) (err error) { return errUnsupported }
+// It is only supported on Linux.
+func (m *Manager) Add(id ID, ifaceName string, port uint16, cc *ControlConfig) (err error) {
+ return errUnsupported
+}
// ListenConfig returns a new netext.ListenConfig that receives connections from
// the interface listener with the given id and the destination addresses of
// which fall within subnet. subnet should be masked.
//
-// It is only suported on Linux.
+// It is only supported on Linux.
func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfig, err error) {
return nil, errUnsupported
}
@@ -44,12 +46,26 @@ func (m *Manager) ListenConfig(id ID, subnet netip.Prefix) (c netext.ListenConfi
// type check
var _ agd.Service = (*Manager)(nil)
-// Start implements the [agd.Service] interface for *Manager.
+// Start implements the [agd.Service] interface for *Manager. If m is nil,
+// Start returns nil, since this feature is optional.
//
-// It is only suported on Linux.
-func (m *Manager) Start() (err error) { return errUnsupported }
+// It is only supported on Linux.
+func (m *Manager) Start() (err error) {
+ if m == nil {
+ return nil
+ }
-// Shutdown implements the [agd.Service] interface for *Manager.
+ return errUnsupported
+}
+
+// Shutdown implements the [agd.Service] interface for *Manager. If m is nil,
+// Shutdown returns nil, since this feature is optional.
//
-// It is only suported on Linux.
-func (m *Manager) Shutdown(_ context.Context) (err error) { return errUnsupported }
+// It is only supported on Linux.
+func (m *Manager) Shutdown(_ context.Context) (err error) {
+ if m == nil {
+ return nil
+ }
+
+ return errUnsupported
+}
diff --git a/internal/bindtodevice/prefixaddr_linux_internal_test.go b/internal/bindtodevice/prefixaddr_linux_internal_test.go
index bc95326..7df5be7 100644
--- a/internal/bindtodevice/prefixaddr_linux_internal_test.go
+++ b/internal/bindtodevice/prefixaddr_linux_internal_test.go
@@ -3,6 +3,7 @@
package bindtodevice
import (
+ "fmt"
"net/netip"
"testing"
@@ -11,16 +12,42 @@ import (
func TestPrefixAddr(t *testing.T) {
const (
- wantStr = "1.2.3.0:56789/24"
+ port = 56789
network = "tcp"
)
- pa := &prefixNetAddr{
- prefix: netip.MustParsePrefix("1.2.3.0/24"),
- network: network,
- port: 56789,
- }
+ testCases := []struct {
+ in *prefixNetAddr
+ want string
+ name string
+ }{{
+ in: &prefixNetAddr{
+ prefix: testSubnetIPv4,
+ network: network,
+ port: port,
+ },
+ want: fmt.Sprintf(
+ "%s/%d",
+ netip.AddrPortFrom(testSubnetIPv4.Addr(), port), testSubnetIPv4.Bits(),
+ ),
+ name: "ipv4",
+ }, {
+ in: &prefixNetAddr{
+ prefix: testSubnetIPv6,
+ network: network,
+ port: port,
+ },
+ want: fmt.Sprintf(
+ "%s/%d",
+ netip.AddrPortFrom(testSubnetIPv6.Addr(), port), testSubnetIPv6.Bits(),
+ ),
+ name: "ipv6",
+ }}
- assert.Equal(t, wantStr, pa.String())
- assert.Equal(t, network, pa.Network())
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.want, tc.in.String())
+ assert.Equal(t, network, tc.in.Network())
+ })
+ }
}
diff --git a/internal/bindtodevice/socket_linux.go b/internal/bindtodevice/socket_linux.go
index 122fa17..ad0e910 100644
--- a/internal/bindtodevice/socket_linux.go
+++ b/internal/bindtodevice/socket_linux.go
@@ -12,106 +12,100 @@ import (
"golang.org/x/sys/unix"
)
-// newListenConfig returns a [net.ListenConfig] that can bind to a network
-// interface (device) by its name.
-func newListenConfig(devName string) (lc *net.ListenConfig) {
- c := &net.ListenConfig{
- Control: func(network, address string, c syscall.RawConn) (err error) {
- return listenControl(devName, network, address, c)
- },
- }
+// setSockOptFunc is a function that sets a socket option on fd.
+type setSockOptFunc func(fd int) (err error)
- return c
+// newIntSetSockOptFunc returns an integer socket-option function with the given
+// parameters.
+func newIntSetSockOptFunc(name string, lvl, opt, val int) (o setSockOptFunc) {
+ return func(fd int) (err error) {
+ opErr := unix.SetsockoptInt(fd, lvl, opt, val)
+
+ return errors.Annotate(opErr, "setting %s: %w", name)
+ }
}
-// listenControl is used as a [net.ListenConfig.Control] function to set
-// additional socket options, including SO_BINDTODEVICE.
-func listenControl(devName, network, _ string, c syscall.RawConn) (err error) {
- var ctrlFunc func(fd uintptr, devName string) (err error)
+// newStringSetSockOptFunc returns a string socket-option function with the
+// given parameters.
+func newStringSetSockOptFunc(name string, lvl, opt int, val string) (o setSockOptFunc) {
+ return func(fd int) (err error) {
+ opErr := unix.SetsockoptString(fd, lvl, opt, val)
+
+ return errors.Annotate(opErr, "setting %s: %w", name)
+ }
+}
+
+// newListenConfig returns a [net.ListenConfig] that can bind to a network
+// interface (device) by its name. ctrlConf must not be nil.
+func newListenConfig(devName string, ctrlConf *ControlConfig) (lc *net.ListenConfig) {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) (err error) {
+ return listenControlWithSO(ctrlConf, devName, network, address, c)
+ },
+ }
+}
+
+// listenControlWithSO is used as a [net.ListenConfig.Control] function to set
+// additional socket options.
+func listenControlWithSO(
+ ctrlConf *ControlConfig,
+ devName string,
+ network string,
+ _ string,
+ c syscall.RawConn,
+) (err error) {
+ opts := []setSockOptFunc{
+ newStringSetSockOptFunc("SO_BINDTODEVICE", unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName),
+ // Use SO_REUSEADDR as well, which is not technically necessary, to
+ // help with the situation of sockets hanging in CLOSE_WAIT for too
+ // long.
+ newIntSetSockOptFunc("SO_REUSEADDR", unix.SOL_SOCKET, unix.SO_REUSEADDR, 1),
+ newIntSetSockOptFunc("SO_REUSEPORT", unix.SOL_SOCKET, unix.SO_REUSEPORT, 1),
+ }
switch network {
case "tcp", "tcp4", "tcp6":
- ctrlFunc = setTCPSockOpt
+ // Socket options for TCP connection already set. Go on.
case "udp", "udp4", "udp6":
- ctrlFunc = setUDPSockOpt
+ opts = append(
+ opts,
+ newIntSetSockOptFunc("IP_RECVORIGDSTADDR", unix.IPPROTO_IP, unix.IP_RECVORIGDSTADDR, 1),
+ newIntSetSockOptFunc("IP_FREEBIND", unix.IPPROTO_IP, unix.IP_FREEBIND, 1),
+ newIntSetSockOptFunc("IPV6_RECVORIGDSTADDR", unix.IPPROTO_IPV6, unix.IPV6_RECVORIGDSTADDR, 1),
+ newIntSetSockOptFunc("IPV6_FREEBIND", unix.IPPROTO_IPV6, unix.IPV6_FREEBIND, 1),
+ )
default:
return fmt.Errorf("bad network %q", network)
}
+ if ctrlConf.SndBufSize > 0 {
+ opts = append(
+ opts,
+ newIntSetSockOptFunc("SO_SNDBUF", unix.SOL_SOCKET, unix.SO_SNDBUF, ctrlConf.SndBufSize),
+ )
+ }
+
+ if ctrlConf.RcvBufSize > 0 {
+ opts = append(
+ opts,
+ newIntSetSockOptFunc("SO_RCVBUF", unix.SOL_SOCKET, unix.SO_RCVBUF, ctrlConf.RcvBufSize),
+ )
+ }
+
var opErr error
err = c.Control(func(fd uintptr) {
- opErr = ctrlFunc(fd, devName)
+ d := int(fd)
+ for _, opt := range opts {
+ opErr = opt(d)
+ if opErr != nil {
+ return
+ }
+ }
})
return errors.WithDeferred(opErr, err)
}
-// setTCPSockOpt sets the SO_BINDTODEVICE and other socket options for a TCP
-// connection.
-func setTCPSockOpt(fd uintptr, devName string) (err error) {
- defer func() { err = errors.Annotate(err, "setting tcp opts: %w") }()
-
- fdInt := int(fd)
- err = unix.SetsockoptString(fdInt, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName)
- if err != nil {
- return fmt.Errorf("setting SO_BINDTODEVICE: %w", err)
- }
-
- err = unix.SetsockoptInt(fdInt, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
- if err != nil {
- return fmt.Errorf("setting SO_REUSEPORT: %w", err)
- }
-
- return nil
-}
-
-// setUDPSockOpt sets the SO_BINDTODEVICE and other socket options for a UDP
-// connection.
-func setUDPSockOpt(fd uintptr, devName string) (err error) {
- defer func() { err = errors.Annotate(err, "setting udp opts: %w") }()
-
- fdInt := int(fd)
- err = unix.SetsockoptString(fdInt, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, devName)
- if err != nil {
- return fmt.Errorf("setting SO_BINDTODEVICE: %w", err)
- }
-
- intOpts := []struct {
- name string
- level int
- opt int
- }{{
- name: "SO_REUSEPORT",
- level: unix.SOL_SOCKET,
- opt: unix.SO_REUSEPORT,
- }, {
- name: "IP_RECVORIGDSTADDR",
- level: unix.IPPROTO_IP,
- opt: unix.IP_RECVORIGDSTADDR,
- }, {
- name: "IP_FREEBIND",
- level: unix.IPPROTO_IP,
- opt: unix.IP_FREEBIND,
- }, {
- name: "IPV6_RECVORIGDSTADDR",
- level: unix.IPPROTO_IPV6,
- opt: unix.IPV6_RECVORIGDSTADDR,
- }, {
- name: "IPV6_FREEBIND",
- level: unix.IPPROTO_IPV6,
- opt: unix.IPV6_FREEBIND,
- }}
-
- for _, o := range intOpts {
- err = unix.SetsockoptInt(fdInt, o.level, o.opt, 1)
- if err != nil {
- return fmt.Errorf("setting %s: %w", o.name, err)
- }
- }
-
- return nil
-}
-
// readPacketSession is a helper that reads a packet-session data from a UDP
// connection.
func readPacketSession(c *net.UDPConn, bodySize int) (sess *packetSession, err error) {
@@ -165,8 +159,11 @@ func sockAddrData(sockAddr unix.Sockaddr) (origDstAddr *net.UDPAddr, respOOB []b
Port: sockAddr.Port,
}
+ // Set both addresses to make sure that users receive the correct source
+ // IP address even when virtual interfaces are involved.
pktInfo := &unix.Inet4Pktinfo{
- Addr: sockAddr.Addr,
+ Addr: sockAddr.Addr,
+ Spec_dst: sockAddr.Addr,
}
respOOB = unix.PktInfo4(pktInfo)
diff --git a/internal/bindtodevice/socket_linux_internal_test.go b/internal/bindtodevice/socket_linux_internal_test.go
index c2bd054..bb062f5 100644
--- a/internal/bindtodevice/socket_linux_internal_test.go
+++ b/internal/bindtodevice/socket_linux_internal_test.go
@@ -9,6 +9,7 @@ import (
"net/netip"
"os"
"strings"
+ "syscall"
"testing"
"time"
@@ -18,6 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
+ "golang.org/x/sys/unix"
)
// TestInterfaceEnvVarName is the environment variable name the presence and
@@ -88,7 +90,7 @@ func TestListenControl(t *testing.T) {
}
ifaceName := iface.Name
- lc := newListenConfig(ifaceName)
+ lc := newListenConfig(ifaceName, &ControlConfig{})
require.NotNil(t, lc)
t.Run("tcp", func(t *testing.T) {
@@ -380,3 +382,81 @@ func closestIP(t testing.TB, n *net.IPNet, ip net.IP) (closest net.IP) {
return nil
}
+
+func TestListenControlWithSO(t *testing.T) {
+ const (
+ sndBufSize = 10000
+ rcvBufSize = 20000
+ )
+
+ iface, _ := InterfaceForTests(t)
+ if iface == nil {
+ t.Skipf("test %s skipped: please set env var %s", t.Name(), TestInterfaceEnvVarName)
+ }
+
+ ifaceName := iface.Name
+ lc := newListenConfig(
+ ifaceName,
+ &ControlConfig{
+ RcvBufSize: rcvBufSize,
+ SndBufSize: sndBufSize,
+ },
+ )
+ require.NotNil(t, lc)
+
+ type syscallConner interface {
+ SyscallConn() (c syscall.RawConn, err error)
+ }
+
+ t.Run("udp", func(t *testing.T) {
+ c, err := lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
+ require.NoError(t, opErr)
+
+ assert.Equal(t, sndBufSize*2, val)
+ })
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
+ require.NoError(t, opErr)
+
+ assert.Equal(t, rcvBufSize*2, val)
+ })
+ require.NoError(t, err)
+ })
+
+ t.Run("tcp", func(t *testing.T) {
+ c, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:0")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
+ require.NoError(t, opErr)
+
+ assert.Equal(t, sndBufSize*2, val)
+ })
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
+ require.NoError(t, opErr)
+
+ assert.Equal(t, rcvBufSize*2, val)
+ })
+ require.NoError(t, err)
+ })
+}
diff --git a/internal/cmd/backend.go b/internal/cmd/backend.go
index 78e9fc2..c1592c3 100644
--- a/internal/cmd/backend.go
+++ b/internal/cmd/backend.go
@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/backend"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@@ -76,7 +77,7 @@ func setupBackend(
envs *environments,
sigHdlr signalHandler,
errColl agd.ErrorCollector,
-) (profDB *agd.DefaultProfileDB, rec *billstat.RuntimeRecorder, err error) {
+) (profDB *profiledb.Default, rec *billstat.RuntimeRecorder, err error) {
profStrgConf, billStatConf := conf.toInternal(envs, errColl)
rec = billstat.NewRuntimeRecorder(&billstat.RuntimeRecorderConfig{
Uploader: backend.NewBillStat(billStatConf),
@@ -103,7 +104,7 @@ func setupBackend(
sigHdlr.add(billStatRefr)
profStrg := backend.NewProfileStorage(profStrgConf)
- profDB, err = agd.NewDefaultProfileDB(
+ profDB, err = profiledb.New(
profStrg,
conf.FullRefreshIvl.Duration,
envs.ProfilesCachePath,
diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go
index e9de141..f8dde60 100644
--- a/internal/cmd/cmd.go
+++ b/internal/cmd/cmd.go
@@ -16,9 +16,9 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
@@ -130,11 +130,23 @@ func Main() {
fltGroups, err := c.FilteringGroups.toInternal(fltStrg)
check(err)
+ // Network interface listener
+
+ btdCtrlConf, ctrlConf := c.Network.toInternal()
+
+ btdMgr, err := c.InterfaceListeners.toInternal(errColl, btdCtrlConf)
+ check(err)
+
+ err = btdMgr.Start()
+ check(err)
+
+ sigHdlr.add(btdMgr)
+
// Server groups
messages := dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, c.Filters.ResponseTTL.Duration)
- srvGrps, err := c.ServerGroups.toInternal(messages, fltGroups)
+ srvGrps, err := c.ServerGroups.toInternal(messages, btdMgr, fltGroups)
check(err)
// TLS keys logging
@@ -173,7 +185,7 @@ func Main() {
// Rate limiting
consulAllowlistURL := &envs.ConsulAllowlistURL.URL
- rateLimiter, err := setupRateLimiter(c.RateLimit, consulAllowlistURL, sigHdlr, errColl)
+ rateLimiter, connLimiter, err := setupRateLimiter(c.RateLimit, consulAllowlistURL, sigHdlr, errColl)
check(err)
// GeoIP database
@@ -197,24 +209,21 @@ func Main() {
// DNS service
- metricsListener := prometheus.NewForwardMetricsListener(len(c.Upstream.FallbackServers) + 1)
-
- upstream, err := c.Upstream.toInternal()
+ fwdConf, err := c.Upstream.toInternal()
check(err)
- handler := forward.NewHandler(&forward.HandlerConfig{
- Address: upstream.Server,
- Network: upstream.Network,
- MetricsListener: metricsListener,
- HealthcheckDomainTmpl: c.Upstream.Healthcheck.DomainTmpl,
- FallbackAddresses: c.Upstream.FallbackServers,
- Timeout: c.Upstream.Timeout.Duration,
- HealthcheckBackoffDuration: c.Upstream.Healthcheck.BackoffDuration.Duration,
- }, c.Upstream.Healthcheck.Enabled)
+ handler := forward.NewHandler(fwdConf)
+
+ // TODO(a.garipov): Consider making these configurable via the configuration
+ // file.
+ hashStorages := map[string]*hashprefix.Storage{
+ filter.GeneralTXTSuffix: safeBrowsingHashes,
+ filter.AdultBlockingTXTSuffix: adultBlockingHashes,
+ }
dnsConf := &dnssvc.Config{
Messages: messages,
- SafeBrowsing: filter.NewSafeBrowsingServer(safeBrowsingHashes, adultBlockingHashes),
+ SafeBrowsing: hashprefix.NewMatcher(hashStorages),
BillStat: billStatRec,
ProfileDB: profDB,
DNSCheck: dnsCk,
@@ -226,21 +235,20 @@ func Main() {
Handler: handler,
QueryLog: c.buildQueryLog(envs),
RuleStat: ruleStat,
- Upstream: upstream,
RateLimit: rateLimiter,
+ ConnLimiter: connLimiter,
FilteringGroups: fltGroups,
ServerGroups: srvGrps,
CacheSize: c.Cache.Size,
ECSCacheSize: c.Cache.ECSSize,
UseECSCache: c.Cache.Type == cacheTypeECS,
ResearchMetrics: bool(envs.ResearchMetrics),
+ ControlConf: ctrlConf,
}
dnsSvc, err := dnssvc.New(dnsConf)
check(err)
- sigHdlr.add(dnsSvc)
-
// Connectivity check
err = connectivityCheck(dnsConf, c.ConnectivityCheck)
diff --git a/internal/cmd/config.go b/internal/cmd/config.go
index af24751..f7475bc 100644
--- a/internal/cmd/config.go
+++ b/internal/cmd/config.go
@@ -61,6 +61,13 @@ type configuration struct {
// ConnectivityCheck is the connectivity check configuration.
ConnectivityCheck *connCheckConfig `yaml:"connectivity_check"`
+ // InterfaceListeners is the configuration for the network interface
+ // listeners and their common parameters.
+ InterfaceListeners *interfaceListenersConfig `yaml:"interface_listeners"`
+
+ // Network is the configuration for network listeners.
+ Network *network `yaml:"network"`
+
// AdditionalMetricsInfo is extra information, which is exposed by metrics.
AdditionalMetricsInfo additionalInfo `yaml:"additional_metrics_info"`
@@ -140,6 +147,12 @@ func (c *configuration) validate() (err error) {
}, {
validate: c.ConnectivityCheck.validate,
name: "connectivity_check",
+ }, {
+ validate: c.InterfaceListeners.validate,
+ name: "interface_listeners",
+ }, {
+ validate: c.Network.validate,
+ name: "network",
}, {
validate: c.AdditionalMetricsInfo.validate,
name: "additional_metrics_info",
diff --git a/internal/cmd/env.go b/internal/cmd/env.go
index 4a51188..fd8d26c 100644
--- a/internal/cmd/env.go
+++ b/internal/cmd/env.go
@@ -34,8 +34,8 @@ type environments struct {
YoutubeSafeSearchURL *agdhttp.URL `env:"YOUTUBE_SAFE_SEARCH_URL,notEmpty"`
RuleStatURL *agdhttp.URL `env:"RULESTAT_URL"`
- ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yml"`
- DNSDBPath string `env:"DNSDB_PATH" envDefault:"./dnsdb.bolt"`
+ ConfPath string `env:"CONFIG_PATH" envDefault:"./config.yaml"`
+ DNSDBPath string `env:"DNSDB_PATH"`
FilterCachePath string `env:"FILTER_CACHE_PATH" envDefault:"./filters/"`
ProfilesCachePath string `env:"PROFILES_CACHE_PATH" envDefault:"./profilecache.json"`
GeoIPASNPath string `env:"GEOIP_ASN_PATH" envDefault:"./asn.mmdb"`
diff --git a/internal/cmd/filter.go b/internal/cmd/filter.go
index fca8414..7e8511c 100644
--- a/internal/cmd/filter.go
+++ b/internal/cmd/filter.go
@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@@ -51,8 +52,8 @@ func (c *filtersConfig) toInternal(
errColl agd.ErrorCollector,
resolver agdnet.Resolver,
envs *environments,
- safeBrowsing *filter.HashPrefix,
- adultBlocking *filter.HashPrefix,
+ safeBrowsing *hashprefix.Filter,
+ adultBlocking *hashprefix.Filter,
) (conf *filter.DefaultStorageConfig) {
return &filter.DefaultStorageConfig{
FilterIndexURL: netutil.CloneURL(&envs.FilterIndexURL.URL),
diff --git a/internal/cmd/ifacelistener.go b/internal/cmd/ifacelistener.go
new file mode 100644
index 0000000..092115b
--- /dev/null
+++ b/internal/cmd/ifacelistener.go
@@ -0,0 +1,102 @@
+package cmd
+
+import (
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdmaps"
+ "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
+ "github.com/AdguardTeam/golibs/errors"
+)
+
+// Network interface listener configuration
+
+// interfaceListenersConfig contains the optional configuration for the network
+// interface listeners and their common parameters.
+type interfaceListenersConfig struct {
+ // List is the ID-to-configuration mapping of network interface listeners.
+ List map[bindtodevice.ID]*interfaceListener `yaml:"list"`
+
+ // ChannelBufferSize is the size of the buffers of the channels used to
+ // dispatch TCP connections and UDP sessions.
+ ChannelBufferSize int `yaml:"channel_buffer_size"`
+}
+
+// toInternal converts c to a bindtodevice.Manager. c is assumed to be valid.
+func (c *interfaceListenersConfig) toInternal(
+ errColl agd.ErrorCollector,
+ ctrlConf *bindtodevice.ControlConfig,
+) (m *bindtodevice.Manager, err error) {
+ if c == nil {
+ return nil, nil
+ }
+
+ m = bindtodevice.NewManager(&bindtodevice.ManagerConfig{
+ InterfaceStorage: bindtodevice.DefaultInterfaceStorage{},
+ ErrColl: errColl,
+ ChannelBufferSize: c.ChannelBufferSize,
+ })
+
+ err = agdmaps.OrderedRangeError(
+ c.List,
+ func(id bindtodevice.ID, l *interfaceListener) (addErr error) {
+ return errors.Annotate(m.Add(id, l.Interface, l.Port, ctrlConf), "adding listener %q: %w", id)
+ },
+ )
+
+ if err != nil {
+ return nil, err
+ }
+
+ return m, nil
+}
+
+// validate returns an error if the network interface listeners configuration is
+// invalid.
+func (c *interfaceListenersConfig) validate() (err error) {
+ switch {
+ case c == nil:
+ // This configuration is optional.
+ //
+ // TODO(a.garipov): Consider making required or not relying on nil
+ // values.
+ return nil
+ case c.ChannelBufferSize <= 0:
+ return newMustBePositiveError("channel_buffer_size", c.ChannelBufferSize)
+ case len(c.List) == 0:
+ return errors.Error("no list")
+ default:
+ // Go on.
+ }
+
+ err = agdmaps.OrderedRangeError(
+ c.List,
+ func(id bindtodevice.ID, l *interfaceListener) (lsnrErr error) {
+ return errors.Annotate(l.validate(), "interface %q: %w", id)
+ },
+ )
+
+ return err
+}
+
+// interfaceListener contains configuration for a single network interface
+// listener.
+type interfaceListener struct {
+ // Interface is the name of the network interface in the system.
+ Interface string `yaml:"interface"`
+
+ // Port is the port number on which to listen for incoming connections.
+ Port uint16 `yaml:"port"`
+}
+
+// validate returns an error if the interface listener configuration is invalid.
+func (l *interfaceListener) validate() (err error) {
+ switch {
+ case l == nil:
+ return errNilConfig
+ case l.Port == 0:
+ return errors.Error("port must not be zero")
+ case l.Interface == "":
+ return errors.Error("no interface")
+ default:
+ return nil
+ }
+}
diff --git a/internal/cmd/network.go b/internal/cmd/network.go
new file mode 100644
index 0000000..232784a
--- /dev/null
+++ b/internal/cmd/network.go
@@ -0,0 +1,51 @@
+package cmd
+
+import (
+ "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
+)
+
+// network defines the network settings.
+//
+// TODO(a.garipov): Use [datasize.ByteSize] for sizes.
+type network struct {
+ // SndBufSize defines the size of socket send buffer in bytes. Default is
+ // zero (uses system settings).
+ SndBufSize int `yaml:"so_sndbuf"`
+
+ // RcvBufSize defines the size of socket receive buffer in bytes. Default
+ // is zero (uses system settings).
+ RcvBufSize int `yaml:"so_rcvbuf"`
+}
+
+// validate returns an error if the network configuration is invalid.
+func (n *network) validate() (err error) {
+ if n == nil {
+ return errNilConfig
+ }
+
+ if n.SndBufSize < 0 {
+ return newMustBeNonNegativeError("so_sndbuf", n.SndBufSize)
+ }
+
+ if n.RcvBufSize < 0 {
+ return newMustBeNonNegativeError("so_rcvbuf", n.RcvBufSize)
+ }
+
+ return nil
+}
+
+// toInternal converts n to the bindtodevice control configuration and network
+// extension control configuration.
+func (n *network) toInternal() (bc *bindtodevice.ControlConfig, nc *netext.ControlConfig) {
+ bc = &bindtodevice.ControlConfig{
+ SndBufSize: n.SndBufSize,
+ RcvBufSize: n.RcvBufSize,
+ }
+ nc = &netext.ControlConfig{
+ SndBufSize: n.SndBufSize,
+ RcvBufSize: n.RcvBufSize,
+ }
+
+ return bc, nc
+}
diff --git a/internal/cmd/ratelimit.go b/internal/cmd/ratelimit.go
index e4834b0..ccc107a 100644
--- a/internal/cmd/ratelimit.go
+++ b/internal/cmd/ratelimit.go
@@ -6,8 +6,11 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
+ "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+ "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/c2h5oh/datasize"
)
@@ -19,6 +22,10 @@ type rateLimitConfig struct {
// AllowList is the allowlist of clients.
Allowlist *allowListConfig `yaml:"allowlist"`
+ // ConnectionLimit is the configuration for the limits on stream
+ // connections.
+ ConnectionLimit *connLimitConfig `yaml:"connection_limit"`
+
// Rate limit options for IPv4 addresses.
IPv4 *rateLimitOptions `yaml:"ipv4"`
@@ -97,12 +104,18 @@ func (c *rateLimitConfig) toInternal(al ratelimit.Allowlist) (conf *ratelimit.Ba
// validate returns an error if the safe rate limiting configuration is invalid.
func (c *rateLimitConfig) validate() (err error) {
- if c == nil {
+ switch {
+ case c == nil:
return errNilConfig
- } else if c.Allowlist == nil {
+ case c.Allowlist == nil:
return fmt.Errorf("allowlist: %w", errNilConfig)
}
+ err = c.ConnectionLimit.validate()
+ if err != nil {
+ return fmt.Errorf("connection_limit: %w", err)
+ }
+
err = c.IPv4.validate()
if err != nil {
return fmt.Errorf("ipv4: %w", err)
@@ -129,16 +142,16 @@ func setupRateLimiter(
consulAllowlist *url.URL,
sigHdlr signalHandler,
errColl agd.ErrorCollector,
-) (rateLimiter *ratelimit.BackOff, err error) {
+) (rateLimiter *ratelimit.BackOff, connLimiter *connlimiter.Limiter, err error) {
allowSubnets, err := agdnet.ParseSubnets(conf.Allowlist.List...)
if err != nil {
- return nil, fmt.Errorf("parsing allowlist subnets: %w", err)
+ return nil, nil, fmt.Errorf("parsing allowlist subnets: %w", err)
}
allowlist := ratelimit.NewDynamicAllowlist(allowSubnets, nil)
refresher, err := consul.NewAllowlistRefresher(allowlist, consulAllowlist)
if err != nil {
- return nil, fmt.Errorf("creating allowlist refresher: %w", err)
+ return nil, nil, fmt.Errorf("creating allowlist refresher: %w", err)
}
refr := agd.NewRefreshWorker(&agd.RefreshWorkerConfig{
@@ -152,10 +165,67 @@ func setupRateLimiter(
})
err = refr.Start()
if err != nil {
- return nil, fmt.Errorf("starting allowlist refresher: %w", err)
+ return nil, nil, fmt.Errorf("starting allowlist refresher: %w", err)
}
sigHdlr.add(refr)
- return ratelimit.NewBackOff(conf.toInternal(allowlist)), nil
+ return ratelimit.NewBackOff(conf.toInternal(allowlist)), conf.ConnectionLimit.toInternal(), nil
+}
+
+// connLimitConfig is the configuration structure for the stream-connection
+// limiter.
+type connLimitConfig struct {
+ // Stop is the point at which the limiter stops accepting new connections.
+ // Once the number of active connections reaches this limit, new connections
+ // wait for the number to decrease below Resume.
+ //
+ // Stop must be greater than zero and greater than or equal to Resume.
+ Stop uint64 `yaml:"stop"`
+
+ // Resume is the point at which the limiter starts accepting new connections
+ // again.
+ //
+ // Resume must be greater than zero and less than or equal to Stop.
+ Resume uint64 `yaml:"resume"`
+
+ // Enabled, if true, enables stream-connection limiting.
+ Enabled bool `yaml:"enabled"`
+}
+
+// toInternal converts c to the connection limiter to use. c is assumed to be
+// valid.
+func (c *connLimitConfig) toInternal() (l *connlimiter.Limiter) {
+ if !c.Enabled {
+ return nil
+ }
+
+ l, err := connlimiter.New(&connlimiter.Config{
+ Stop: c.Stop,
+ Resume: c.Resume,
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ metrics.ConnLimiterLimits.WithLabelValues("stop").Set(float64(c.Stop))
+ metrics.ConnLimiterLimits.WithLabelValues("resume").Set(float64(c.Resume))
+
+ return l
+}
+
+// validate returns an error if the connection limit configuration is invalid.
+func (c *connLimitConfig) validate() (err error) {
+ switch {
+ case c == nil:
+ return errNilConfig
+ case !c.Enabled:
+ return nil
+ case c.Stop == 0:
+ return newMustBePositiveError("stop", c.Stop)
+ case c.Resume > c.Stop:
+ return errors.Error("resume: must be less than or equal to stop")
+ default:
+ return nil
+ }
}
diff --git a/internal/cmd/safebrowsing.go b/internal/cmd/safebrowsing.go
index e8cc36a..55f9503 100644
--- a/internal/cmd/safebrowsing.go
+++ b/internal/cmd/safebrowsing.go
@@ -7,8 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
@@ -45,13 +44,13 @@ func (c *safeBrowsingConfig) toInternal(
resolver agdnet.Resolver,
id agd.FilterListID,
cacheDir string,
-) (fltConf *filter.HashPrefixConfig, err error) {
- hashes, err := hashstorage.New("")
+) (fltConf *hashprefix.FilterConfig, err error) {
+ hashes, err := hashprefix.NewStorage("")
if err != nil {
return nil, err
}
- return &filter.HashPrefixConfig{
+ return &hashprefix.FilterConfig{
Hashes: hashes,
URL: netutil.CloneURL(&c.URL.URL),
ErrColl: errColl,
@@ -95,13 +94,13 @@ func setupHashPrefixFilter(
cachePath string,
sigHdlr signalHandler,
errColl agd.ErrorCollector,
-) (strg *hashstorage.Storage, flt *filter.HashPrefix, err error) {
+) (strg *hashprefix.Storage, flt *hashprefix.Filter, err error) {
fltConf, err := conf.toInternal(errColl, resolver, id, cachePath)
if err != nil {
return nil, nil, fmt.Errorf("configuring hash prefix filter %s: %w", id, err)
}
- flt, err = filter.NewHashPrefix(fltConf)
+ flt, err = hashprefix.NewFilter(fltConf)
if err != nil {
return nil, nil, fmt.Errorf("creating hash prefix filter %s: %w", id, err)
}
diff --git a/internal/cmd/server.go b/internal/cmd/server.go
index 56c64f7..1aea133 100644
--- a/internal/cmd/server.go
+++ b/internal/cmd/server.go
@@ -5,6 +5,8 @@ import (
"net/netip"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
@@ -14,10 +16,18 @@ import (
// toInternal returns the configuration of DNS servers for a single server
// group. srvs is assumed to be valid.
-func (srvs servers) toInternal(tlsConfig *agd.TLS) (dnsSrvs []*agd.Server, err error) {
+func (srvs servers) toInternal(
+ tlsConfig *agd.TLS,
+ btdMgr *bindtodevice.Manager,
+) (dnsSrvs []*agd.Server, err error) {
dnsSrvs = make([]*agd.Server, 0, len(srvs))
for _, srv := range srvs {
- bindData := srv.bindData()
+ var bindData []*agd.ServerBindData
+ bindData, err = srv.bindData(btdMgr)
+ if err != nil {
+ return nil, fmt.Errorf("server %q: %w", srv.Name, err)
+ }
+
name := agd.ServerName(srv.Name)
switch p := srv.Protocol; p {
case srvProtoDNS:
@@ -158,27 +168,56 @@ type server struct {
// Protocol is the protocol of the server.
Protocol serverProto `yaml:"protocol"`
- // BindAddresses are addresses this server binds to.
+ // BindAddresses are addresses this server binds to. If BindAddresses is
+ // set, BindInterfaces must not be set.
BindAddresses []netip.AddrPort `yaml:"bind_addresses"`
+ // BindInterfaces are network interface data for this server to bind to. If
+ // BindInterfaces is set, BindAddresses must not be set.
+ BindInterfaces []*serverBindInterface `yaml:"bind_interfaces"`
+
// LinkedIPEnabled shows if the linked IP addresses should be used to detect
// profiles on this server.
LinkedIPEnabled bool `yaml:"linked_ip_enabled"`
}
// bindData returns the socket binding data for this server.
-func (s *server) bindData() (bindData []*agd.ServerBindData) {
- addrs := s.BindAddresses
- bindData = make([]*agd.ServerBindData, 0, len(addrs))
- for _, addr := range addrs {
+func (s *server) bindData(
+ btdMgr *bindtodevice.Manager,
+) (bindData []*agd.ServerBindData, err error) {
+ if addrs := s.BindAddresses; len(addrs) > 0 {
+ bindData = make([]*agd.ServerBindData, 0, len(addrs))
+ for _, addr := range addrs {
+ bindData = append(bindData, &agd.ServerBindData{
+ AddrPort: addr,
+ })
+ }
+
+ return bindData, nil
+ }
+
+ if btdMgr == nil {
+ err = errors.Error("bind_interfaces are only supported when interface_listeners are set")
+
+ return nil, err
+ }
+
+ ifaces := s.BindInterfaces
+ bindData = make([]*agd.ServerBindData, 0, len(ifaces))
+ for i, iface := range s.BindInterfaces {
+ var lc netext.ListenConfig
+ lc, err = btdMgr.ListenConfig(iface.ID, iface.Subnet)
+ if err != nil {
+ return nil, fmt.Errorf("bind_interface at index %d: %w", i, err)
+ }
+
bindData = append(bindData, &agd.ServerBindData{
- AddrPort: addr,
+ ListenConfig: lc,
+ Address: string(iface.ID),
})
}
- // TODO(a.garipov): Support bind_interfaces.
-
- return bindData
+ return bindData, nil
}
// validate returns an error if the configuration is invalid.
@@ -188,13 +227,12 @@ func (s *server) validate() (err error) {
return errNilConfig
case s.Name == "":
return errors.Error("no name")
- case len(s.BindAddresses) == 0:
- return errors.Error("no bind_addresses")
}
- err = validateAddrs(s.BindAddresses)
+ err = s.validateBindData()
if err != nil {
- return fmt.Errorf("bind_addresses: %w", err)
+ // Don't wrap the error, because it's informative enough as is.
+ return err
}
err = s.Protocol.validate()
@@ -209,3 +247,62 @@ func (s *server) validate() (err error) {
return nil
}
+
+// validateBindData returns an error if the server's binding data aren't valid.
+func (s *server) validateBindData() (err error) {
+ bindAddrsSet, bindIfacesSet := len(s.BindAddresses) > 0, len(s.BindInterfaces) > 0
+ if bindAddrsSet {
+ if bindIfacesSet {
+ return errors.Error("bind_addresses and bind_interfaces cannot both be set")
+ }
+
+ err = validateAddrs(s.BindAddresses)
+ if err != nil {
+ return fmt.Errorf("bind_addresses: %w", err)
+ }
+
+ return nil
+ }
+
+ if !bindIfacesSet {
+ return errors.Error("neither bind_addresses nor bind_interfaces is set")
+ }
+
+ if s.Protocol != srvProtoDNS {
+ return fmt.Errorf(
+ "bind_interfaces: only supported for protocol %q, got %q",
+ srvProtoDNS,
+ s.Protocol,
+ )
+ }
+
+ for i, bindIface := range s.BindInterfaces {
+ err = bindIface.validate()
+ if err != nil {
+ return fmt.Errorf("bind_interfaces: at index %d: %w", i, err)
+ }
+ }
+
+ return nil
+}
+
+// serverBindInterface contains the data for a network interface binding.
+type serverBindInterface struct {
+ ID bindtodevice.ID `yaml:"id"`
+ Subnet netip.Prefix `yaml:"subnet"`
+}
+
+// validate returns an error if the network interface binding configuration is
+// invalid.
+func (c *serverBindInterface) validate() (err error) {
+ switch {
+ case c == nil:
+ return errNilConfig
+ case c.ID == "":
+ return errors.Error("no id")
+ case !c.Subnet.IsValid():
+ return errors.Error("bad subnet")
+ default:
+ return nil
+ }
+}
diff --git a/internal/cmd/servergroup.go b/internal/cmd/servergroup.go
index aee2e32..70b35d0 100644
--- a/internal/cmd/servergroup.go
+++ b/internal/cmd/servergroup.go
@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
@@ -19,6 +20,7 @@ type serverGroups []*serverGroup
// service. srvGrps is assumed to be valid.
func (srvGrps serverGroups) toInternal(
messages *dnsmsg.Constructor,
+ btdMgr *bindtodevice.Manager,
fltGrps map[agd.FilteringGroupID]*agd.FilteringGroup,
) (svcSrvGrps []*agd.ServerGroup, err error) {
svcSrvGrps = make([]*agd.ServerGroup, len(srvGrps))
@@ -42,7 +44,7 @@ func (srvGrps serverGroups) toInternal(
FilteringGroup: fltGrpID,
}
- svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf)
+ svcSrvGrps[i].Servers, err = g.Servers.toInternal(tlsConf, btdMgr)
if err != nil {
return nil, fmt.Errorf("server group %q: %w", g.Name, err)
}
diff --git a/internal/cmd/upstream.go b/internal/cmd/upstream.go
index a1171eb..9380031 100644
--- a/internal/cmd/upstream.go
+++ b/internal/cmd/upstream.go
@@ -6,9 +6,11 @@ import (
"net/netip"
"net/url"
"strings"
+ "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/prometheus"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil"
)
@@ -33,18 +35,32 @@ type upstreamConfig struct {
}
// toInternal converts c to the data storage configuration for the DNS server.
-func (c *upstreamConfig) toInternal() (conf *agd.Upstream, err error) {
- net, addrPort, err := splitUpstreamURL(c.Server)
+func (c *upstreamConfig) toInternal() (fwdConf *forward.HandlerConfig, err error) {
+ network, addrPort, err := splitUpstreamURL(c.Server)
if err != nil {
return nil, err
}
- return &agd.Upstream{
- Server: addrPort,
- Network: net,
- FallbackServers: c.FallbackServers,
- Timeout: c.Timeout.Duration,
- }, nil
+ fallbacks := c.FallbackServers
+ metricsListener := prometheus.NewForwardMetricsListener(len(fallbacks) + 1)
+
+ var hcInit time.Duration
+ if c.Healthcheck.Enabled {
+ hcInit = c.Healthcheck.Timeout.Duration
+ }
+
+ fwdConf = &forward.HandlerConfig{
+ Address: addrPort,
+ Network: network,
+ MetricsListener: metricsListener,
+ HealthcheckDomainTmpl: c.Healthcheck.DomainTmpl,
+ FallbackAddresses: c.FallbackServers,
+ Timeout: c.Timeout.Duration,
+ HealthcheckBackoffDuration: c.Healthcheck.BackoffDuration.Duration,
+ HealthcheckInitDuration: hcInit,
+ }
+
+ return fwdConf, nil
}
// validate returns an error if the upstream configuration is invalid.
diff --git a/internal/cmd/websvc.go b/internal/cmd/websvc.go
index 5e6834f..be3e6a0 100644
--- a/internal/cmd/websvc.go
+++ b/internal/cmd/websvc.go
@@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@@ -413,8 +414,8 @@ func (f *staticFile) toInternal() (file *websvc.StaticFile, err error) {
// Check Content-Type here as opposed to in validate, because we need
// all keys to be canonicalized first.
- if file.Headers.Get(agdhttp.HdrNameContentType) == "" {
- return nil, errors.Error("content: " + agdhttp.HdrNameContentType + " header is required")
+ if file.Headers.Get(httphdr.ContentType) == "" {
+ return nil, errors.Error("content: " + httphdr.ContentType + " header is required")
}
return file, nil
diff --git a/internal/connlimiter/conn.go b/internal/connlimiter/conn.go
new file mode 100644
index 0000000..eb61a9a
--- /dev/null
+++ b/internal/connlimiter/conn.go
@@ -0,0 +1,51 @@
+package connlimiter
+
+import (
+ "net"
+ "sync/atomic"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+ "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
+ "github.com/AdguardTeam/golibs/errors"
+)
+
+// limitConn is a wrapper for a stream connection that decreases the counter
+// value on close.
+//
+// See https://pkg.go.dev/golang.org/x/net/netutil#LimitListener.
+type limitConn struct {
+ net.Conn
+
+ decrement func()
+ start time.Time
+ serverInfo dnsserver.ServerInfo
+ isClosed atomic.Bool
+}
+
+// Close closes the underlying connection and decrements the counter.
+func (c *limitConn) Close() (err error) {
+ defer func() { err = errors.Annotate(err, "limit conn: %w") }()
+
+ if !c.isClosed.CompareAndSwap(false, true) {
+ return net.ErrClosed
+ }
+
+ // Close the connection immediately and wait for the counter decrement and
+ // metrics later.
+ err = c.Conn.Close()
+
+ connLife := time.Since(c.start).Seconds()
+ name := c.serverInfo.Name
+ optlog.Debug3("connlimiter: %s: closed conn from %s after %fs", name, c.RemoteAddr(), connLife)
+ metrics.StreamConnLifeDuration.WithLabelValues(
+ name,
+ c.serverInfo.Proto.String(),
+ c.serverInfo.Addr,
+ ).Observe(connLife)
+
+ c.decrement()
+
+ return err
+}
diff --git a/internal/connlimiter/counter.go b/internal/connlimiter/counter.go
new file mode 100644
index 0000000..e91c66a
--- /dev/null
+++ b/internal/connlimiter/counter.go
@@ -0,0 +1,35 @@
+package connlimiter
+
+// counter is the simultaneous stream-connection counter. It stops accepting
+// new connections once it reaches stop and resumes when the number of active
+// connections goes back to resume.
+//
+// Note that current is the number of both active stream-connections as well as
+// goroutines that are currently in the process of accepting a new connection
+// but haven't accepted one yet.
+type counter struct {
+ current uint64
+ stop uint64
+ resume uint64
+ isAccepting bool
+}
+
+// increment tries to add the connection to the current active connection count.
+// If the counter does not accept new connections, shouldAccept is false.
+func (c *counter) increment() (shouldAccept bool) {
+ if !c.isAccepting {
+ return false
+ }
+
+ c.current++
+ c.isAccepting = c.current < c.stop
+
+ return true
+}
+
+// decrement decreases the number of current active connections.
+func (c *counter) decrement() {
+ c.current--
+
+ c.isAccepting = c.isAccepting || c.current <= c.resume
+}
diff --git a/internal/connlimiter/counter_internal_test.go b/internal/connlimiter/counter_internal_test.go
new file mode 100644
index 0000000..2d63480
--- /dev/null
+++ b/internal/connlimiter/counter_internal_test.go
@@ -0,0 +1,42 @@
+package connlimiter
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCounter(t *testing.T) {
+ t.Run("same", func(t *testing.T) {
+ c := &counter{
+ current: 0,
+ stop: 1,
+ resume: 1,
+ isAccepting: true,
+ }
+
+ assert.True(t, c.increment())
+ assert.False(t, c.increment())
+
+ c.decrement()
+ assert.True(t, c.increment())
+ assert.False(t, c.increment())
+ })
+
+ t.Run("more", func(t *testing.T) {
+ c := &counter{
+ current: 0,
+ stop: 2,
+ resume: 1,
+ isAccepting: true,
+ }
+
+ assert.True(t, c.increment())
+ assert.True(t, c.increment())
+ assert.False(t, c.increment())
+
+ c.decrement()
+ assert.True(t, c.increment())
+ assert.False(t, c.increment())
+ })
+}
diff --git a/internal/connlimiter/limiter.go b/internal/connlimiter/limiter.go
new file mode 100644
index 0000000..797345b
--- /dev/null
+++ b/internal/connlimiter/limiter.go
@@ -0,0 +1,73 @@
+package connlimiter
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+)
+
+// Config is the configuration structure for the stream-connection limiter.
+type Config struct {
+ // Stop is the point at which the limiter stops accepting new connections.
+ // Once the number of active connections reaches this limit, new connections
+ // wait for the number to decrease to or below Resume.
+ //
+ // Stop must be greater than zero and greater than or equal to Resume.
+ Stop uint64
+
+ // Resume is the point at which the limiter starts accepting new connections
+ // again.
+ //
+ // Resume must be greater than zero and less than or equal to Stop.
+ Resume uint64
+}
+
+// Limiter is the stream-connection limiter.
+type Limiter struct {
+ // counterCond is the shared condition variable that protects counter.
+ counterCond *sync.Cond
+
+ // counter is the shared counter of active stream-connections.
+ counter *counter
+}
+
+// New returns a new *Limiter.
+func New(c *Config) (l *Limiter, err error) {
+ if c == nil || c.Stop == 0 || c.Resume > c.Stop {
+ return nil, fmt.Errorf("bad limiter config: %+v", c)
+ }
+
+ return &Limiter{
+ counterCond: sync.NewCond(&sync.Mutex{}),
+ counter: &counter{
+ current: 0,
+ stop: c.Stop,
+ resume: c.Resume,
+ isAccepting: true,
+ },
+ }, nil
+}
+
+// Limit wraps lsnr to control the number of active connections. srvInfo is
+// used for logging and metrics.
+func (l *Limiter) Limit(lsnr net.Listener, srvInfo dnsserver.ServerInfo) (limited net.Listener) {
+ name, addr := srvInfo.Name, srvInfo.Addr
+ proto := srvInfo.Proto.String()
+
+ return &limitListener{
+ Listener: lsnr,
+
+ counterCond: l.counterCond,
+ counter: l.counter,
+
+ serverInfo: srvInfo,
+
+ activeGauge: metrics.ConnLimiterActiveStreamConns.WithLabelValues(name, proto, addr),
+ waitingHist: metrics.StreamConnWaitDuration.WithLabelValues(name, proto, addr),
+
+ isClosed: false,
+ }
+}
diff --git a/internal/connlimiter/limiter_test.go b/internal/connlimiter/limiter_test.go
new file mode 100644
index 0000000..568e974
--- /dev/null
+++ b/internal/connlimiter/limiter_test.go
@@ -0,0 +1,122 @@
+package connlimiter_test
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// testTimeout is the common timeout for tests.
+const testTimeout = 1 * time.Second
+
+// testServerInfo is the common server information for tests.
+var testServerInfo = dnsserver.ServerInfo{
+ Name: "test_server",
+ Addr: "127.0.0.1:0",
+ Proto: agd.ProtoDoT,
+}
+
+func TestLimiter(t *testing.T) {
+ l, err := connlimiter.New(&connlimiter.Config{
+ Stop: 1,
+ Resume: 1,
+ })
+ require.NoError(t, err)
+
+ conn := &agdtest.Conn{
+ OnClose: func() (err error) { return nil },
+ OnLocalAddr: func() (laddr net.Addr) { panic("not implemented") },
+ OnRead: func(b []byte) (n int, err error) { panic("not implemented") },
+ OnRemoteAddr: func() (addr net.Addr) {
+ return &net.TCPAddr{
+ IP: netutil.IPv4Localhost().AsSlice(),
+ Port: 1234,
+ }
+ },
+ OnSetDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnSetReadDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnSetWriteDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnWrite: func(b []byte) (n int, err error) { panic("not implemented") },
+ }
+
+ lsnr := &agdtest.Listener{
+ OnAccept: func() (c net.Conn, err error) { return conn, nil },
+ OnAddr: func() (addr net.Addr) {
+ return &net.TCPAddr{
+ IP: netutil.IPv4Localhost().AsSlice(),
+ Port: 853,
+ }
+ },
+ OnClose: func() (err error) { return nil },
+ }
+
+ limited := l.Limit(lsnr, testServerInfo)
+
+ // Accept one connection.
+ gotConn, err := limited.Accept()
+ require.NoError(t, err)
+
+ // Try accepting another connection. This should block until gotConn is
+ // closed.
+ otherStarted, otherListened := make(chan struct{}, 1), make(chan struct{}, 1)
+ go func() {
+ pt := &testutil.PanicT{}
+
+ otherStarted <- struct{}{}
+
+ otherConn, otherErr := limited.Accept()
+ require.NoError(pt, otherErr)
+
+ otherListened <- struct{}{}
+
+ require.NoError(pt, otherConn.Close())
+ }()
+
+ // Wait for the other goroutine to start.
+ testutil.RequireReceive(t, otherStarted, testTimeout)
+
+ // Assert that the other connection hasn't been accepted.
+ var otherAccepted bool
+ select {
+ case <-otherListened:
+ otherAccepted = true
+ default:
+ otherAccepted = false
+ }
+ assert.False(t, otherAccepted)
+
+ require.NoError(t, gotConn.Close())
+
+ // Check that double close causes an error.
+ assert.ErrorIs(t, gotConn.Close(), net.ErrClosed)
+
+ testutil.RequireReceive(t, otherListened, testTimeout)
+
+ err = limited.Close()
+ require.NoError(t, err)
+
+ // Check that double close causes an error.
+ assert.ErrorIs(t, limited.Close(), net.ErrClosed)
+}
+
+func TestLimiter_badConf(t *testing.T) {
+ l, err := connlimiter.New(&connlimiter.Config{
+ Stop: 1,
+ Resume: 2,
+ })
+ assert.Nil(t, l)
+ assert.Error(t, err)
+}
diff --git a/internal/connlimiter/listenconfig.go b/internal/connlimiter/listenconfig.go
new file mode 100644
index 0000000..a30e689
--- /dev/null
+++ b/internal/connlimiter/listenconfig.go
@@ -0,0 +1,54 @@
+package connlimiter
+
+import (
+ "context"
+ "net"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
+)
+
+// type check
+var _ netext.ListenConfig = (*ListenConfig)(nil)
+
+// ListenConfig is a [netext.ListenConfig] that uses a [*Limiter] to limit the
+// number of active stream-connections.
+type ListenConfig struct {
+ listenConfig netext.ListenConfig
+ limiter *Limiter
+}
+
+// NewListenConfig returns a new netext.ListenConfig that uses l to limit the
+// number of active stream-connections.
+func NewListenConfig(c netext.ListenConfig, l *Limiter) (limited *ListenConfig) {
+ return &ListenConfig{
+ listenConfig: c,
+ limiter: l,
+ }
+}
+
+// ListenPacket implements the [netext.ListenConfig] interface for
+// *ListenConfig.
+func (c *ListenConfig) ListenPacket(
+ ctx context.Context,
+ network string,
+ address string,
+) (conn net.PacketConn, err error) {
+ return c.listenConfig.ListenPacket(ctx, network, address)
+}
+
+// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
+// Listen returns a net.Listener wrapped by c's limiter. ctx must contain a
+// [dnsserver.ServerInfo].
+func (c *ListenConfig) Listen(
+ ctx context.Context,
+ network string,
+ address string,
+) (l net.Listener, err error) {
+ l, err = c.listenConfig.Listen(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.limiter.Limit(l, dnsserver.MustServerInfoFromContext(ctx)), nil
+}
diff --git a/internal/connlimiter/listenconfig_test.go b/internal/connlimiter/listenconfig_test.go
new file mode 100644
index 0000000..58a2e4e
--- /dev/null
+++ b/internal/connlimiter/listenconfig_test.go
@@ -0,0 +1,75 @@
+package connlimiter_test
+
+import (
+ "context"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestListenConfig(t *testing.T) {
+ pc := &agdtest.PacketConn{
+ OnClose: func() (err error) { panic("not implemented") },
+ OnLocalAddr: func() (laddr net.Addr) { panic("not implemented") },
+ OnReadFrom: func(b []byte) (n int, addr net.Addr, err error) { panic("not implemented") },
+ OnSetDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnSetReadDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnSetWriteDeadline: func(t time.Time) (err error) { panic("not implemented") },
+ OnWriteTo: func(b []byte, addr net.Addr) (n int, err error) { panic("not implemented") },
+ }
+
+ lsnr := &agdtest.Listener{
+ OnAccept: func() (c net.Conn, err error) { panic("not implemented") },
+ OnAddr: func() (addr net.Addr) { panic("not implemented") },
+ OnClose: func() (err error) { return nil },
+ }
+
+ c := &agdtest.ListenConfig{
+ OnListen: func(
+ ctx context.Context,
+ network string,
+ address string,
+ ) (l net.Listener, err error) {
+ return lsnr, nil
+ },
+ OnListenPacket: func(
+ ctx context.Context,
+ network string,
+ address string,
+ ) (conn net.PacketConn, err error) {
+ return pc, nil
+ },
+ }
+
+ l, err := connlimiter.New(&connlimiter.Config{
+ Stop: 1,
+ Resume: 1,
+ })
+ require.NoError(t, err)
+
+ limited := connlimiter.NewListenConfig(c, l)
+
+ ctx := dnsserver.ContextWithServerInfo(context.Background(), testServerInfo)
+ gotLsnr, err := limited.Listen(ctx, "", "")
+ require.NoError(t, err)
+
+ // TODO(a.garipov): Add more testing logic here if [Limiter] becomes
+ // unexported.
+ assert.NotEqual(t, lsnr, gotLsnr)
+
+ err = gotLsnr.Close()
+ require.NoError(t, err)
+
+ gotPC, err := limited.ListenPacket(ctx, "", "")
+ require.NoError(t, err)
+
+ // TODO(a.garipov): Add more testing logic here if [Limiter] becomes
+ // unexported.
+ assert.Equal(t, pc, gotPC)
+}
diff --git a/internal/connlimiter/listener.go b/internal/connlimiter/listener.go
new file mode 100644
index 0000000..59d1f97
--- /dev/null
+++ b/internal/connlimiter/listener.go
@@ -0,0 +1,139 @@
+// Package connlimiter describes a limiter of the number of active
+// stream-connections.
+package connlimiter
+
+import (
+ "net"
+ "sync"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+// limitListener is a wrapper that uses a counter to limit the number of active
+// stream-connections.
+//
+// See https://pkg.go.dev/golang.org/x/net/netutil#LimitListener.
+type limitListener struct {
+ net.Listener
+
+ // counterCond is the condition variable that protects counter and isClosed
+ // through its locker, as well as signals when connections can be accepted
+ // again or when the listener has been closed.
+ counterCond *sync.Cond
+
+ // counter is the shared counter for all listeners.
+ counter *counter
+
+ // activeGauge is the metrics gauge of currently active stream-connections.
+ activeGauge prometheus.Gauge
+
+ // waitingHist is the metrics histogram of how much a connection spends
+ // waiting for an accept.
+ waitingHist prometheus.Observer
+
+ // serverInfo is used for logging and metrics in both the listener itself
+ // and in its conns.
+ serverInfo dnsserver.ServerInfo
+
+ // isClosed shows whether this listener has been closed.
+ isClosed bool
+}
+
+// Accept returns a new connection if the counter allows it. Otherwise, it
+// waits until the counter allows it or the listener is closed.
+func (l *limitListener) Accept() (conn net.Conn, err error) {
+ defer func() { err = errors.Annotate(err, "limit listener: %w") }()
+
+ waitStart := time.Now()
+
+ isClosed := l.increment()
+ if isClosed {
+ return nil, net.ErrClosed
+ }
+
+ l.waitingHist.Observe(time.Since(waitStart).Seconds())
+ l.activeGauge.Inc()
+
+ conn, err = l.Listener.Accept()
+ if err != nil {
+ l.decrement()
+
+ return nil, err
+ }
+
+ return &limitConn{
+ Conn: conn,
+
+ decrement: l.decrement,
+ start: time.Now(),
+ serverInfo: l.serverInfo,
+ }, nil
+}
+
+// increment waits until it can increase the number of active connections
+// in the counter. If the listener is closed while waiting, increment exits and
+// returns true
+func (l *limitListener) increment() (isClosed bool) {
+ l.counterCond.L.Lock()
+ defer l.counterCond.L.Unlock()
+
+ // Make sure to check both that the counter allows this connection and that
+ // the listener hasn't been closed. Only log about waiting for an increment
+ // when such waiting actually took place.
+ waited := false
+ for !l.counter.increment() && !l.isClosed {
+ if !waited {
+ optlog.Debug1("connlimiter: server %s: accept waiting", l.serverInfo.Name)
+
+ waited = true
+ }
+
+ l.counterCond.Wait()
+ }
+
+ if waited {
+ optlog.Debug1("connlimiter: server %s: accept stopped waiting", l.serverInfo.Name)
+ }
+
+ return l.isClosed
+}
+
+// decrement decreases the number of active connections in the counter and
+// broadcasts the change.
+func (l *limitListener) decrement() {
+ l.counterCond.L.Lock()
+ defer l.counterCond.L.Unlock()
+
+ l.activeGauge.Dec()
+
+ l.counter.decrement()
+
+ l.counterCond.Signal()
+}
+
+// Close closes the underlying listener and signals to all goroutines waiting
+// for an accept that the listener is closed now.
+func (l *limitListener) Close() (err error) {
+ defer func() { err = errors.Annotate(err, "limit listener: %w") }()
+
+ l.counterCond.L.Lock()
+ defer l.counterCond.L.Unlock()
+
+ if l.isClosed {
+ return net.ErrClosed
+ }
+
+ // Close the listener immediately; change the boolean and broadcast the
+ // change later.
+ err = l.Listener.Close()
+
+ l.isClosed = true
+
+ l.counterCond.Broadcast()
+
+ return err
+}
diff --git a/internal/dnscheck/consul.go b/internal/dnscheck/consul.go
index 190c3ab..d48a5f0 100644
--- a/internal/dnscheck/consul.go
+++ b/internal/dnscheck/consul.go
@@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
@@ -303,8 +304,8 @@ func (cc *Consul) serveCheckTest(ctx context.Context, w http.ResponseWriter, r *
}
h := w.Header()
- h.Set(agdhttp.HdrNameContentType, agdhttp.HdrValApplicationJSON)
- h.Set(agdhttp.HdrNameAccessControlAllowOrigin, agdhttp.HdrValWildcard)
+ h.Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON)
+ h.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard)
err = json.NewEncoder(w).Encode(inf)
if err != nil {
diff --git a/internal/dnsdb/bolt_test.go b/internal/dnsdb/bolt_test.go
index 0ca8069..a5bc228 100644
--- a/internal/dnsdb/bolt_test.go
+++ b/internal/dnsdb/bolt_test.go
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"net/url"
"os"
+ "strings"
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
@@ -15,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -39,9 +41,9 @@ func TestBolt_ServeHTTP(t *testing.T) {
const dname = "some-domain.name"
successHdr := http.Header{
- agdhttp.HdrNameContentType: []string{agdhttp.HdrValTextCSV},
- agdhttp.HdrNameTrailer: []string{agdhttp.HdrNameXError},
- agdhttp.HdrNameContentEncoding: []string{"gzip"},
+ httphdr.ContentType: []string{agdhttp.HdrValTextCSV},
+ httphdr.Trailer: []string{httphdr.XError},
+ httphdr.ContentEncoding: []string{"gzip"},
}
newMsg := func(rcode int, name string, qtype uint16) (m *dns.Msg) {
@@ -104,7 +106,10 @@ func TestBolt_ServeHTTP(t *testing.T) {
for _, m := range msgs {
ctx := context.Background()
db.Record(ctx, m, &agd.RequestInfo{
- Host: m.Question[0].Name,
+ // Emulate the logic from init middleware.
+ //
+ // See [dnssvc.initMw.newRequestInfo].
+ Host: strings.TrimSuffix(m.Question[0].Name, "."),
})
err := db.Refresh(context.Background())
@@ -117,7 +122,7 @@ func TestBolt_ServeHTTP(t *testing.T) {
(&url.URL{Scheme: "http", Host: "example.com"}).String(),
nil,
)
- r.Header.Add(agdhttp.HdrNameAcceptEncoding, "gzip")
+ r.Header.Add(httphdr.AcceptEncoding, "gzip")
for _, tc := range testCases {
db := newTmpBolt(t)
diff --git a/internal/dnsdb/bolthttp.go b/internal/dnsdb/bolthttp.go
index 2d969a0..0f32a15 100644
--- a/internal/dnsdb/bolthttp.go
+++ b/internal/dnsdb/bolthttp.go
@@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"go.etcd.io/bbolt"
)
@@ -38,7 +39,7 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
h := w.Header()
- h.Add(agdhttp.HdrNameContentType, agdhttp.HdrValTextCSV)
+ h.Add(httphdr.ContentType, agdhttp.HdrValTextCSV)
if dbPath == "" {
// No data.
@@ -47,10 +48,10 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
- h.Set(agdhttp.HdrNameTrailer, agdhttp.HdrNameXError)
+ h.Set(httphdr.Trailer, httphdr.XError)
defer func() {
if err != nil {
- h.Set(agdhttp.HdrNameXError, err.Error())
+ h.Set(httphdr.XError, err.Error())
agd.Collectf(ctx, db.errColl, "dnsdb: http handler error: %w", err)
}
}()
@@ -59,8 +60,8 @@ func (db *Bolt) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var rw io.Writer = w
// TODO(a.garipov): Consider parsing the quality value.
- if strings.Contains(r.Header.Get(agdhttp.HdrNameAcceptEncoding), "gzip") {
- h.Set(agdhttp.HdrNameContentEncoding, "gzip")
+ if strings.Contains(r.Header.Get(httphdr.AcceptEncoding), "gzip") {
+ h.Set(httphdr.ContentEncoding, "gzip")
gw := gzip.NewWriter(w)
defer func() { err = errors.WithDeferred(err, gw.Close()) }()
diff --git a/internal/dnsmsg/constructor.go b/internal/dnsmsg/constructor.go
index ee282de..e1cd697 100644
--- a/internal/dnsmsg/constructor.go
+++ b/internal/dnsmsg/constructor.go
@@ -195,12 +195,12 @@ func (c *Constructor) newHdr(req *dns.Msg, rrType RRType) (hdr dns.RR_Header) {
}
// newHdrWithClass returns a new resource record header with specified class.
-func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, class dns.Class) (hdr dns.RR_Header) {
+func (c *Constructor) newHdrWithClass(req *dns.Msg, rrType RRType, cl dns.Class) (h dns.RR_Header) {
return dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: rrType,
Ttl: uint32(c.fltRespTTL.Seconds()),
- Class: uint16(class),
+ Class: uint16(cl),
}
}
diff --git a/internal/dnsserver/cache/cache_test.go b/internal/dnsserver/cache/cache_test.go
index 6207c53..6673096 100644
--- a/internal/dnsserver/cache/cache_test.go
+++ b/internal/dnsserver/cache/cache_test.go
@@ -23,14 +23,8 @@ func TestMiddleware_Wrap(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
cnameReq := dnsservertest.NewReq(reqHostname, dns.TypeCNAME, dns.ClassINET)
- cnameAns := dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)},
- Sec: dnsservertest.SectionAnswer,
- }
- soaNs := dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewSOA(reqHostname, 3600, reqNs1, reqNs2)},
- Sec: dnsservertest.SectionNs,
- }
+ cnameAns := dnsservertest.SectionAnswer{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)}
+ soaNs := dnsservertest.SectionNs{dnsservertest.NewSOA(reqHostname, 3600, reqNs1, reqNs2)}
const N = 5
testCases := []struct {
@@ -40,8 +34,8 @@ func TestMiddleware_Wrap(t *testing.T) {
wantNumReq int
}{{
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, 3600, net.IP{1, 2, 3, 4})},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, 3600, net.IP{1, 2, 3, 4}),
}),
name: "simple_a",
wantNumReq: 1,
@@ -67,9 +61,8 @@ func TestMiddleware_Wrap(t *testing.T) {
wantNumReq: N,
}, {
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewNS(reqHostname, 3600, reqNs1)},
- Sec: dnsservertest.SectionNs,
+ resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.SectionNs{
+ dnsservertest.NewNS(reqHostname, 3600, reqNs1),
}),
name: "non_authoritative_nxdomain",
// TODO(ameshkov): Consider https://datatracker.ietf.org/doc/html/rfc2308#section-3.
@@ -86,15 +79,15 @@ func TestMiddleware_Wrap(t *testing.T) {
wantNumReq: 1,
}, {
req: cnameReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, 3600, reqCname)},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewCNAME(reqHostname, 3600, reqCname),
}),
name: "simple_cname_ans",
wantNumReq: 1,
}, {
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4})},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
}),
name: "expired_one",
wantNumReq: N,
diff --git a/internal/dnsserver/dnsservertest/handler.go b/internal/dnsserver/dnsservertest/handler.go
index 3f5a67f..18b444b 100644
--- a/internal/dnsserver/dnsservertest/handler.go
+++ b/internal/dnsserver/dnsservertest/handler.go
@@ -2,14 +2,19 @@ package dnsservertest
import (
"context"
- "net"
+ "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
-// CreateTestHandler creates a [dnsserver.Handler] with the specified parameters.
+// AnswerTTL is the default TTL of the test handler's answers.
+const AnswerTTL time.Duration = 100 * time.Second
+
+// CreateTestHandler creates a [dnsserver.Handler] with the specified
+// parameters. All responses will have the [TestAnsTTL] TTL.
func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
f := func(ctx context.Context, rw dnsserver.ResponseWriter, req *dns.Msg) (err error) {
// Check that necessary context keys are set.
@@ -20,30 +25,23 @@ func CreateTestHandler(recordsCount int) (h dnsserver.Handler) {
return errors.Error("client info does not contain server name")
}
- hostname := req.Question[0].Name
-
- resp := &dns.Msg{
- Compress: true,
+ ans := make(SectionAnswer, 0, recordsCount)
+ hdr := dns.RR_Header{
+ Name: req.Question[0].Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ Ttl: uint32(AnswerTTL.Seconds()),
}
- resp.SetReply(req)
+ ip := netutil.IPv4Localhost().Prev()
for i := 0; i < recordsCount; i++ {
- hdr := dns.RR_Header{
- Name: hostname,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 100,
- }
-
- a := &dns.A{
- // Add 1 to make sure that each IP is valid.
- A: net.IP{127, 0, 0, byte(i + 1)},
- Hdr: hdr,
- }
-
- resp.Answer = append(resp.Answer, a)
+ // Add 1 to make sure that each IP is valid.
+ ip = ip.Next()
+ ans = append(ans, &dns.A{Hdr: hdr, A: ip.AsSlice()})
}
+ resp := NewResp(dns.RcodeSuccess, req, ans)
+
_ = rw.WriteMsg(ctx, req, resp)
return nil
diff --git a/internal/dnsserver/dnsservertest/msg.go b/internal/dnsserver/dnsservertest/msg.go
index 6feac03..8481ced 100644
--- a/internal/dnsserver/dnsservertest/msg.go
+++ b/internal/dnsserver/dnsservertest/msg.go
@@ -4,23 +4,17 @@ import (
"net"
"testing"
+ "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
// CreateMessage creates a DNS message for the specified hostname and qtype.
func CreateMessage(hostname string, qtype uint16) (m *dns.Msg) {
- return &dns.Msg{
- MsgHdr: dns.MsgHdr{
- Id: dns.Id(),
- RecursionDesired: true,
- },
- Question: []dns.Question{{
- Name: dns.Fqdn(hostname),
- Qtype: qtype,
- Qclass: dns.ClassINET,
- }},
- }
+ m = NewReq(hostname, qtype, dns.ClassINET)
+ m.RecursionDesired = true
+
+ return m
}
// RequireResponse checks that the DNS response we received is what was
@@ -29,9 +23,9 @@ func RequireResponse(
t *testing.T,
req *dns.Msg,
resp *dns.Msg,
- expectedRecordsCount int,
- expectedRCode int,
- expectedTruncated bool,
+ wantAnsLen int,
+ wantRCode int,
+ wantTruncated bool,
) {
t.Helper()
@@ -40,26 +34,64 @@ func RequireResponse(
// Check that Opcode is not changed in the response
// regardless of the response status
require.Equal(t, req.Opcode, resp.Opcode)
- require.Equal(t, expectedRCode, resp.Rcode)
- require.Equal(t, expectedTruncated, resp.Truncated)
+ require.Equal(t, wantRCode, resp.Rcode)
+ require.Equal(t, wantTruncated, resp.Truncated)
require.True(t, resp.Response)
// Response must not have a Z flag set even for a query that does
// See https://github.com/miekg/dns/issues/975
require.False(t, resp.Zero)
- require.Equal(t, expectedRecordsCount, len(resp.Answer))
+ require.Len(t, resp.Answer, wantAnsLen)
// Check that there's an OPT record in the response
if len(req.Extra) > 0 {
- require.True(t, len(resp.Extra) > 0)
+ require.NotEmpty(t, resp.Extra)
}
- if expectedRecordsCount > 0 {
- a, ok := resp.Answer[0].(*dns.A)
- require.True(t, ok)
+ if wantAnsLen > 0 {
+ a := testutil.RequireTypeAssert[*dns.A](t, resp.Answer[0])
require.Equal(t, req.Question[0].Name, a.Hdr.Name)
}
}
+// RRSection is the resource record set to be appended to a new message created
+// by [NewReq] and [NewResp]. It's essentially a sum type of:
+//
+// - [SectionAnswer]
+// - [SectionNs]
+// - [SectionExtra]
+type RRSection interface {
+ // appendTo modifies m adding the resource record set into it appropriately.
+ appendTo(m *dns.Msg)
+}
+
+// type check
+var (
+ _ RRSection = SectionAnswer{}
+ _ RRSection = SectionNs{}
+ _ RRSection = SectionExtra{}
+)
+
+// SectionAnswer should wrap a resource record set for the Answer section of DNS
+// message.
+type SectionAnswer []dns.RR
+
+// appendTo implements the [RRSection] interface for SectionAnswer.
+func (rrs SectionAnswer) appendTo(m *dns.Msg) { m.Answer = append(m.Answer, ([]dns.RR)(rrs)...) }
+
+// SectionNs should wrap a resource record set for the Ns section of DNS
+// message.
+type SectionNs []dns.RR
+
+// appendTo implements the [RRSection] interface for SectionNs.
+func (rrs SectionNs) appendTo(m *dns.Msg) { m.Ns = append(m.Ns, ([]dns.RR)(rrs)...) }
+
+// SectionExtra should wrap a resource record set for the Extra section of DNS
+// message.
+type SectionExtra []dns.RR
+
+// appendTo implements the [RRSection] interface for SectionExtra.
+func (rrs SectionExtra) appendTo(m *dns.Msg) { m.Extra = append(m.Extra, ([]dns.RR)(rrs)...) }
+
// NewReq returns the new DNS request with a single question for name, qtype,
// qclass, and rrs added.
func NewReq(name string, qtype, qclass uint16, rrs ...RRSection) (req *dns.Msg) {
@@ -68,13 +100,15 @@ func NewReq(name string, qtype, qclass uint16, rrs ...RRSection) (req *dns.Msg)
Id: dns.Id(),
},
Question: []dns.Question{{
- Name: name,
+ Name: dns.Fqdn(name),
Qtype: qtype,
Qclass: qclass,
}},
}
- withRRs(req, rrs...)
+ for _, rr := range rrs {
+ rr.appendTo(req)
+ }
return req
}
@@ -86,50 +120,13 @@ func NewResp(rcode int, req *dns.Msg, rrs ...RRSection) (resp *dns.Msg) {
resp.RecursionAvailable = true
resp.Compress = true
- withRRs(resp, rrs...)
+ for _, rr := range rrs {
+ rr.appendTo(resp)
+ }
return resp
}
-// MsgSection is used to specify the resource record set of the DNS message.
-type MsgSection int
-
-// Possible values of the MsgSection.
-const (
- SectionAnswer MsgSection = iota
- SectionNs
- SectionExtra
-)
-
-// RRSection is the slice of resource records to be appended to a new message
-// created by NewReq and NewResp.
-//
-// TODO(e.burkov): Use separate types for different sections of DNS message
-// instead of constants.
-type RRSection struct {
- RRs []dns.RR
- Sec MsgSection
-}
-
-// withRRs adds rrs to the m. Invalid rrs are skipped.
-func withRRs(m *dns.Msg, rrs ...RRSection) {
- for _, r := range rrs {
- var msgRR *[]dns.RR
- switch r.Sec {
- case SectionAnswer:
- msgRR = &m.Answer
- case SectionNs:
- msgRR = &m.Ns
- case SectionExtra:
- msgRR = &m.Extra
- default:
- continue
- }
-
- *msgRR = append(*msgRR, r.RRs...)
- }
-}
-
// NewCNAME constructs the new resource record of type CNAME.
func NewCNAME(name string, ttl uint32, target string) (rr dns.RR) {
return &dns.CNAME{
diff --git a/internal/dnsserver/dnsservertest/msg_test.go b/internal/dnsserver/dnsservertest/msg_test.go
new file mode 100644
index 0000000..88e9682
--- /dev/null
+++ b/internal/dnsserver/dnsservertest/msg_test.go
@@ -0,0 +1,76 @@
+package dnsservertest_test
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/miekg/dns"
+)
+
+func ExampleNewReq() {
+ const nonUniqueID = 1234
+
+ m := dnsservertest.NewReq("example.org.", dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
+ dnsservertest.NewECSExtra(netutil.IPv4Zero(), uint16(netutil.AddrFamilyIPv4), 0, 0),
+ })
+ m.Id = nonUniqueID
+ fmt.Println(m)
+
+ // Output:
+ //
+ // ;; opcode: QUERY, status: NOERROR, id: 1234
+ // ;; flags:; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 1
+ //
+ // ;; OPT PSEUDOSECTION:
+ // ; EDNS: version 0; flags:; udp: 0
+ // ; SUBNET: 0.0.0.0/0/0
+ //
+ // ;; QUESTION SECTION:
+ // ;example.org. IN A
+}
+
+func ExampleNewResp() {
+ const (
+ nonUniqueID = 1234
+ testFQDN = "example.org."
+ realTestFQDN = "real." + testFQDN
+ )
+
+ m := dnsservertest.NewReq(testFQDN, dns.TypeA, dns.ClassINET, dnsservertest.SectionExtra{
+ dnsservertest.NewECSExtra(netutil.IPv4Zero(), uint16(netutil.AddrFamilyIPv4), 0, 0),
+ })
+ m.Id = nonUniqueID
+
+ m = dnsservertest.NewResp(dns.RcodeSuccess, m, dnsservertest.SectionAnswer{
+ dnsservertest.NewCNAME(testFQDN, 3600, realTestFQDN),
+ dnsservertest.NewA(realTestFQDN, 3600, net.IP{1, 2, 3, 4}),
+ }, dnsservertest.SectionNs{
+ dnsservertest.NewSOA(realTestFQDN, 1000, "ns."+realTestFQDN, "mbox."+realTestFQDN),
+ dnsservertest.NewNS(testFQDN, 1000, "ns."+testFQDN),
+ }, dnsservertest.SectionExtra{
+ m.IsEdns0(),
+ })
+ fmt.Println(m)
+
+ // Output:
+ //
+ // ;; opcode: QUERY, status: NOERROR, id: 1234
+ // ;; flags: qr ra; QUERY: 1, ANSWER: 2, AUTHORITY: 2, ADDITIONAL: 1
+ //
+ // ;; OPT PSEUDOSECTION:
+ // ; EDNS: version 0; flags:; udp: 0
+ // ; SUBNET: 0.0.0.0/0/0
+ //
+ // ;; QUESTION SECTION:
+ // ;example.org. IN A
+ //
+ // ;; ANSWER SECTION:
+ // example.org. 3600 IN CNAME real.example.org.
+ // real.example.org. 3600 IN A 1.2.3.4
+ //
+ // ;; AUTHORITY SECTION:
+ // real.example.org. 1000 IN SOA ns.real.example.org. mbox.real.example.org. 0 0 0 0 0
+ // example.org. 1000 IN NS ns.example.org.
+}
diff --git a/internal/dnsserver/example_test.go b/internal/dnsserver/example_test.go
index 71ed64f..4d058a0 100644
--- a/internal/dnsserver/example_test.go
+++ b/internal/dnsserver/example_test.go
@@ -59,7 +59,7 @@ func ExampleWithMiddlewares() {
forwarder := forward.NewHandler(&forward.HandlerConfig{
Address: netip.MustParseAddrPort("94.140.14.140:53"),
Network: forward.NetworkAny,
- }, true)
+ })
middleware := querylog.NewLogMiddleware(os.Stdout)
handler := dnsserver.WithMiddlewares(forwarder, middleware)
diff --git a/internal/dnsserver/forward/example_test.go b/internal/dnsserver/forward/example_test.go
index 897a283..b22035c 100644
--- a/internal/dnsserver/forward/example_test.go
+++ b/internal/dnsserver/forward/example_test.go
@@ -20,7 +20,7 @@ func ExampleNewHandler() {
FallbackAddresses: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:53"),
},
- }, false),
+ }),
},
}
diff --git a/internal/dnsserver/forward/forward.go b/internal/dnsserver/forward/forward.go
index a683ccc..c6e3fa8 100644
--- a/internal/dnsserver/forward/forward.go
+++ b/internal/dnsserver/forward/forward.go
@@ -38,20 +38,6 @@ import (
// queries to the specified upstreams. It also implements [io.Closer], allowing
// resource reuse.
type Handler struct {
- // lastFailedHealthcheck shows the last time of failed healthcheck.
- //
- // It is of type int64 to be accessed by package atomic. The field is
- // arranged for 64-bit alignment on the first position.
- lastFailedHealthcheck int64
-
- // useFallbacks is not zero if the main upstream server failed health check
- // probes and therefore the fallback upstream servers should be used for
- // resolving.
- //
- // It is of type uint64 to be accessed by package atomic. The field is
- // arranged for 64-bit alignment on the second position.
- useFallbacks uint64
-
// metrics is a listener for the handler events.
metrics MetricsListener
@@ -65,12 +51,21 @@ type Handler struct {
// fallbacks is a list of fallback DNS servers.
fallbacks []Upstream
+ // lastFailedHealthcheck contains the Unix time of the last time of failed
+ // healthcheck.
+ lastFailedHealthcheck atomic.Int64
+
// timeout specifies the query timeout for upstreams and fallbacks.
timeout time.Duration
// hcBackoffTime specifies the delay before returning to the main upstream
// after failed healthcheck probe.
hcBackoff time.Duration
+
+ // useFallbacks is true if the main upstream server failed health check
+ // probes and therefore the fallback upstream servers should be used for
+ // resolving.
+ useFallbacks atomic.Bool
}
// ErrNoResponse is returned from Handler's methods when the desired response
@@ -113,14 +108,21 @@ type HandlerConfig struct {
// upstream until this time has passed. If the healthcheck is still
// performed, each failed check advances the backoff.
HealthcheckBackoffDuration time.Duration
+
+ // HealthcheckInitDuration is the time duration for initial upstream
+ // healthcheck.
+ HealthcheckInitDuration time.Duration
}
// NewHandler initializes a new instance of Handler. It also performs a health
-// check afterwards if initialHealthcheck is true. Note, that this handler only
-// support plain DNS upstreams. c must not be nil.
-func NewHandler(c *HandlerConfig, initialHealthcheck bool) (h *Handler) {
+// check afterwards if c.HealthcheckInitDuration is not zero. Note, that this
+// handler only support plain DNS upstreams. c must not be nil.
+func NewHandler(c *HandlerConfig) (h *Handler) {
h = &Handler{
- upstream: NewUpstreamPlain(c.Address, c.Network),
+ upstream: NewUpstreamPlain(&UpstreamPlainConfig{
+ Network: c.Network,
+ Address: c.Address,
+ }),
hcDomainTmpl: c.HealthcheckDomainTmpl,
timeout: c.Timeout,
hcBackoff: c.HealthcheckBackoffDuration,
@@ -134,13 +136,19 @@ func NewHandler(c *HandlerConfig, initialHealthcheck bool) (h *Handler) {
h.fallbacks = make([]Upstream, len(c.FallbackAddresses))
for i, addr := range c.FallbackAddresses {
- h.fallbacks[i] = NewUpstreamPlain(addr, NetworkAny)
+ h.fallbacks[i] = NewUpstreamPlain(&UpstreamPlainConfig{
+ Network: NetworkAny,
+ Address: addr,
+ })
}
- if initialHealthcheck {
+ if c.HealthcheckInitDuration > 0 {
+ ctx, cancel := context.WithTimeout(context.Background(), c.HealthcheckInitDuration)
+ defer cancel()
+
// Ignore the error since it's considered non-critical and also should
// have been logged already.
- _ = h.refresh(context.Background(), true)
+ _ = h.refresh(ctx, true)
}
return h
@@ -176,7 +184,7 @@ func (h *Handler) ServeDNS(
) (err error) {
defer func() { err = annotate(err, h.upstream) }()
- useFallbacks := atomic.LoadUint64(&h.useFallbacks) != 0
+ useFallbacks := h.useFallbacks.Load()
var resp *dns.Msg
if !useFallbacks {
resp, err = h.exchange(ctx, h.upstream, req)
diff --git a/internal/dnsserver/forward/forward_test.go b/internal/dnsserver/forward/forward_test.go
index 52fbe5a..7c3221e 100644
--- a/internal/dnsserver/forward/forward_test.go
+++ b/internal/dnsserver/forward/forward_test.go
@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"testing"
+ "time"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
@@ -17,6 +18,21 @@ func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
+// testTimeout is the timeout for tests.
+const testTimeout = 1 * time.Second
+
+// newTimeoutCtx is a test helper that returns a context with a timeout of
+// [testTimeout] and its cancel function being called in the test cleanup.
+// It should not be used where cancellation is expected sooner.
+func newTimeoutCtx(tb testing.TB, parent context.Context) (ctx context.Context) {
+ tb.Helper()
+
+ ctx, cancel := context.WithTimeout(parent, testTimeout)
+ tb.Cleanup(cancel)
+
+ return ctx
+}
+
func TestHandler_ServeDNS(t *testing.T) {
srv, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
@@ -24,13 +40,14 @@ func TestHandler_ServeDNS(t *testing.T) {
handler := forward.NewHandler(&forward.HandlerConfig{
Address: netip.MustParseAddrPort(addr),
Network: forward.NetworkAny,
- }, true)
+ Timeout: testTimeout,
+ })
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr())
// Check the handler's ServeDNS method
- err := handler.ServeDNS(context.Background(), rw, req)
+ err := handler.ServeDNS(newTimeoutCtx(t, context.Background()), rw, req)
require.NoError(t, err)
res := rw.Msg()
@@ -46,7 +63,8 @@ func TestHandler_ServeDNS_fallbackNetError(t *testing.T) {
FallbackAddresses: []netip.AddrPort{
netip.MustParseAddrPort(srv.LocalUDPAddr().String()),
},
- }, true)
+ Timeout: testTimeout,
+ })
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
rw := dnsserver.NewNonWriterResponseWriter(srv.LocalUDPAddr(), srv.LocalUDPAddr())
diff --git a/internal/dnsserver/forward/healthcheck.go b/internal/dnsserver/forward/healthcheck.go
index b385dd4..d873efe 100644
--- a/internal/dnsserver/forward/healthcheck.go
+++ b/internal/dnsserver/forward/healthcheck.go
@@ -6,7 +6,6 @@ import (
"math/rand"
"strconv"
"strings"
- "sync/atomic"
"time"
"github.com/AdguardTeam/golibs/errors"
@@ -25,25 +24,25 @@ func (h *Handler) refresh(ctx context.Context, shouldReport bool) (err error) {
return nil
}
- var useFallbacks uint64
- lastFailed := atomic.LoadInt64(&h.lastFailedHealthcheck)
+ var useFallbacks bool
+ lastFailed := h.lastFailedHealthcheck.Load()
shouldReturnToMain := time.Since(time.Unix(lastFailed, 0)) >= h.hcBackoff
if !shouldReturnToMain {
// Make sure that useFallbacks is left true if the main upstream is
// still in the backoff mode.
- useFallbacks = 1
+ useFallbacks = true
log.Debug("forward: healthcheck: in backoff, will not return to main on success")
}
err = h.healthcheck(ctx)
if err != nil {
- atomic.StoreInt64(&h.lastFailedHealthcheck, time.Now().Unix())
- useFallbacks = 1
+ h.lastFailedHealthcheck.Store(time.Now().Unix())
+ useFallbacks = true
}
- statusChanged := atomic.CompareAndSwapUint64(&h.useFallbacks, 1-useFallbacks, useFallbacks)
+ statusChanged := h.useFallbacks.CompareAndSwap(!useFallbacks, useFallbacks)
if statusChanged || shouldReport {
- h.setUpstreamStatus(useFallbacks == 0)
+ h.setUpstreamStatus(!useFallbacks)
}
return errors.Annotate(err, "forward: %w")
diff --git a/internal/dnsserver/forward/healthcheck_test.go b/internal/dnsserver/forward/healthcheck_test.go
index 8a3e2aa..7c224f9 100644
--- a/internal/dnsserver/forward/healthcheck_test.go
+++ b/internal/dnsserver/forward/healthcheck_test.go
@@ -15,8 +15,10 @@ import (
)
func TestHandler_Refresh(t *testing.T) {
- var upstreamUp uint64
- var upstreamRequestsCount uint64
+ var upstreamIsUp atomic.Bool
+ var upstreamRequestsCount atomic.Int64
+
+ defaultHandler := dnsservertest.DefaultHandler()
// This handler writes an empty message if upstreamUp flag is false.
handlerFunc := dnsserver.HandlerFunc(func(
@@ -24,16 +26,15 @@ func TestHandler_Refresh(t *testing.T) {
rw dnsserver.ResponseWriter,
req *dns.Msg,
) (err error) {
- atomic.AddUint64(&upstreamRequestsCount, 1)
+ upstreamRequestsCount.Add(1)
nrw := dnsserver.NewNonWriterResponseWriter(rw.LocalAddr(), rw.RemoteAddr())
- handler := dnsservertest.DefaultHandler()
- err = handler.ServeDNS(ctx, nrw, req)
+ err = defaultHandler.ServeDNS(ctx, nrw, req)
if err != nil {
return err
}
- if atomic.LoadUint64(&upstreamUp) == 0 {
+ if !upstreamIsUp.Load() {
return rw.WriteMsg(ctx, req, &dns.Msg{})
}
@@ -41,7 +42,7 @@ func TestHandler_Refresh(t *testing.T) {
})
upstream, _ := dnsservertest.RunDNSServer(t, handlerFunc)
- fallback, _ := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
+ fallback, _ := dnsservertest.RunDNSServer(t, defaultHandler)
handler := forward.NewHandler(&forward.HandlerConfig{
Address: netip.MustParseAddrPort(upstream.LocalUDPAddr().String()),
Network: forward.NetworkAny,
@@ -49,38 +50,41 @@ func TestHandler_Refresh(t *testing.T) {
FallbackAddresses: []netip.AddrPort{
netip.MustParseAddrPort(fallback.LocalUDPAddr().String()),
},
- // Make sure that the handler routs queries back to the main upstream
+ Timeout: testTimeout,
+ // Make sure that the handler routes queries back to the main upstream
// immediately.
HealthcheckBackoffDuration: 0,
- }, false)
+ })
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
rw := dnsserver.NewNonWriterResponseWriter(fallback.LocalUDPAddr(), fallback.LocalUDPAddr())
- err := handler.ServeDNS(context.Background(), rw, req)
- require.Error(t, err)
- assert.Equal(t, uint64(1), atomic.LoadUint64(&upstreamRequestsCount))
+ ctx := context.Background()
- err = handler.Refresh(context.Background())
+ err := handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req)
require.Error(t, err)
- assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount))
+ assert.Equal(t, int64(2), upstreamRequestsCount.Load())
- err = handler.ServeDNS(context.Background(), rw, req)
+ err = handler.Refresh(newTimeoutCtx(t, ctx))
+ require.Error(t, err)
+ assert.Equal(t, int64(4), upstreamRequestsCount.Load())
+
+ err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req)
require.NoError(t, err)
- assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount))
+ assert.Equal(t, int64(4), upstreamRequestsCount.Load())
// Now, set upstream up.
- atomic.StoreUint64(&upstreamUp, 1)
+ upstreamIsUp.Store(true)
- err = handler.ServeDNS(context.Background(), rw, req)
+ err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req)
require.NoError(t, err)
- assert.Equal(t, uint64(2), atomic.LoadUint64(&upstreamRequestsCount))
+ assert.Equal(t, int64(4), upstreamRequestsCount.Load())
- err = handler.Refresh(context.Background())
+ err = handler.Refresh(newTimeoutCtx(t, ctx))
require.NoError(t, err)
- assert.Equal(t, uint64(3), atomic.LoadUint64(&upstreamRequestsCount))
+ assert.Equal(t, int64(5), upstreamRequestsCount.Load())
- err = handler.ServeDNS(context.Background(), rw, req)
+ err = handler.ServeDNS(newTimeoutCtx(t, ctx), rw, req)
require.NoError(t, err)
- assert.Equal(t, uint64(4), atomic.LoadUint64(&upstreamRequestsCount))
+ assert.Equal(t, int64(6), upstreamRequestsCount.Load())
}
diff --git a/internal/dnsserver/forward/upstreamplain.go b/internal/dnsserver/forward/upstreamplain.go
index 97d1a32..edf3e24 100644
--- a/internal/dnsserver/forward/upstreamplain.go
+++ b/internal/dnsserver/forward/upstreamplain.go
@@ -7,6 +7,7 @@ import (
"io"
"net"
"net/netip"
+ "strings"
"sync"
"time"
@@ -15,7 +16,7 @@ import (
"github.com/miekg/dns"
)
-// Network is a enumeration of networks UpstreamPlain supports
+// Network is a enumeration of networks UpstreamPlain supports.
type Network string
const (
@@ -72,11 +73,21 @@ type UpstreamPlain struct {
// type check
var _ Upstream = (*UpstreamPlain)(nil)
-// NewUpstreamPlain creates and initializes a new instance of UpstreamPlain.
-func NewUpstreamPlain(addr netip.AddrPort, network Network) (ups *UpstreamPlain) {
+// UpstreamPlainConfig is the configuration structure for a plain-DNS upstream.
+type UpstreamPlainConfig struct {
+ // Network is the network to use for this upstream.
+ Network Network
+
+ // Address is the address of the upstream DNS server.
+ Address netip.AddrPort
+}
+
+// NewUpstreamPlain returns a new properly initialized *UpstreamPlain. c must
+// not be nil.
+func NewUpstreamPlain(c *UpstreamPlainConfig) (ups *UpstreamPlain) {
ups = &UpstreamPlain{
- addr: addr,
- network: network,
+ addr: c.Address,
+ network: c.Network,
}
ups.connsPoolUDP = pool.NewPool(poolMaxCapacity, makeConnsPoolFactory(ups, NetworkUDP))
@@ -127,33 +138,39 @@ func (u *UpstreamPlain) String() (str string) {
return fmt.Sprintf("%s://%s", u.network, u.addr)
}
-// exchangeUDP attempts to send the DNS request over UDP. It returns a
-// fallbackToTCP flag to signal if we should fallback to using TCP instead.
-// this may happen if the response received over UDP was truncated and
+// exchangeUDP attempts to send the DNS request over UDP. It returns a
+// fallbackToTCP flag to signal if the caller should fallback to using TCP
+// instead. This may happen if the response received over UDP was truncated and
// TCP is enabled for this upstream or if UDP is disabled.
func (u *UpstreamPlain) exchangeUDP(
ctx context.Context,
req *dns.Msg,
) (fallbackToTCP bool, resp *dns.Msg, err error) {
if u.network == NetworkTCP {
- // fallback to TCP immediately.
+ // Fallback to TCP immediately.
return true, nil, nil
}
resp, err = u.exchangeNet(ctx, req, NetworkUDP)
if err != nil {
- // error means that the upstream is dead, no need to fallback to TCP.
- return false, resp, err
+ // The network error always causes the subsequent query attempt using
+ // fresh UDP connection, so if it happened again, the upstream is likely
+ // dead and using TCP appears meaningless. See [exchangeNet].
+ //
+ // Thus, non-network errors are considered being related to the
+ // response. It may also happen the received response is intended for
+ // another timeouted request sent from the same source port, but falling
+ // back to TCP in this case shouldn't hurt.
+ fallbackToTCP = !isExpectedConnErr(err)
+
+ return fallbackToTCP, resp, err
}
- // If the response is truncated and we can use TCP, make sure that we'll
- // fallback to TCP. We also fallback to TCP if we received a response with
- // the wrong ID (it may happen with the servers under heavy load).
- if (resp.Truncated || resp.Id != req.Id) && u.network != NetworkUDP {
- fallbackToTCP = true
- }
+ // Also, fallback to TCP if the received response is truncated and the
+ // upstream isn't UDP-only.
+ fallbackToTCP = u.network != NetworkUDP && resp != nil && resp.Truncated
- return fallbackToTCP, resp, err
+ return fallbackToTCP, resp, nil
}
// exchangeNet sends a DNS query using the specified network (either TCP or UDP).
@@ -203,25 +220,36 @@ func (u *UpstreamPlain) exchangeNet(
return resp, err
}
-// validateResponse checks if the response is valid for the original query. For
-// instance, it is possible to receive a response to a different query, and we
-// must be sure that we received what was expected.
-func (u *UpstreamPlain) validateResponse(req, resp *dns.Msg) (err error) {
+// validatePlainResponse returns an error if the response is not valid for the
+// original request. This is required because we might receive a response to a
+// different query, e.g. when the server is under heavy load.
+func validatePlainResponse(req, resp *dns.Msg) (err error) {
if req.Id != resp.Id {
return dns.ErrId
}
- if len(resp.Question) != 1 {
- return ErrQuestion
+ if qlen := len(resp.Question); qlen != 1 {
+ return fmt.Errorf("%w: only 1 question allowed; got %d", ErrQuestion, qlen)
}
- if req.Question[0].Name != resp.Question[0].Name {
- return ErrQuestion
+ reqQ, respQ := req.Question[0], resp.Question[0]
+
+ if reqQ.Qtype != respQ.Qtype {
+ return fmt.Errorf("%w: mismatched type %s", ErrQuestion, dns.Type(respQ.Qtype))
+ }
+
+ // Compare the names case-insensitively, just like CoreDNS does.
+ if !strings.EqualFold(reqQ.Name, respQ.Name) {
+ return fmt.Errorf("%w: mismatched name %q", ErrQuestion, respQ.Name)
}
return nil
}
+// defaultUDPTimeout is the default timeout for waiting a valid DNS message or
+// network error.
+const defaultUDPTimeout = 1 * time.Minute
+
// processConn writes the query to the connection and then reads the response
// from it. We might be dealing with an idle dead connection so if we get
// a network error here, we'll attempt to open a new connection and call this
@@ -236,7 +264,7 @@ func (u *UpstreamPlain) processConn(
req *dns.Msg,
buf []byte,
bufLen int,
-) (msg *dns.Msg, err error) {
+) (resp *dns.Msg, err error) {
// Make sure that we return the connection to the pool in the end or close
// if there was any error.
defer func() {
@@ -248,7 +276,12 @@ func (u *UpstreamPlain) processConn(
}()
// Prepare a context with a deadline if needed.
- if deadline, ok := ctx.Deadline(); ok {
+ deadline, ok := ctx.Deadline()
+ if !ok && network == NetworkUDP {
+ deadline, ok = time.Now().Add(defaultUDPTimeout), true
+ }
+
+ if ok {
err = conn.SetDeadline(deadline)
if err != nil {
return nil, fmt.Errorf("setting deadline: %w", err)
@@ -261,19 +294,28 @@ func (u *UpstreamPlain) processConn(
return nil, fmt.Errorf("writing request: %w", err)
}
- var resp *dns.Msg
+ return u.readValidMsg(req, network, conn, buf)
+}
+
+// readValidMsg reads the response from conn to buf, parses and validates it.
+func (u *UpstreamPlain) readValidMsg(
+ req *dns.Msg,
+ network Network,
+ conn net.Conn,
+ buf []byte,
+) (resp *dns.Msg, err error) {
resp, err = u.readMsg(network, conn, buf)
if err != nil {
- // Error is already wrapped.
+ // Don't wrap the error, because it's informative enough as is.
return nil, err
}
- err = u.validateResponse(req, resp)
+ err = validatePlainResponse(req, resp)
if err != nil {
- return nil, fmt.Errorf("validating response: %w", err)
+ return resp, fmt.Errorf("validating %s response: %w", network, err)
}
- return resp, err
+ return resp, nil
}
// readMsg reads the response from the specified connection and parses it.
@@ -302,7 +344,8 @@ func (u *UpstreamPlain) readMsg(network Network, conn net.Conn, buf []byte) (*dn
if n < minDNSMessageSize {
return nil, fmt.Errorf("invalid msg: %w", dns.ErrShortRead)
}
- ret := new(dns.Msg)
+
+ ret := &dns.Msg{}
err = ret.Unpack(buf)
if err != nil {
return nil, fmt.Errorf("unpacking msg: %w", err)
diff --git a/internal/dnsserver/forward/upstreamplain_test.go b/internal/dnsserver/forward/upstreamplain_test.go
index 1d1b049..22e60a7 100644
--- a/internal/dnsserver/forward/upstreamplain_test.go
+++ b/internal/dnsserver/forward/upstreamplain_test.go
@@ -5,11 +5,14 @@ import (
"net/netip"
"testing"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward"
"github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -31,11 +34,14 @@ func TestUpstreamPlain_Exchange(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
- u := forward.NewUpstreamPlain(netip.MustParseAddrPort(addr), tc.network)
+ u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: tc.network,
+ Address: netip.MustParseAddrPort(addr),
+ })
defer log.OnCloserError(u, log.DEBUG)
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
- res, err := u.Exchange(context.Background(), req)
+ res, err := u.Exchange(newTimeoutCtx(t, context.Background()), req)
require.NoError(t, err)
require.NotNil(t, res)
dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false)
@@ -71,34 +77,199 @@ func TestUpstreamPlain_Exchange_truncated(t *testing.T) {
return rw.WriteMsg(ctx, req, res)
})
- _, addr := dnsservertest.RunDNSServer(t, handlerFunc)
+ _, addrStr := dnsservertest.RunDNSServer(t, handlerFunc)
// Create a test message.
req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
// First, check that we receive truncated response over UDP.
- uAddr := netip.MustParseAddrPort(addr)
- uUDP := forward.NewUpstreamPlain(uAddr, forward.NetworkUDP)
+ addr := netip.MustParseAddrPort(addrStr)
+ uUDP := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: forward.NetworkUDP,
+ Address: addr,
+ })
defer log.OnCloserError(uUDP, log.DEBUG)
- res, err := uUDP.Exchange(context.Background(), req)
+ ctx := context.Background()
+
+ res, err := uUDP.Exchange(newTimeoutCtx(t, ctx), req)
require.NoError(t, err)
dnsservertest.RequireResponse(t, req, res, 0, dns.RcodeSuccess, true)
// Second, check that nothing is truncated over TCP.
- uTCP := forward.NewUpstreamPlain(uAddr, forward.NetworkTCP)
+ uTCP := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: forward.NetworkTCP,
+ Address: addr,
+ })
defer log.OnCloserError(uTCP, log.DEBUG)
- res, err = uTCP.Exchange(context.Background(), req)
+ res, err = uTCP.Exchange(newTimeoutCtx(t, ctx), req)
require.NoError(t, err)
dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false)
// Now with NetworkANY response is also not truncated since the upstream
// fallbacks to TCP.
- uAny := forward.NewUpstreamPlain(uAddr, forward.NetworkAny)
+ uAny := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: forward.NetworkAny,
+ Address: addr,
+ })
defer log.OnCloserError(uAny, log.DEBUG)
- res, err = uAny.Exchange(context.Background(), req)
+ res, err = uAny.Exchange(newTimeoutCtx(t, ctx), req)
require.NoError(t, err)
dnsservertest.RequireResponse(t, req, res, 1, dns.RcodeSuccess, false)
}
+
+func TestUpstreamPlain_Exchange_fallbackFail(t *testing.T) {
+ pt := testutil.PanicT{}
+
+ // Use only unbuffered channels to block until received and validated.
+ netCh := make(chan string)
+ respCh := make(chan struct{})
+
+ h := dnsserver.HandlerFunc(func(
+ ctx context.Context,
+ rw dnsserver.ResponseWriter,
+ req *dns.Msg,
+ ) (err error) {
+ testutil.RequireSend(pt, netCh, rw.RemoteAddr().Network(), testTimeout)
+
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
+
+ // Make all responses invalid.
+ resp.Id = req.Id + 1
+
+ return rw.WriteMsg(ctx, req, resp)
+ })
+
+ _, addr := dnsservertest.RunDNSServer(t, h)
+ u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: forward.NetworkUDP,
+ Address: netip.MustParseAddrPort(addr),
+ })
+ testutil.CleanupAndRequireSuccess(t, u.Close)
+
+ req := dnsservertest.CreateMessage("example.org.", dns.TypeA)
+
+ var resp *dns.Msg
+ var err error
+ go func() {
+ resp, err = u.Exchange(newTimeoutCtx(t, context.Background()), req)
+ testutil.RequireSend(pt, respCh, struct{}{}, testTimeout)
+ }()
+
+ // First attempt should use UDP and fail due to bad ID.
+ network, _ := testutil.RequireReceive(t, netCh, testTimeout)
+ require.Equal(t, string(forward.NetworkUDP), network)
+
+ // Second attempt should use TCP and succeed.
+ network, _ = testutil.RequireReceive(t, netCh, testTimeout)
+ require.Equal(t, string(forward.NetworkTCP), network)
+
+ testutil.RequireReceive(t, respCh, testTimeout)
+ require.ErrorIs(t, err, dns.ErrId)
+ assert.NotNil(t, resp)
+}
+
+func TestUpstreamPlain_Exchange_fallbackSuccess(t *testing.T) {
+ const (
+ // network is set to UDP to ensure that falling back to TCP will still
+ // be performed.
+ network = forward.NetworkUDP
+
+ goodDomain = "domain.example."
+ badDomain = "bad.example."
+ )
+
+ pt := testutil.PanicT{}
+
+ req := dnsservertest.CreateMessage(goodDomain, dns.TypeA)
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
+
+ // Prepare malformed responses.
+
+ badIDResp := dnsmsg.Clone(resp)
+ badIDResp.Id = ^req.Id
+
+ badQNumResp := dnsmsg.Clone(resp)
+ badQNumResp.Question = append(badQNumResp.Question, req.Question[0])
+
+ badQnameResp := dnsmsg.Clone(resp)
+ badQnameResp.Question[0].Name = badDomain
+
+ badQtypeResp := dnsmsg.Clone(resp)
+ badQtypeResp.Question[0].Qtype = dns.TypeMX
+
+ testCases := []struct {
+ udpResp *dns.Msg
+ name string
+ }{{
+ udpResp: badIDResp,
+ name: "wrong_id",
+ }, {
+ udpResp: badQNumResp,
+ name: "wrong_question)_number",
+ }, {
+ udpResp: badQnameResp,
+ name: "wrong_qname",
+ }, {
+ udpResp: badQtypeResp,
+ name: "wrong_qtype",
+ }}
+
+ for _, tc := range testCases {
+ clonedReq := dnsmsg.Clone(req)
+ badResp := dnsmsg.Clone(tc.udpResp)
+ goodResp := dnsmsg.Clone(resp)
+
+ // Use only unbuffered channels to block until received and validated.
+ netCh := make(chan string)
+ respCh := make(chan struct{})
+
+ h := dnsserver.HandlerFunc(func(
+ ctx context.Context,
+ rw dnsserver.ResponseWriter,
+ req *dns.Msg,
+ ) (err error) {
+ network := rw.RemoteAddr().Network()
+ testutil.RequireSend(pt, netCh, network, testTimeout)
+
+ if network == string(forward.NetworkUDP) {
+ // Respond with invalid message via UDP.
+ return rw.WriteMsg(ctx, req, badResp)
+ }
+
+ // Respond with valid message via TCP.
+ return rw.WriteMsg(ctx, req, goodResp)
+ })
+
+ t.Run(tc.name, func(t *testing.T) {
+ _, addr := dnsservertest.RunDNSServer(t, dnsserver.HandlerFunc(h))
+
+ u := forward.NewUpstreamPlain(&forward.UpstreamPlainConfig{
+ Network: network,
+ Address: netip.MustParseAddrPort(addr),
+ })
+ testutil.CleanupAndRequireSuccess(t, u.Close)
+
+ var actualResp *dns.Msg
+ var err error
+ go func() {
+ actualResp, err = u.Exchange(newTimeoutCtx(t, context.Background()), clonedReq)
+ testutil.RequireSend(pt, respCh, struct{}{}, testTimeout)
+ }()
+
+ // First attempt should use UDP and fail due to bad ID.
+ network, _ := testutil.RequireReceive(t, netCh, testTimeout)
+ require.Equal(t, string(forward.NetworkUDP), network)
+
+ // Second attempt should use TCP and succeed.
+ network, _ = testutil.RequireReceive(t, netCh, testTimeout)
+ require.Equal(t, string(forward.NetworkTCP), network)
+
+ testutil.RequireReceive(t, respCh, testTimeout)
+ require.NoError(t, err)
+ dnsservertest.RequireResponse(t, req, actualResp, 0, dns.RcodeSuccess, false)
+ })
+ }
+}
diff --git a/internal/dnsserver/go.mod b/internal/dnsserver/go.mod
index 3544026..38d5bd0 100644
--- a/internal/dnsserver/go.mod
+++ b/internal/dnsserver/go.mod
@@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver
go 1.20
require (
- github.com/AdguardTeam/golibs v0.12.1
+ github.com/AdguardTeam/golibs v0.13.2
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/ameshkov/dnsstamps v1.0.3
github.com/bluele/gcache v0.0.2
@@ -13,7 +13,7 @@ require (
github.com/prometheus/client_golang v1.14.0
github.com/quic-go/quic-go v0.33.0
github.com/stretchr/testify v1.8.2
- golang.org/x/exp v0.0.0-20230307190834-24139beb5833
+ golang.org/x/exp v0.0.0-20230321023759-10a507213a29
golang.org/x/net v0.8.0
golang.org/x/sys v0.6.0
)
diff --git a/internal/dnsserver/go.sum b/internal/dnsserver/go.sum
index a0d5590..0dd980d 100644
--- a/internal/dnsserver/go.sum
+++ b/internal/dnsserver/go.sum
@@ -1,5 +1,4 @@
-github.com/AdguardTeam/golibs v0.12.1 h1:bJfFzCnUCl+QsP6prUltM2Sjt0fTiDBPlxuAwfKP3g8=
-github.com/AdguardTeam/golibs v0.12.1/go.mod h1:rIglKDHdLvFT1UbhumBLHO9S4cvWS9MEyT1njommI/Y=
+github.com/AdguardTeam/golibs v0.13.2 h1:BPASsyQKmb+b8VnvsNOHp7bKfcZl9Z+Z2UhPjOiupSc=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
@@ -80,8 +79,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
-golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
-golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
+golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
diff --git a/internal/dnsserver/netext/listenconfig.go b/internal/dnsserver/netext/listenconfig.go
index 19a47d3..d148645 100644
--- a/internal/dnsserver/netext/listenconfig.go
+++ b/internal/dnsserver/netext/listenconfig.go
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
+ "syscall"
)
// ListenConfig is the interface that allows controlling options of connections
@@ -18,23 +19,43 @@ type ListenConfig interface {
ListenPacket(ctx context.Context, network, address string) (c net.PacketConn, err error)
}
+// defaultCtrlConf is the default control config. By default, don't alter
+// anything. defaultCtrlConf must not be mutated.
+var defaultCtrlConf = &ControlConfig{
+ RcvBufSize: 0,
+ SndBufSize: 0,
+}
+
// DefaultListenConfig returns the default [ListenConfig] used by the servers in
// this module except for the plain-DNS ones, which use
-// [DefaultListenConfigWithOOB].
-func DefaultListenConfig() (lc ListenConfig) {
+// [DefaultListenConfigWithOOB]. If conf is nil, a default configuration is
+// used.
+func DefaultListenConfig(conf *ControlConfig) (lc ListenConfig) {
+ if conf == nil {
+ conf = defaultCtrlConf
+ }
+
return &net.ListenConfig{
- Control: defaultListenControl,
+ Control: func(_, _ string, c syscall.RawConn) (err error) {
+ return listenControlWithSO(conf, c)
+ },
}
}
// DefaultListenConfigWithOOB returns the default [ListenConfig] used by the
// plain-DNS servers in this module. The resulting ListenConfig sets additional
// socket flags and processes the control-messages of connections created with
-// ListenPacket.
-func DefaultListenConfigWithOOB() (lc ListenConfig) {
+// ListenPacket. If conf is nil, a default configuration is used.
+func DefaultListenConfigWithOOB(conf *ControlConfig) (lc ListenConfig) {
+ if conf == nil {
+ conf = defaultCtrlConf
+ }
+
return &listenConfigOOB{
ListenConfig: net.ListenConfig{
- Control: defaultListenControl,
+ Control: func(_, _ string, c syscall.RawConn) (err error) {
+ return listenControlWithSO(conf, c)
+ },
},
}
}
@@ -71,3 +92,14 @@ func (lc *listenConfigOOB) ListenPacket(
return wrapPacketConn(c), nil
}
+
+// ControlConfig is the configuration of socket options.
+type ControlConfig struct {
+ // RcvBufSize defines the size of socket receive buffer in bytes. Default
+ // is zero (uses system settings).
+ RcvBufSize int
+
+ // SndBufSize defines the size of socket send buffer in bytes. Default is
+ // zero (uses system settings).
+ SndBufSize int
+}
diff --git a/internal/dnsserver/netext/listenconfig_unix.go b/internal/dnsserver/netext/listenconfig_unix.go
index 3435b0c..23fca54 100644
--- a/internal/dnsserver/netext/listenconfig_unix.go
+++ b/internal/dnsserver/netext/listenconfig_unix.go
@@ -13,17 +13,50 @@ import (
"golang.org/x/sys/unix"
)
-// defaultListenControl is used as a [net.ListenConfig.Control] function to set
-// the SO_REUSEPORT socket option on all sockets used by the DNS servers in this
-// package.
-func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
+// setSockOptFunc is a function that sets a socket option on fd.
+type setSockOptFunc func(fd int) (err error)
+
+// newSetSockOptFunc returns a socket-option function with the given parameters.
+func newSetSockOptFunc(name string, lvl, opt, val int) (o setSockOptFunc) {
+ return func(fd int) (err error) {
+ err = unix.SetsockoptInt(fd, lvl, opt, val)
+
+ return errors.Annotate(err, "setting %s: %w", name)
+ }
+}
+
+// listenControlWithSO is used as a [net.ListenConfig.Control] function to set
+// the SO_REUSEPORT, SO_SNDBUF, and SO_RCVBUF socket options on all sockets
+// used by the DNS servers in this package. conf must not be nil.
+func listenControlWithSO(conf *ControlConfig, c syscall.RawConn) (err error) {
+ opts := []setSockOptFunc{
+ newSetSockOptFunc("SO_REUSEPORT", unix.SOL_SOCKET, unix.SO_REUSEPORT, 1),
+ }
+
+ if conf.SndBufSize > 0 {
+ opts = append(
+ opts,
+ newSetSockOptFunc("SO_SNDBUF", unix.SOL_SOCKET, unix.SO_SNDBUF, conf.SndBufSize),
+ )
+ }
+
+ if conf.RcvBufSize > 0 {
+ opts = append(
+ opts,
+ newSetSockOptFunc("SO_RCVBUF", unix.SOL_SOCKET, unix.SO_RCVBUF, conf.RcvBufSize),
+ )
+ }
+
var opErr error
err = c.Control(func(fd uintptr) {
- opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+ fdInt := int(fd)
+ for _, opt := range opts {
+ opErr = opt(fdInt)
+ if opErr != nil {
+ return
+ }
+ }
})
- if err != nil {
- return err
- }
return errors.WithDeferred(opErr, err)
}
diff --git a/internal/dnsserver/netext/listenconfig_unix_test.go b/internal/dnsserver/netext/listenconfig_unix_test.go
index caa4f6f..d9e3155 100644
--- a/internal/dnsserver/netext/listenconfig_unix_test.go
+++ b/internal/dnsserver/netext/listenconfig_unix_test.go
@@ -15,7 +15,7 @@ import (
)
func TestDefaultListenConfigWithOOB(t *testing.T) {
- lc := netext.DefaultListenConfigWithOOB()
+ lc := netext.DefaultListenConfigWithOOB(nil)
require.NotNil(t, lc)
type syscallConner interface {
@@ -65,3 +65,81 @@ func TestDefaultListenConfigWithOOB(t *testing.T) {
require.NoError(t, err)
})
}
+
+func TestDefaultListenConfigWithSO(t *testing.T) {
+ const (
+ sndBufSize = 10000
+ rcvBufSize = 20000
+ )
+
+ lc := netext.DefaultListenConfigWithOOB(&netext.ControlConfig{
+ SndBufSize: sndBufSize,
+ RcvBufSize: rcvBufSize,
+ })
+ require.NotNil(t, lc)
+
+ type syscallConner interface {
+ SyscallConn() (c syscall.RawConn, err error)
+ }
+
+ t.Run("ipv4", func(t *testing.T) {
+ c, err := lc.ListenPacket(context.Background(), "udp4", "127.0.0.1:0")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
+ require.NoError(t, opErr)
+
+ // TODO(a.garipov): Rewrite this to use actual expected values for
+ // each OS.
+ assert.LessOrEqual(t, sndBufSize, val)
+ })
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
+ require.NoError(t, opErr)
+
+ assert.LessOrEqual(t, rcvBufSize, val)
+ })
+ require.NoError(t, err)
+ })
+
+ t.Run("ipv6", func(t *testing.T) {
+ c, err := lc.ListenPacket(context.Background(), "udp6", "[::1]:0")
+ if errors.Is(err, syscall.EADDRNOTAVAIL) {
+ // Some CI machines have IPv6 disabled.
+ t.Skipf("ipv6 seems to not be supported: %s", err)
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, c)
+ require.Implements(t, (*syscallConner)(nil), c)
+
+ sc, err := c.(syscallConner).SyscallConn()
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
+ require.NoError(t, opErr)
+
+ // TODO(a.garipov): Rewrite this to use actual expected values for
+ // each OS.
+ assert.LessOrEqual(t, sndBufSize, val)
+ })
+ require.NoError(t, err)
+
+ err = sc.Control(func(fd uintptr) {
+ val, opErr := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
+ require.NoError(t, opErr)
+
+ assert.LessOrEqual(t, rcvBufSize, val)
+ })
+ require.NoError(t, err)
+ })
+}
diff --git a/internal/dnsserver/netext/listenconfig_windows.go b/internal/dnsserver/netext/listenconfig_windows.go
index 20e11bf..8172b4a 100644
--- a/internal/dnsserver/netext/listenconfig_windows.go
+++ b/internal/dnsserver/netext/listenconfig_windows.go
@@ -7,9 +7,9 @@ import (
"syscall"
)
-// defaultListenControl is nil on Windows, because it doesn't support
-// SO_REUSEPORT.
-var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
+// listenControlWithSO is nil on Windows, because it doesn't support socket
+// options.
+var listenControlWithSO func(_ *ControlConfig, _ syscall.RawConn) (_ error)
// setIPOpts sets the IPv4 and IPv6 options on a packet connection.
func setIPOpts(c net.PacketConn) (err error) {
diff --git a/internal/dnsserver/netext/packetconn_linux_test.go b/internal/dnsserver/netext/packetconn_linux_test.go
index 6999445..e315457 100644
--- a/internal/dnsserver/netext/packetconn_linux_test.go
+++ b/internal/dnsserver/netext/packetconn_linux_test.go
@@ -51,7 +51,7 @@ func TestSessionPacketConn(t *testing.T) {
}
func testSessionPacketConn(t *testing.T, proto, addr string, dstIP net.IP) (isTimeout bool) {
- lc := netext.DefaultListenConfigWithOOB()
+ lc := netext.DefaultListenConfigWithOOB(nil)
require.NotNil(t, lc)
c, err := lc.ListenPacket(context.Background(), proto, addr)
diff --git a/internal/dnsserver/prometheus/forward_test.go b/internal/dnsserver/prometheus/forward_test.go
index d2ee285..da5975f 100644
--- a/internal/dnsserver/prometheus/forward_test.go
+++ b/internal/dnsserver/prometheus/forward_test.go
@@ -24,7 +24,7 @@ func TestForwardMetricsListener_integration_request(t *testing.T) {
Address: netip.MustParseAddrPort(addr),
Network: forward.NetworkAny,
MetricsListener: prometheus.NewForwardMetricsListener(0),
- }, true)
+ })
// Prepare a test DNS message and call the handler's ServeDNS function.
// It will then call the metrics listener and prom metrics should be
diff --git a/internal/dnsserver/serverbench_test.go b/internal/dnsserver/serverbench_test.go
index ad59fe7..2e0d14c 100644
--- a/internal/dnsserver/serverbench_test.go
+++ b/internal/dnsserver/serverbench_test.go
@@ -308,7 +308,7 @@ func BenchmarkServeQUIC(b *testing.B) {
}
// Open QUIC session
- sess, err := quic.DialAddr(addr.String(), tlsConfig, nil)
+ sess, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(b, err)
defer func() {
err = sess.CloseWithError(0, "")
diff --git a/internal/dnsserver/serverdns.go b/internal/dnsserver/serverdns.go
index 88ee67e..1e568a6 100644
--- a/internal/dnsserver/serverdns.go
+++ b/internal/dnsserver/serverdns.go
@@ -112,7 +112,7 @@ func newServerDNS(proto Protocol, conf ConfigDNS) (s *ServerDNS) {
}
if conf.ListenConfig == nil {
- conf.ListenConfig = netext.DefaultListenConfigWithOOB()
+ conf.ListenConfig = netext.DefaultListenConfigWithOOB(nil)
}
s = &ServerDNS{
diff --git a/internal/dnsserver/serverdns_test.go b/internal/dnsserver/serverdns_test.go
index 0e6f3d6..c522870 100644
--- a/internal/dnsserver/serverdns_test.go
+++ b/internal/dnsserver/serverdns_test.go
@@ -411,6 +411,8 @@ func TestServerDNS_integration_udpMsgIgnore(t *testing.T) {
}
func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) {
+ t.Parallel()
+
testCases := []struct {
name string
buf []byte
@@ -459,7 +461,10 @@ func TestServerDNS_integration_tcpMsgIgnore(t *testing.T) {
}
for _, tc := range testCases {
+ tc := tc
t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
_, addr := dnsservertest.RunDNSServer(t, dnsservertest.DefaultHandler())
conn, err := net.Dial("tcp", addr)
require.Nil(t, err)
diff --git a/internal/dnsserver/serverdnscrypt.go b/internal/dnsserver/serverdnscrypt.go
index 84e30d0..4297882 100644
--- a/internal/dnsserver/serverdnscrypt.go
+++ b/internal/dnsserver/serverdnscrypt.go
@@ -41,7 +41,7 @@ var _ Server = (*ServerDNSCrypt)(nil)
// NewServerDNSCrypt creates a new instance of ServerDNSCrypt.
func NewServerDNSCrypt(conf ConfigDNSCrypt) (s *ServerDNSCrypt) {
if conf.ListenConfig == nil {
- conf.ListenConfig = netext.DefaultListenConfig()
+ conf.ListenConfig = netext.DefaultListenConfig(nil)
}
return &ServerDNSCrypt{
diff --git a/internal/dnsserver/serverhttps.go b/internal/dnsserver/serverhttps.go
index 7c45652..0bc8fa2 100644
--- a/internal/dnsserver/serverhttps.go
+++ b/internal/dnsserver/serverhttps.go
@@ -16,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
@@ -85,7 +86,7 @@ type ServerHTTPS struct {
h3Server *http3.Server
// quicListener is a listener that we use to serve DoH3 requests.
- quicListener quic.EarlyListener
+ quicListener *quic.EarlyListener
conf ConfigHTTPS
}
@@ -98,7 +99,7 @@ func NewServerHTTPS(conf ConfigHTTPS) (s *ServerHTTPS) {
if conf.ListenConfig == nil {
// Do not enable OOB here, because ListenPacket is only used by HTTP/3,
// and quic-go sets the necessary flags.
- conf.ListenConfig = netext.DefaultListenConfig()
+ conf.ListenConfig = netext.DefaultListenConfig(nil)
}
s = &ServerHTTPS{
@@ -289,7 +290,7 @@ func (s *ServerHTTPS) serveHTTPS(ctx context.Context, hs *http.Server, l net.Lis
}
// serveH3 is launched in a worker goroutine and serves HTTP/3 requests.
-func (s *ServerHTTPS) serveH3(ctx context.Context, hs *http3.Server, ql quic.EarlyListener) {
+func (s *ServerHTTPS) serveH3(ctx context.Context, hs *http3.Server, ql *quic.EarlyListener) {
defer s.wg.Done()
// Do not recover from panics here since if this goroutine panics, the
@@ -441,10 +442,10 @@ func (h *httpHandler) writeResponse(
switch ct {
case MimeTypeDoH:
buf, err = resp.Pack()
- w.Header().Set("Content-Type", MimeTypeDoH)
+ w.Header().Set(httphdr.ContentType, MimeTypeDoH)
case MimeTypeJSON:
buf, err = dnsMsgToJSON(resp)
- w.Header().Set("Content-Type", MimeTypeJSON)
+ w.Header().Set(httphdr.ContentType, MimeTypeJSON)
default:
return fmt.Errorf("invalid content type: %s", ct)
}
@@ -458,8 +459,8 @@ func (h *httpHandler) writeResponse(
// lifetime (see Section 4.2 of [RFC7234]) so that the DoH client is
// more likely to use fresh DNS data.
maxAge := minimalTTL(resp)
- w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", maxAge.Seconds()))
- w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
+ w.Header().Set(httphdr.CacheControl, fmt.Sprintf("max-age=%f", maxAge.Seconds()))
+ w.Header().Set(httphdr.ContentLength, strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK)
// Write the actual response
diff --git a/internal/dnsserver/serverhttps_test.go b/internal/dnsserver/serverhttps_test.go
index fd9803e..ad32426 100644
--- a/internal/dnsserver/serverhttps_test.go
+++ b/internal/dnsserver/serverhttps_test.go
@@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -27,6 +28,8 @@ import (
)
func TestServerHTTPS_integration_serveRequests(t *testing.T) {
+ t.Parallel()
+
testCases := []struct {
name string
method string
@@ -101,7 +104,10 @@ func TestServerHTTPS_integration_serveRequests(t *testing.T) {
}}
for _, tc := range testCases {
+ tc := tc
t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
srv, err := dnsservertest.RunLocalHTTPSServer(
dnsservertest.DefaultHandler(),
@@ -356,7 +362,7 @@ func TestServerHTTPS_0RTT(t *testing.T) {
// quicConfig with TokenStore set so that 0-RTT was enabled.
quicConfig := &quic.Config{
TokenStore: quic.NewLRUTokenStore(1, 10),
- Tracer: quicTracer,
+ Tracer: quicTracer.TracerForConnection,
}
// ClientSessionCache in the tls.Config must also be set for 0-RTT to work.
@@ -516,7 +522,7 @@ func createDoH3Client(
tlsCfg *tls.Config,
cfg *quic.Config,
) (c quic.EarlyConnection, e error) {
- return quic.DialAddrEarlyContext(ctx, httpsAddr.String(), tlsCfg, cfg)
+ return quic.DialAddrEarly(ctx, httpsAddr.String(), tlsCfg, cfg)
},
QuicConfig: quicConfig,
TLSClientConfig: tlsConfig,
@@ -550,8 +556,8 @@ func createDoHRequest(proto, method string, msg *dns.Msg) (r *http.Request, err
return nil, err
}
- r.Header.Set("Content-Type", dnsserver.MimeTypeDoH)
- r.Header.Set("Accept", dnsserver.MimeTypeDoH)
+ r.Header.Set(httphdr.ContentType, dnsserver.MimeTypeDoH)
+ r.Header.Set(httphdr.Accept, dnsserver.MimeTypeDoH)
return r, nil
}
@@ -582,8 +588,8 @@ func createJSONRequest(
return nil, err
}
- r.Header.Set("Content-Type", dnsserver.MimeTypeJSON)
- r.Header.Set("Accept", dnsserver.MimeTypeJSON)
+ r.Header.Set(httphdr.ContentType, dnsserver.MimeTypeJSON)
+ r.Header.Set(httphdr.Accept, dnsserver.MimeTypeJSON)
return r, err
}
diff --git a/internal/dnsserver/serverquic.go b/internal/dnsserver/serverquic.go
index 016f4cc..ad0fb93 100644
--- a/internal/dnsserver/serverquic.go
+++ b/internal/dnsserver/serverquic.go
@@ -81,7 +81,7 @@ type ServerQUIC struct {
pool *ants.Pool
// quicListener is a listener that we use to accept DoQ connections.
- quicListener quic.Listener
+ quicListener *quic.Listener
// bytesPool is a pool to avoid unnecessary allocations when reading
// DNS packets.
@@ -101,7 +101,7 @@ func NewServerQUIC(conf ConfigQUIC) (s *ServerQUIC) {
if conf.ListenConfig == nil {
// Do not enable OOB here as quic-go will do that on its own.
- conf.ListenConfig = netext.DefaultListenConfig()
+ conf.ListenConfig = netext.DefaultListenConfig(nil)
}
s = &ServerQUIC{
@@ -220,7 +220,7 @@ func (s *ServerQUIC) startServeQUIC(ctx context.Context) {
}
// serveQUIC listens for incoming QUIC connections.
-func (s *ServerQUIC) serveQUIC(ctx context.Context, l quic.Listener) (err error) {
+func (s *ServerQUIC) serveQUIC(ctx context.Context, l *quic.Listener) (err error) {
connWg := &sync.WaitGroup{}
// Wait until all conns are processed before exiting this method
defer connWg.Wait()
@@ -261,7 +261,7 @@ func (s *ServerQUIC) serveQUIC(ctx context.Context, l quic.Listener) (err error)
// acceptQUICConn is a wrapper around quic.Listener.Accept that makes sure that the
// timeout is handled properly.
-func acceptQUICConn(ctx context.Context, l quic.Listener) (conn quic.Connection, err error) {
+func acceptQUICConn(ctx context.Context, l *quic.Listener) (conn quic.Connection, err error) {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(DefaultReadTimeout))
defer cancel()
@@ -662,9 +662,7 @@ func newServerQUICConfig(metrics MetricsListener) (conf *quic.Config) {
RequireAddressValidation: v.requiresValidation,
// Enable 0-RTT by default for all addresses, it's beneficial for the
// performance.
- Allow0RTT: func(net.Addr) (ok bool) {
- return true
- },
+ Allow0RTT: true,
}
}
diff --git a/internal/dnsserver/serverquic_test.go b/internal/dnsserver/serverquic_test.go
index cc0a86d..5415c4a 100644
--- a/internal/dnsserver/serverquic_test.go
+++ b/internal/dnsserver/serverquic_test.go
@@ -34,7 +34,7 @@ func TestServerQUIC_integration_query(t *testing.T) {
})
// Open a QUIC connection.
- conn, err := quic.DialAddr(addr.String(), tlsConfig, nil)
+ conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
defer testutil.CleanupAndRequireSuccess(t, func() (err error) {
@@ -83,7 +83,7 @@ func TestServerQUIC_integration_ENDS0Padding(t *testing.T) {
})
// Open a QUIC connection.
- conn, err := quic.DialAddr(addr.String(), tlsConfig, nil)
+ conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
defer func(conn quic.Connection, code quic.ApplicationErrorCode, s string) {
@@ -122,7 +122,7 @@ func TestServerQUIC_integration_0RTT(t *testing.T) {
// quicConfig with TokenStore set so that 0-RTT was enabled.
quicConfig := &quic.Config{
TokenStore: quic.NewLRUTokenStore(1, 10),
- Tracer: quicTracer,
+ Tracer: quicTracer.TracerForConnection,
}
// ClientSessionCache in the tls.Config must also be set for 0-RTT to work.
@@ -156,7 +156,7 @@ func TestServerQUIC_integration_largeQuery(t *testing.T) {
})
// Open a QUIC connection.
- conn, err := quic.DialAddr(addr.String(), tlsConfig, nil)
+ conn, err := quic.DialAddr(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
defer testutil.CleanupAndRequireSuccess(t, func() (err error) {
@@ -190,7 +190,7 @@ func testQUICExchange(
tlsConfig *tls.Config,
quicConfig *quic.Config,
) {
- conn, err := quic.DialAddrEarly(addr.String(), tlsConfig, quicConfig)
+ conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, quicConfig)
require.NoError(t, err)
defer testutil.CleanupAndRequireSuccess(t, func() (err error) {
diff --git a/internal/dnsserver/servertls_test.go b/internal/dnsserver/servertls_test.go
index aa37a3c..a97688d 100644
--- a/internal/dnsserver/servertls_test.go
+++ b/internal/dnsserver/servertls_test.go
@@ -45,6 +45,8 @@ func TestServerTLS_integration_queryTLS(t *testing.T) {
}
func TestServerTLS_integration_msgIgnore(t *testing.T) {
+ t.Parallel()
+
testCases := []struct {
name string
buf []byte
@@ -88,7 +90,10 @@ func TestServerTLS_integration_msgIgnore(t *testing.T) {
}
for _, tc := range testCases {
+ tc := tc
t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
tlsConfig := dnsservertest.CreateServerTLSConfig("example.org")
h := dnsservertest.DefaultHandler()
addr := dnsservertest.RunTLSServer(t, h, tlsConfig)
diff --git a/internal/dnssvc/debug.go b/internal/dnssvc/debug.go
index 3dd60c7..e40c320 100644
--- a/internal/dnssvc/debug.go
+++ b/internal/dnssvc/debug.go
@@ -17,15 +17,16 @@ import (
// Debug header name constants.
const (
- hdrNameResType = "res-type"
- hdrNameRuleListID = "rule-list-id"
- hdrNameRule = "rule"
- hdrNameClientIP = "client-ip"
- hdrNameDeviceID = "device-id"
- hdrNameProfileID = "profile-id"
- hdrNameCountry = "country"
- hdrNameASN = "asn"
- hdrNameHost = "adguard-dns.com."
+ hdrNameResType = "res-type"
+ hdrNameRuleListID = "rule-list-id"
+ hdrNameRule = "rule"
+ hdrNameClientIP = "client-ip"
+ hdrNameDeviceID = "device-id"
+ hdrNameProfileID = "profile-id"
+ hdrNameCountry = "country"
+ hdrNameASN = "asn"
+ hdrNameSubdivision = "subdivision"
+ hdrNameHost = "adguard-dns.com."
)
// writeDebugResponse writes the debug response to rw.
@@ -97,17 +98,41 @@ func (svc *Service) appendDebugExtraFromContext(
}
}
- if d := ri.Location; d != nil {
- setQuestionName(debugReq, "", hdrNameCountry)
- err = svc.messages.AppendDebugExtra(debugReq, resp, string(d.Country))
+ if loc := ri.Location; loc != nil {
+ err = svc.appendDebugExtraFromLocation(loc, debugReq, resp)
if err != nil {
- return fmt.Errorf("adding %s extra: %w", hdrNameCountry, err)
+ // Don't wrap the error, because it's informative enough as is.
+ return err
}
+ }
- setQuestionName(debugReq, "", hdrNameASN)
- err = svc.messages.AppendDebugExtra(debugReq, resp, strconv.FormatUint(uint64(d.ASN), 10))
+ return nil
+}
+
+// appendDebugExtraFromLocation adds debug info to response got from request
+// info location. loc should not be nil.
+func (svc *Service) appendDebugExtraFromLocation(
+ loc *agd.Location,
+ debugReq *dns.Msg,
+ resp *dns.Msg,
+) (err error) {
+ setQuestionName(debugReq, "", hdrNameCountry)
+ err = svc.messages.AppendDebugExtra(debugReq, resp, string(loc.Country))
+ if err != nil {
+ return fmt.Errorf("adding %s extra: %w", hdrNameCountry, err)
+ }
+
+ setQuestionName(debugReq, "", hdrNameASN)
+ err = svc.messages.AppendDebugExtra(debugReq, resp, strconv.FormatUint(uint64(loc.ASN), 10))
+ if err != nil {
+ return fmt.Errorf("adding %s extra: %w", hdrNameASN, err)
+ }
+
+ if subdivision := loc.TopSubdivision; subdivision != "" {
+ setQuestionName(debugReq, "", hdrNameSubdivision)
+ err = svc.messages.AppendDebugExtra(debugReq, resp, subdivision)
if err != nil {
- return fmt.Errorf("adding %s extra: %w", hdrNameASN, err)
+ return fmt.Errorf("adding %s extra: %w", hdrNameSubdivision, err)
}
}
diff --git a/internal/dnssvc/debug_internal_test.go b/internal/dnssvc/debug_internal_test.go
index 91c5eb5..46c9547 100644
--- a/internal/dnssvc/debug_internal_test.go
+++ b/internal/dnssvc/debug_internal_test.go
@@ -26,7 +26,7 @@ func newTXTExtra(strs [][2]string) (extra []dns.RR) {
Name: v[0],
Rrtype: dns.TypeTXT,
Class: dns.ClassCHAOS,
- Ttl: uint32(agdtest.FilteredResponseTTL.Seconds()),
+ Ttl: agdtest.FilteredResponseTTLSec,
},
Txt: []string{v[1]},
})
@@ -47,6 +47,7 @@ func TestService_writeDebugResponse(t *testing.T) {
blockRule = "||example.com^"
)
+ clientIPStr := testClientIP.String()
testCases := []struct {
name string
ri *agd.RequestInfo
@@ -59,7 +60,7 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
@@ -68,7 +69,7 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: &filter.ResultBlocked{List: fltListID1, Rule: blockRule},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"req.res-type.adguard-dns.com.", "blocked"},
{"req.rule.adguard-dns.com.", "||example.com^"},
{"req.rule-list-id.adguard-dns.com.", "fl1"},
@@ -79,7 +80,7 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: nil,
respRes: &filter.ResultBlocked{List: fltListID2, Rule: blockRule},
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"resp.res-type.adguard-dns.com.", "blocked"},
{"resp.rule.adguard-dns.com.", "||example.com^"},
{"resp.rule-list-id.adguard-dns.com.", "fl2"},
@@ -90,7 +91,7 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: &filter.ResultAllowed{},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"req.res-type.adguard-dns.com.", "allowed"},
{"req.rule.adguard-dns.com.", ""},
{"req.rule-list-id.adguard-dns.com.", ""},
@@ -101,7 +102,7 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: nil,
respRes: &filter.ResultAllowed{},
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"resp.res-type.adguard-dns.com.", "allowed"},
{"resp.rule.adguard-dns.com.", ""},
{"resp.rule-list-id.adguard-dns.com.", ""},
@@ -114,31 +115,31 @@ func TestService_writeDebugResponse(t *testing.T) {
},
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
{"req.res-type.adguard-dns.com.", "modified"},
{"req.rule.adguard-dns.com.", "||example.com^$dnsrewrite=REFUSED"},
{"req.rule-list-id.adguard-dns.com.", ""},
}),
}, {
name: "device",
- ri: &agd.RequestInfo{Device: &agd.Device{ID: "dev1234"}},
+ ri: &agd.RequestInfo{Device: &agd.Device{ID: testDeviceID}},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
- {"device-id.adguard-dns.com.", "dev1234"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
+ {"device-id.adguard-dns.com.", testDeviceID},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
name: "profile",
ri: &agd.RequestInfo{
- Profile: &agd.Profile{ID: agd.ProfileID("some-profile-id")},
+ Profile: &agd.Profile{ID: testProfileID},
},
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
- {"profile-id.adguard-dns.com.", "some-profile-id"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
+ {"profile-id.adguard-dns.com.", testProfileID},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
}, {
@@ -147,11 +148,25 @@ func TestService_writeDebugResponse(t *testing.T) {
reqRes: nil,
respRes: nil,
wantExtra: newTXTExtra([][2]string{
- {"client-ip.adguard-dns.com.", "1.2.3.4"},
- {"country.adguard-dns.com.", "AD"},
+ {"client-ip.adguard-dns.com.", clientIPStr},
+ {"country.adguard-dns.com.", string(agd.CountryAD)},
{"asn.adguard-dns.com.", "0"},
{"resp.res-type.adguard-dns.com.", "normal"},
}),
+ }, {
+ name: "location_subdivision",
+ ri: &agd.RequestInfo{
+ Location: &agd.Location{Country: agd.CountryAD, TopSubdivision: "CA"},
+ },
+ reqRes: nil,
+ respRes: nil,
+ wantExtra: newTXTExtra([][2]string{
+ {"client-ip.adguard-dns.com.", clientIPStr},
+ {"country.adguard-dns.com.", string(agd.CountryAD)},
+ {"asn.adguard-dns.com.", "0"},
+ {"subdivision.adguard-dns.com.", "CA"},
+ {"resp.res-type.adguard-dns.com.", "normal"},
+ }),
}}
for _, tc := range testCases {
diff --git a/internal/dnssvc/deviceid_internal_test.go b/internal/dnssvc/deviceid_internal_test.go
index f70ebe0..ebb6559 100644
--- a/internal/dnssvc/deviceid_internal_test.go
+++ b/internal/dnssvc/deviceid_internal_test.go
@@ -46,8 +46,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_device_id",
- cliSrvName: "dev.dns.example.com",
- wantDeviceID: "dev",
+ cliSrvName: testDeviceID + ".dns.example.com",
+ wantDeviceID: testDeviceID,
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
proto: agd.ProtoDoT,
@@ -61,7 +61,7 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_deep_subdomain",
- cliSrvName: "abc.def.dns.example.com",
+ cliSrvName: "abc." + testDeviceID + ".dns.example.com",
wantDeviceID: "",
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
@@ -79,8 +79,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "quic_device_id",
- cliSrvName: "dev.dns.example.com",
- wantDeviceID: "dev",
+ cliSrvName: testDeviceID + ".dns.example.com",
+ wantDeviceID: testDeviceID,
wantErrMsg: "",
wildcards: []string{"*.dns.example.com"},
proto: agd.ProtoDoQ,
@@ -93,8 +93,8 @@ func TestService_Wrap_deviceID(t *testing.T) {
proto: agd.ProtoDoT,
}, {
name: "tls_device_id_subdomain_wildcard",
- cliSrvName: "dev.sub.dns.example.com",
- wantDeviceID: "dev",
+ cliSrvName: testDeviceID + ".sub.dns.example.com",
+ wantDeviceID: testDeviceID,
wantErrMsg: "",
wildcards: []string{
"*.dns.example.com",
@@ -135,13 +135,13 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
wantErrMsg: "",
}, {
name: "device_id",
- path: "/dns-query/cli",
- wantDeviceID: "cli",
+ path: "/dns-query/" + testDeviceID,
+ wantDeviceID: testDeviceID,
wantErrMsg: "",
}, {
name: "device_id_slash",
- path: "/dns-query/cli/",
- wantDeviceID: "cli",
+ path: "/dns-query/" + testDeviceID + "/",
+ wantDeviceID: testDeviceID,
wantErrMsg: "",
}, {
name: "bad_url",
@@ -150,9 +150,9 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
wantErrMsg: `http url device id check: bad path "/foo"`,
}, {
name: "extra",
- path: "/dns-query/cli/foo",
+ path: "/dns-query/" + testDeviceID + "/foo",
wantDeviceID: "",
- wantErrMsg: `http url device id check: bad path "/dns-query/cli/foo": ` +
+ wantErrMsg: `http url device id check: bad path "/dns-query/` + testDeviceID + `/foo": ` +
`extra parts`,
}, {
name: "bad_device_id",
@@ -184,11 +184,9 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
}
t.Run("domain_name", func(t *testing.T) {
- const want = "dev"
-
u := &url.URL{
Scheme: "https",
- Host: want + ".dns.example.com",
+ Host: testDeviceID + ".dns.example.com",
Path: "/dns-query",
}
@@ -204,7 +202,7 @@ func TestService_Wrap_deviceIDHTTPS(t *testing.T) {
deviceID, err := deviceIDFromContext(ctx, proto, []string{"*.dns.example.com"})
require.NoError(t, err)
- assert.Equal(t, agd.DeviceID(want), deviceID)
+ assert.Equal(t, agd.DeviceID(testDeviceID), deviceID)
})
}
diff --git a/internal/dnssvc/dnssvc.go b/internal/dnssvc/dnssvc.go
index a5872ad..ca715dc 100644
--- a/internal/dnssvc/dnssvc.go
+++ b/internal/dnssvc/dnssvc.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
+ "github.com/AdguardTeam/AdGuardDNS/internal/connlimiter"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
@@ -20,6 +21,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/errors"
@@ -38,15 +40,22 @@ type Config struct {
// messages for this DNS service.
Messages *dnsmsg.Constructor
- // SafeBrowsing is the safe browsing TXT server.
- SafeBrowsing *filter.SafeBrowsingServer
+ // ControlConf is the configuration of socket options.
+ ControlConf *netext.ControlConfig
+
+ // ConnLimiter, if not nil, is used to limit the number of simultaneously
+ // active stream-connections.
+ ConnLimiter *connlimiter.Limiter
+
+ // SafeBrowsing is the safe browsing TXT hash matcher.
+ SafeBrowsing filter.HashMatcher
// BillStat is used to collect billing statistics.
BillStat billstat.Recorder
// ProfileDB is the AdGuard DNS profile database used to fetch data about
// profiles, devices, and so on.
- ProfileDB agd.ProfileDB
+ ProfileDB profiledb.Interface
// DNSCheck is used by clients to check if they use AdGuard DNS.
DNSCheck dnscheck.Interface
@@ -75,14 +84,11 @@ type Config struct {
// rule lists.
RuleStat rulestat.Interface
- // Upstream defines the upstream server and the group of fallback servers.
- Upstream *agd.Upstream
-
// NewListener, when set, is used instead of the package-level function
// NewListener when creating a DNS listener.
//
// TODO(a.garipov): The handler and service logic should really not be
- // internwined in this way. See AGDNS-1327.
+ // intertwined in this way. See AGDNS-1327.
NewListener NewListenerFunc
// Handler is used as the main DNS handler instead of a simple forwarder.
@@ -170,9 +176,9 @@ func New(c *Config) (svc *Service, err error) {
dnsHdlr := dnsserver.WithMiddlewares(
handler,
&preServiceMw{
- messages: c.Messages,
- filter: c.SafeBrowsing,
- checker: c.DNSCheck,
+ messages: c.Messages,
+ hashMatcher: c.SafeBrowsing,
+ checker: c.DNSCheck,
},
svc,
)
@@ -475,8 +481,13 @@ func newServers(
name := listenerName(s.Name, addr, s.Protocol)
+ lc := bindData.ListenConfig
+ if lc == nil {
+ lc = newListenConfig(c.ControlConf, c.ConnLimiter, s.Protocol)
+ }
+
var l Listener
- l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, bindData.ListenConfig)
+ l, err = newListener(s, name, addr, h, c.NonDNS, c.ErrColl, lc)
if err != nil {
return nil, fmt.Errorf("server %q: %w", s.Name, err)
}
@@ -496,3 +507,26 @@ func newServers(
return servers, nil
}
+
+// newListenConfig returns the netext.ListenConfig used by the plain-DNS
+// servers. The resulting ListenConfig sets additional socket flags and
+// processes the control messages of connections created with ListenPacket.
+// Additionally, if l is not nil, it is used to limit the number of
+// simultaneously active stream-connections.
+func newListenConfig(
+ ctrlConf *netext.ControlConfig,
+ l *connlimiter.Limiter,
+ p agd.Protocol,
+) (lc netext.ListenConfig) {
+ if p == agd.ProtoDNS {
+ lc = netext.DefaultListenConfigWithOOB(ctrlConf)
+ } else {
+ lc = netext.DefaultListenConfig(ctrlConf)
+ }
+
+ if l != nil {
+ lc = connlimiter.NewListenConfig(lc, l)
+ }
+
+ return lc
+}
diff --git a/internal/dnssvc/dnssvc_internal_test.go b/internal/dnssvc/dnssvc_internal_test.go
new file mode 100644
index 0000000..5138cd9
--- /dev/null
+++ b/internal/dnssvc/dnssvc_internal_test.go
@@ -0,0 +1,26 @@
+package dnssvc
+
+import (
+ "net"
+ "net/netip"
+)
+
+// Common addresses for tests.
+var (
+ testClientIP = net.IP{1, 2, 3, 4}
+ testRAddr = &net.TCPAddr{
+ IP: testClientIP,
+ Port: 12345,
+ }
+
+ testClientAddrPort = testRAddr.AddrPort()
+ testClientAddr = testClientAddrPort.Addr()
+
+ testServerAddr = netip.MustParseAddr("5.6.7.8")
+)
+
+// testDeviceID is the common device ID for tests
+const testDeviceID = "dev1234"
+
+// testProfileID is the common profile ID for tests
+const testProfileID = "prof1234"
diff --git a/internal/dnssvc/dnssvc_test.go b/internal/dnssvc/dnssvc_test.go
index c98aa78..ada8233 100644
--- a/internal/dnssvc/dnssvc_test.go
+++ b/internal/dnssvc/dnssvc_test.go
@@ -173,12 +173,6 @@ func TestService_Start(t *testing.T) {
Name: "test_group",
Servers: []*agd.Server{srv},
}},
- Upstream: &agd.Upstream{
- Server: netip.MustParseAddrPort("8.8.8.8:53"),
- FallbackServers: []netip.AddrPort{
- netip.MustParseAddrPort("1.1.1.1:53"),
- },
- },
}
svc, err := dnssvc.New(c)
@@ -246,12 +240,6 @@ func TestNew(t *testing.T) {
Name: "test_group",
Servers: srvs,
}},
- Upstream: &agd.Upstream{
- Server: netip.MustParseAddrPort("8.8.8.8:53"),
- FallbackServers: []netip.AddrPort{
- netip.MustParseAddrPort("1.1.1.1:53"),
- },
- },
}
svc, err := dnssvc.New(c)
diff --git a/internal/dnssvc/initmw.go b/internal/dnssvc/initmw.go
index c4dbd36..4272a59 100644
--- a/internal/dnssvc/initmw.go
+++ b/internal/dnssvc/initmw.go
@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
@@ -27,8 +28,6 @@ import (
// one.
//
// TODO(a.garipov): Add tests.
-//
-// TODO(a.garipov): Make other middlewares more compact as well. See AGDNS-328.
type initMw struct {
// messages is used to build the responses specific for the request's
// context.
@@ -43,8 +42,8 @@ type initMw struct {
// srv is the current server which serves the request.
srv *agd.Server
- // db is the storage of user profiles.
- db agd.ProfileDB
+ // db is the database of user profiles and devices.
+ db profiledb.Interface
// geoIP detects the location of the request source.
geoIP geoip.Interface
@@ -69,6 +68,7 @@ func (mw *initMw) Wrap(h dnsserver.Handler) (wrapped dnsserver.Handler) {
func (mw *initMw) newRequestInfo(
ctx context.Context,
req *dns.Msg,
+ laddr net.Addr,
raddr net.Addr,
fqdn string,
qt dnsmsg.RRType,
@@ -100,7 +100,8 @@ func (mw *initMw) newRequestInfo(
}
// Add the profile information, if any.
- err = mw.addProfile(ctx, ri, req)
+ localIP := netutil.NetAddrToAddrPort(laddr).Addr()
+ err = mw.addProfile(ctx, ri, req, localIP)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@@ -149,7 +150,12 @@ func (mw *initMw) locationData(ctx context.Context, ip netip.Addr, typ string) (
// addProfile adds profile and device information, if any, to the request
// information.
-func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns.Msg) (err error) {
+func (mw *initMw) addProfile(
+ ctx context.Context,
+ ri *agd.RequestInfo,
+ req *dns.Msg,
+ localIP netip.Addr,
+) (err error) {
defer func() { err = errors.Annotate(err, "getting profile from req: %w") }()
var id agd.DeviceID
@@ -170,14 +176,12 @@ func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns.
optlog.Debug2("init mw: got device id %q and ip %s", id, ri.RemoteIP)
- prof, dev, byWhat, err := mw.profile(ctx, ri.RemoteIP, id, mw.srv.Protocol)
+ prof, dev, byWhat, err := mw.profile(ctx, localIP, ri.RemoteIP, id)
if err != nil {
- // Use two errors.Is calls to prevent unnecessary allocations.
- if !errors.Is(err, agd.DeviceNotFoundError{}) &&
- !errors.Is(err, agd.ProfileNotFoundError{}) {
+ if !errors.Is(err, profiledb.ErrDeviceNotFound) {
// Very unlikely, since those two error types are the only ones
- // currently returned from the profile DB.
- return err
+ // currently returned from the default profile DB.
+ return fmt.Errorf("unexpected profiledb error: %s", err)
}
optlog.Debug1("init mw: profile or device not found: %s", err)
@@ -193,32 +197,51 @@ func (mw *initMw) addProfile(ctx context.Context, ri *agd.RequestInfo, req *dns.
return nil
}
+// Constants for the parameter by which a device has been found.
+const (
+ byDeviceID = "device id"
+ byDedicatedIP = "dedicated ip"
+ byLinkedIP = "linked ip"
+)
+
// profile finds the profile by the client data.
func (mw *initMw) profile(
ctx context.Context,
- ip netip.Addr,
+ localIP netip.Addr,
+ remoteIP netip.Addr,
id agd.DeviceID,
- p agd.Protocol,
) (prof *agd.Profile, dev *agd.Device, byWhat string, err error) {
if id != "" {
prof, dev, err = mw.db.ProfileByDeviceID(ctx, id)
+ if err != nil {
+ return nil, nil, "", err
+ }
- return prof, dev, "device id", err
+ return prof, dev, byDeviceID, nil
}
if !mw.srv.LinkedIPEnabled {
- optlog.Debug1("init mw: not matching by ip for server %s", mw.srv.Name)
+ optlog.Debug1("init mw: not matching by linked or dedicated ip for server %s", mw.srv.Name)
- return nil, nil, "", agd.ProfileNotFoundError{}
- } else if p != agd.ProtoDNS {
- optlog.Debug1("init mw: not matching by ip for proto %v", p)
+ return nil, nil, "", profiledb.ErrDeviceNotFound
+ } else if p := mw.srv.Protocol; p != agd.ProtoDNS {
+ optlog.Debug1("init mw: not matching by linked or dedicated ip for proto %v", p)
- return nil, nil, "", agd.ProfileNotFoundError{}
+ return nil, nil, "", profiledb.ErrDeviceNotFound
}
- prof, dev, err = mw.db.ProfileByIP(ctx, ip)
+ byWhat = byDedicatedIP
+ prof, dev, err = mw.db.ProfileByDedicatedIP(ctx, localIP)
+ if errors.Is(err, profiledb.ErrDeviceNotFound) {
+ byWhat = byLinkedIP
+ prof, dev, err = mw.db.ProfileByLinkedIP(ctx, remoteIP)
+ }
- return prof, dev, "linked ip", err
+ if err != nil {
+ return nil, nil, "", err
+ }
+
+ return prof, dev, byWhat, nil
}
// initMwHandler implements the [dnsserver.Handler] interface and will be used
@@ -259,7 +282,7 @@ func (mh *initMwHandler) ServeDNS(
mw := mh.mw
// Get the request's information, such as GeoIP data and user profiles.
- ri, err := mw.newRequestInfo(ctx, req, rw.RemoteAddr(), fqdn, qt, cl)
+ ri, err := mw.newRequestInfo(ctx, req, rw.LocalAddr(), rw.RemoteAddr(), fqdn, qt, cl)
if err != nil {
var ecsErr dnsmsg.BadECSError
if errors.As(err, &ecsErr) {
diff --git a/internal/dnssvc/initmw_internal_test.go b/internal/dnssvc/initmw_internal_test.go
index 74ffb13..fd10fce 100644
--- a/internal/dnssvc/initmw_internal_test.go
+++ b/internal/dnssvc/initmw_internal_test.go
@@ -3,7 +3,6 @@ package dnssvc
import (
"context"
"crypto/tls"
- "net"
"net/netip"
"testing"
@@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
@@ -21,10 +21,148 @@ import (
"golang.org/x/exp/maps"
)
-// testRAddr is the common remote address for tests.
-var testRAddr = &net.TCPAddr{
- IP: net.IP{1, 2, 3, 4},
- Port: 12345,
+func TestInitMw_profile(t *testing.T) {
+ prof := &agd.Profile{
+ ID: testProfileID,
+ DeviceIDs: []agd.DeviceID{
+ testDeviceID,
+ },
+ }
+ dev := &agd.Device{
+ ID: testDeviceID,
+ LinkedIP: testClientAddr,
+ DedicatedIPs: []netip.Addr{
+ testServerAddr,
+ },
+ }
+
+ testCases := []struct {
+ wantDev *agd.Device
+ wantProf *agd.Profile
+ wantByWhat string
+ wantErrMsg string
+ name string
+ id agd.DeviceID
+ proto agd.Protocol
+ linkedIPEnabled bool
+ }{{
+ wantDev: nil,
+ wantProf: nil,
+ wantByWhat: "",
+ wantErrMsg: "device not found",
+ name: "no_device_id",
+ id: "",
+ proto: agd.ProtoDNS,
+ linkedIPEnabled: true,
+ }, {
+ wantDev: dev,
+ wantProf: prof,
+ wantByWhat: byDeviceID,
+ wantErrMsg: "",
+ name: "device_id",
+ id: testDeviceID,
+ proto: agd.ProtoDNS,
+ linkedIPEnabled: true,
+ }, {
+ wantDev: dev,
+ wantProf: prof,
+ wantByWhat: byLinkedIP,
+ wantErrMsg: "",
+ name: "linked_ip",
+ id: "",
+ proto: agd.ProtoDNS,
+ linkedIPEnabled: true,
+ }, {
+ wantDev: nil,
+ wantProf: nil,
+ wantByWhat: "",
+ wantErrMsg: "device not found",
+ name: "linked_ip_dot",
+ id: "",
+ proto: agd.ProtoDoT,
+ linkedIPEnabled: true,
+ }, {
+ wantDev: nil,
+ wantProf: nil,
+ wantByWhat: "",
+ wantErrMsg: "device not found",
+ name: "linked_ip_disabled",
+ id: "",
+ proto: agd.ProtoDoT,
+ linkedIPEnabled: false,
+ }, {
+ wantDev: dev,
+ wantProf: prof,
+ wantByWhat: byDedicatedIP,
+ wantErrMsg: "",
+ name: "dedicated_ip",
+ id: "",
+ proto: agd.ProtoDNS,
+ linkedIPEnabled: true,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ db := &agdtest.ProfileDB{
+ OnProfileByDeviceID: func(
+ _ context.Context,
+ gotID agd.DeviceID,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ assert.Equal(t, tc.id, gotID)
+
+ if tc.wantByWhat == byDeviceID {
+ return prof, dev, nil
+ }
+
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ gotLocalIP netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ assert.Equal(t, testServerAddr, gotLocalIP)
+
+ if tc.wantByWhat == byDedicatedIP {
+ return prof, dev, nil
+ }
+
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ gotRemoteIP netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ assert.Equal(t, testClientAddr, gotRemoteIP)
+
+ if tc.wantByWhat == byLinkedIP {
+ return prof, dev, nil
+ }
+
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ }
+
+ mw := &initMw{
+ srv: &agd.Server{
+ Protocol: tc.proto,
+ LinkedIPEnabled: tc.linkedIPEnabled,
+ },
+ db: db,
+ }
+ ctx := context.Background()
+
+ gotProf, gotDev, gotByWhat, err := mw.profile(
+ ctx,
+ testServerAddr,
+ testClientAddr,
+ tc.id,
+ )
+ testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
+ assert.Equal(t, tc.wantProf, gotProf)
+ assert.Equal(t, tc.wantDev, gotDev)
+ assert.Equal(t, tc.wantByWhat, gotByWhat)
+ })
+ }
}
func TestInitMw_ServeDNS_ddr(t *testing.T) {
@@ -32,15 +170,14 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
resolverName = "dns.example.com"
resolverFQDN = resolverName + "."
- deviceID = "dev1234"
- targetWithID = deviceID + ".d." + resolverName + "."
+ targetWithID = testDeviceID + ".d." + resolverName + "."
ddrFQDN = ddrDomain + "."
dohPath = "/dns-query"
)
- testDevice := &agd.Device{ID: deviceID}
+ testDevice := &agd.Device{ID: testDeviceID}
srvs := map[agd.ServerName]*agd.Server{
"dot": {
@@ -98,15 +235,21 @@ func TestInitMw_ServeDNS_ddr(t *testing.T) {
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
- p = &agd.Profile{Devices: []*agd.Device{dev}}
+ p = &agd.Profile{}
return p, dev, nil
},
- OnProfileByIP: func(
+ OnProfileByDedicatedIP: func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
- p = &agd.Profile{Devices: []*agd.Device{dev}}
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ p = &agd.Profile{}
return p, dev, nil
},
@@ -397,12 +540,12 @@ func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
return rw.WriteMsg(ctx, req, resp)
})
- onProfileByIP := func(
+ onProfileByLinkedIP := func(
_ context.Context,
_ netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
if !tc.hasProf {
- return nil, nil, agd.DeviceNotFoundError{}
+ return nil, nil, profiledb.ErrDeviceNotFound
}
prof := &agd.Profile{
@@ -419,7 +562,13 @@ func TestInitMw_ServeDNS_specialDomain(t *testing.T) {
) (p *agd.Profile, d *agd.Device, err error) {
panic("not implemented")
},
- OnProfileByIP: onProfileByIP,
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ OnProfileByLinkedIP: onProfileByLinkedIP,
}
geoIP := &agdtest.GeoIP{
@@ -543,7 +692,7 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
- TLSServerName: "dev1234.dns.example.com",
+ TLSServerName: testDeviceID + ".dns.example.com",
})
req := &dns.Msg{
@@ -573,6 +722,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
}
b.Run("success", func(b *testing.B) {
b.ReportAllocs()
@@ -589,7 +750,19 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
_ context.Context,
_ agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
- return nil, nil, agd.ProfileNotFoundError{}
+ return nil, nil, profiledb.ErrDeviceNotFound
+ },
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
},
}
b.Run("not_found", func(b *testing.B) {
@@ -616,6 +789,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
) (p *agd.Profile, d *agd.Device, err error) {
return prof, dev, nil
},
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
}
b.Run("firefox_canary", func(b *testing.B) {
b.ReportAllocs()
@@ -630,13 +815,13 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
ddrReq := &dns.Msg{
Question: []dns.Question{{
// Check the worst case when wildcards are checked.
- Name: "_dns.dev1234.dns.example.com.",
+ Name: "_dns." + testDeviceID + ".dns.example.com.",
Qtype: dns.TypeSVCB,
Qclass: dns.ClassINET,
}},
}
devWithID := &agd.Device{
- ID: "dev1234",
+ ID: testDeviceID,
}
mw.db = &agdtest.ProfileDB{
OnProfileByDeviceID: func(
@@ -645,6 +830,18 @@ func BenchmarkInitMw_Wrap(b *testing.B) {
) (p *agd.Profile, d *agd.Device, err error) {
return prof, devWithID, nil
},
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
+ OnProfileByLinkedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
}
b.Run("ddr", func(b *testing.B) {
b.ReportAllocs()
diff --git a/internal/dnssvc/middleware.go b/internal/dnssvc/middleware.go
index 5d0719b..f1a3294 100644
--- a/internal/dnssvc/middleware.go
+++ b/internal/dnssvc/middleware.go
@@ -3,6 +3,7 @@ package dnssvc
import (
"context"
"strconv"
+ "strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
@@ -11,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/miekg/dns"
+ "golang.org/x/exp/slices"
)
// Middlewares
@@ -52,30 +54,58 @@ func (mh *svcHandler) ServeDNS(
optlog.Debug2("processing request %q from %s", reqID, raddr)
defer optlog.Debug2("finished processing request %q from %s", reqID, raddr)
- // Assume that the cache is always hot and that we can always send the
- // request and filter it out along with the response later if we need
- // it.
+ reqInfo := agd.MustRequestInfoFromContext(ctx)
+ flt := mh.svc.fltStrg.FilterFromContext(ctx, reqInfo)
+
+ modReq, reqRes, elapsedReq := mh.svc.filterRequest(ctx, req, flt, reqInfo)
+
nwrw := makeNonWriter(rw)
- err = mh.next.ServeDNS(ctx, nwrw, req)
+ if modReq != nil {
+ // Modified request is set only if the request was modified by a CNAME
+ // rewrite rule, so resolve the request as if it was for the rewritten
+ // name.
+
+ // Clone the request informaton and replace the host name with the
+ // rewritten one, since the request information from current context
+ // must only be accessed for reading, see [agd.RequestInfo]. Shallow
+ // copy is enough, because we only change the [agd.RequestInfo.Host]
+ // field, which is a string.
+ modReqInfo := &agd.RequestInfo{}
+ *modReqInfo = *reqInfo
+ modReqInfo.Host = strings.ToLower(strings.TrimSuffix(modReq.Question[0].Name, "."))
+
+ modReqCtx := agd.ContextWithRequestInfo(ctx, modReqInfo)
+
+ optlog.Debug2(
+ "dnssvc: request for %q rewritten to %q by CNAME rewrite rule",
+ reqInfo.Host,
+ modReqInfo.Host,
+ )
+
+ err = mh.next.ServeDNS(modReqCtx, nwrw, modReq)
+ } else {
+ err = mh.next.ServeDNS(ctx, nwrw, req)
+ }
if err != nil {
return err
}
- ri := agd.MustRequestInfoFromContext(ctx)
origResp := nwrw.Msg()
- reqRes, respRes := mh.svc.filterQuery(ctx, req, origResp, ri)
+ respRes, elapsedResp := mh.svc.filterResponse(ctx, req, origResp, flt, reqInfo, modReq)
+
+ mh.svc.reportMetrics(reqInfo, reqRes, respRes, elapsedReq+elapsedResp)
if isDebug {
return mh.svc.writeDebugResponse(ctx, rw, req, origResp, reqRes, respRes)
}
- resp, err := writeFilteredResp(ctx, ri, rw, req, origResp, reqRes, respRes)
+ resp, err := writeFilteredResp(ctx, reqInfo, rw, req, origResp, reqRes, respRes)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
- mh.svc.recordQueryInfo(ctx, req, resp, origResp, ri, reqRes, respRes)
+ mh.svc.recordQueryInfo(ctx, req, resp, origResp, reqInfo, reqRes, respRes)
return nil
}
@@ -91,31 +121,73 @@ func makeNonWriter(rw dnsserver.ResponseWriter) (nwrw *dnsserver.NonWriterRespon
return dnsserver.NewNonWriterResponseWriter(rw.LocalAddr(), rw.RemoteAddr())
}
-// filterQuery is a wrapper for f.FilterRequest and f.FilterResponse that treats
-// filtering errors non-critical. It also records filtering metrics.
-func (svc *Service) filterQuery(
+// rewrittenRequest returns a request from res in case it's a CNAME rewrite, and
+// returns nil otherwise. Note that the returned message is always a request
+// since any other rewrite rule type turns into response.
+func rewrittenRequest(res filter.Result) (req *dns.Msg) {
+ if res, ok := res.(*filter.ResultModified); ok && !res.Msg.Response {
+ return res.Msg
+ }
+
+ return nil
+}
+
+// filterRequest applies f to req and returns the result of filtering. If the
+// result is the CNAME rewrite, it also returns the modified request to resolve.
+// It also returns the time elapsed on filtering.
+func (svc *Service) filterRequest(
ctx context.Context,
req *dns.Msg,
- origResp *dns.Msg,
+ f filter.Interface,
ri *agd.RequestInfo,
-) (reqRes, respRes filter.Result) {
+) (modReq *dns.Msg, reqRes filter.Result, elapsed time.Duration) {
start := time.Now()
- defer func() {
- svc.reportMetrics(ri, reqRes, respRes, time.Since(start))
- }()
-
- f := svc.fltStrg.FilterFromContext(ctx, ri)
reqRes, err := f.FilterRequest(ctx, req, ri)
if err != nil {
svc.reportf(ctx, "filtering request: %w", err)
}
- respRes, err = f.FilterResponse(ctx, origResp, ri)
- if err != nil {
- svc.reportf(ctx, "dnssvc: filtering original response: %w", err)
+ // Consider this operation related to filtering and account the elapsed
+ // time.
+ modReq = rewrittenRequest(reqRes)
+
+ return modReq, reqRes, time.Since(start)
+}
+
+// filterResponse applies f to resp and returns the result of filtering. If
+// origReq has a different question name than resp, the request assumed being
+// CNAME-rewritten and no filtering performed on resp, the CNAME is prepended to
+// resp answer section instead. It also returns the time elapsed on filtering.
+func (svc *Service) filterResponse(
+ ctx context.Context,
+ req *dns.Msg,
+ resp *dns.Msg,
+ f filter.Interface,
+ ri *agd.RequestInfo,
+ modReq *dns.Msg,
+) (respRes filter.Result, elapsed time.Duration) {
+ start := time.Now()
+
+ if modReq != nil {
+ // Return the request name to its original state, since it was
+ // previously rewritten by CNAME rewrite rule.
+ resp.Question[0] = req.Question[0]
+
+ // Prepend the CNAME answer to the response and don't filter it.
+ var rr dns.RR = ri.Messages.NewAnswerCNAME(req, modReq.Question[0].Name)
+ resp.Answer = slices.Insert(resp.Answer, 0, rr)
+
+ // Also consider this operation related to filtering and account the
+ // elapsed time.
+ return nil, time.Since(start)
}
- return reqRes, respRes
+ respRes, err := f.FilterResponse(ctx, resp, ri)
+ if err != nil {
+ svc.reportf(ctx, "filtering response: %w", err)
+ }
+
+ return respRes, time.Since(start)
}
// reportMetrics extracts filtering metrics data from the context and reports it
@@ -139,7 +211,7 @@ func (svc *Service) reportMetrics(
metrics.DNSSvcRequestByCountryTotal.WithLabelValues(cont, ctry).Inc()
metrics.DNSSvcRequestByASNTotal.WithLabelValues(ctry, asn).Inc()
- id, _, blocked := filteringData(reqRes, respRes)
+ id, _, isBlocked := filteringData(reqRes, respRes)
metrics.DNSSvcRequestByFilterTotal.WithLabelValues(
string(id),
metrics.BoolString(ri.Profile == nil),
@@ -149,19 +221,7 @@ func (svc *Service) reportMetrics(
metrics.DNSSvcUsersCountUpdate(ri.RemoteIP)
if svc.researchMetrics {
- anonymous := ri.Profile == nil
- filteringEnabled := ri.FilteringGroup != nil &&
- ri.FilteringGroup.RuleListsEnabled &&
- len(ri.FilteringGroup.RuleListIDs) > 0
-
- metrics.ReportResearchMetrics(
- anonymous,
- filteringEnabled,
- asn,
- ctry,
- string(id),
- blocked,
- )
+ metrics.ReportResearchMetrics(ri, id, isBlocked)
}
}
diff --git a/internal/dnssvc/middleware_test.go b/internal/dnssvc/middleware_test.go
index 7615556..b7ea00e 100644
--- a/internal/dnssvc/middleware_test.go
+++ b/internal/dnssvc/middleware_test.go
@@ -5,6 +5,7 @@ import (
"net"
"net/http"
"net/netip"
+ "strings"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnssvc"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
@@ -22,46 +24,95 @@ import (
"github.com/stretchr/testify/require"
)
-func TestService_Wrap_withClient(t *testing.T) {
- // Part 1. Server Configuration
- //
- // Configure a server with fakes to make sure that the wrapped handler
- // has all necessary entities and data in place.
- //
- // TODO(a.garipov): Put this thing into some kind of helper so that we
- // could create several such tests.
+const (
+ // testProfID is the [agd.ProfileID] for tests.
+ testProfID agd.ProfileID = "prof1234"
- const (
- id agd.ProfileID = "prof1234"
- devID agd.DeviceID = "dev1234"
- fltListID agd.FilterListID = "flt1234"
- )
+ // testDevID is the [agd.DeviceID] for tests.
+ testDevID agd.DeviceID = "dev1234"
+
+ // testFltListID is the [agd.FilterListID] for tests.
+ testFltListID agd.FilterListID = "flt1234"
+
+ // testSrvName is the [agd.ServerName] for tests.
+ testSrvName agd.ServerName = "test_server_dns_tls"
+
+ // testSrvGrpName is the [agd.ServerGroupName] for tests.
+ testSrvGrpName agd.ServerGroupName = "test_group"
+
+ // testDevIDWildcard is the wildcard domain for retrieving [agd.DeviceID] in
+ // tests. Use [strings.ReplaceAll] to replace the "*" symbol with the
+ // actual [agd.DeviceID].
+ testDevIDWildcard string = "*.dns.example.com"
+)
+
+// testTimeout is the common timeout for tests.
+const testTimeout time.Duration = 1 * time.Second
+
+// newTestService creates a new [dnssvc.Service] for tests. The service built
+// of stubs, that use the following data:
+//
+// - A filtering group containing a filter with [testFltListID] and enabled
+// rule lists.
+// - A device with [testDevID] and enabled filtering.
+// - A profile with [testProfID] with enabled filtering and query
+// logging, containing the device.
+// - GeoIP database always returning [agd.CountryAD], [agd.ContinentEU], and
+// ASN of 42.
+// - A server with [testSrvName] under group with [testSrvGrpName], matching
+// the DeviceID with [testDevIDWildcard].
+//
+// Each stub also uses the corresponding channels to send the data it receives
+// from the service. If the channel is [nil], the stub ignores it. Each
+// sending to a channel wrapped with [testutil.RequireSend] using [testTimeout].
+//
+// It also uses the [dnsservertest.DefaultHandler] to create the DNS handler.
+func newTestService(
+ t testing.TB,
+ flt filter.Interface,
+ errCollCh chan<- error,
+ profileDBCh chan<- agd.DeviceID,
+ querylogCh chan<- *querylog.Entry,
+ geoIPCh chan<- string,
+ dnsDBCh chan<- *agd.RequestInfo,
+ ruleStatCh chan<- agd.FilterRuleText,
+) (svc *dnssvc.Service, srvAddr netip.AddrPort) {
+ t.Helper()
+
+ pt := testutil.PanicT{}
dev := &agd.Device{
- ID: devID,
+ ID: testDevID,
FilteringEnabled: true,
}
prof := &agd.Profile{
- ID: id,
- Devices: []*agd.Device{dev},
- RuleListIDs: []agd.FilterListID{fltListID},
- FilteredResponseTTL: 10 * time.Second,
+ ID: testProfID,
+ DeviceIDs: []agd.DeviceID{testDevID},
+ RuleListIDs: []agd.FilterListID{testFltListID},
+ FilteredResponseTTL: agdtest.FilteredResponseTTL,
FilteringEnabled: true,
QueryLogEnabled: true,
}
- dbDeviceIDs := make(chan agd.DeviceID, 1)
db := &agdtest.ProfileDB{
OnProfileByDeviceID: func(
_ context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
- dbDeviceIDs <- id
+ if profileDBCh != nil {
+ testutil.RequireSend(pt, profileDBCh, id, testTimeout)
+ }
return prof, dev, nil
},
- OnProfileByIP: func(
+ OnProfileByDedicatedIP: func(
+ _ context.Context,
+ _ netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error) {
+ panic("not implemented")
+ },
+ OnProfileByLinkedIP: func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
@@ -71,24 +122,19 @@ func TestService_Wrap_withClient(t *testing.T) {
// Make sure that any panics and errors within handlers are caught and
// that they fail the test by panicking.
- errCh := make(chan error, 1)
- go func() {
- pt := testutil.PanicT{}
-
- err, ok := <-errCh
- if !ok {
- return
- }
-
- require.NoError(pt, err)
- }()
-
errColl := &agdtest.ErrorCollector{
OnCollect: func(_ context.Context, err error) {
- errCh <- err
+ if errCollCh != nil {
+ testutil.RequireSend(pt, errCollCh, err, testTimeout)
+ }
},
}
+ loc := &agd.Location{
+ Country: agd.CountryAD,
+ Continent: agd.ContinentEU,
+ ASN: 42,
+ }
geoIP := &agdtest.GeoIP{
OnSubnetByLocation: func(
_ agd.Country,
@@ -97,34 +143,13 @@ func TestService_Wrap_withClient(t *testing.T) {
) (n netip.Prefix, err error) {
panic("not implemented")
},
- OnData: func(_ string, _ netip.Addr) (l *agd.Location, err error) {
- return &agd.Location{
- Country: agd.CountryAD,
- Continent: agd.ContinentEU,
- ASN: 42,
- }, nil
- },
- }
+ OnData: func(host string, _ netip.Addr) (l *agd.Location, err error) {
+ if geoIPCh != nil {
+ testutil.RequireSend(pt, geoIPCh, host, testTimeout)
+ }
- fltDomainCh := make(chan string, 1)
- flt := &agdtest.Filter{
- OnFilterRequest: func(
- _ context.Context,
- req *dns.Msg,
- _ *agd.RequestInfo,
- ) (r filter.Result, err error) {
- fltDomainCh <- req.Question[0].Name
-
- return nil, nil
+ return loc, nil
},
- OnFilterResponse: func(
- _ context.Context,
- _ *dns.Msg,
- _ *agd.RequestInfo,
- ) (r filter.Result, err error) {
- return nil, nil
- },
- OnClose: func() (err error) { panic("not implemented") },
}
fltStrg := &agdtest.FilterStorage{
@@ -134,23 +159,21 @@ func TestService_Wrap_withClient(t *testing.T) {
OnHasListID: func(_ agd.FilterListID) (ok bool) { panic("not implemented") },
}
- logDomainCh := make(chan string, 1)
- logQTypeCh := make(chan dnsmsg.RRType, 1)
var ql querylog.Interface = &agdtest.QueryLog{
OnWrite: func(_ context.Context, e *querylog.Entry) (err error) {
- logDomainCh <- e.DomainFQDN
- logQTypeCh <- e.RequestType
+ if querylogCh != nil {
+ testutil.RequireSend(pt, querylogCh, e, testTimeout)
+ }
return nil
},
}
- srvAddr := netip.MustParseAddrPort("94.149.14.14:853")
- srvName := agd.ServerName("test_server_dns_tls")
+ srvAddr = netip.MustParseAddrPort("94.149.14.14:853")
srvs := []*agd.Server{{
DNSCrypt: nil,
TLS: nil,
- Name: srvName,
+ Name: testSrvName,
BindData: []*agd.ServerBindData{{
AddrPort: srvAddr,
}},
@@ -161,20 +184,6 @@ func TestService_Wrap_withClient(t *testing.T) {
tl.onStart = func(_ context.Context) (err error) { return nil }
tl.onShutdown = func(_ context.Context) (err error) { return nil }
- var h dnsserver.Handler = dnsserver.HandlerFunc(func(
- ctx context.Context,
- rw dnsserver.ResponseWriter,
- r *dns.Msg,
- ) (err error) {
- resp := &dns.Msg{}
- resp.SetReply(r)
- resp.Answer = append(resp.Answer, &dns.A{
- A: net.IP{1, 2, 3, 4},
- })
-
- return rw.WriteMsg(ctx, r, resp)
- })
-
dnsCk := &agdtest.DNSCheck{
OnCheck: func(
_ context.Context,
@@ -185,17 +194,19 @@ func TestService_Wrap_withClient(t *testing.T) {
},
}
- numDNSDBReq := 0
dnsDB := &agdtest.DNSDB{
- OnRecord: func(_ context.Context, _ *dns.Msg, _ *agd.RequestInfo) {
- numDNSDBReq++
+ OnRecord: func(_ context.Context, _ *dns.Msg, ri *agd.RequestInfo) {
+ if dnsDBCh != nil {
+ testutil.RequireSend(pt, dnsDBCh, ri, testTimeout)
+ }
},
}
- numRuleStatReq := 0
ruleStat := &agdtest.RuleStat{
- OnCollect: func(_ context.Context, _ agd.FilterListID, _ agd.FilterRuleText) {
- numRuleStatReq++
+ OnCollect: func(_ context.Context, _ agd.FilterListID, text agd.FilterRuleText) {
+ if ruleStatCh != nil {
+ testutil.RequireSend(pt, ruleStatCh, text, testTimeout)
+ }
},
}
@@ -207,11 +218,13 @@ func TestService_Wrap_withClient(t *testing.T) {
) (drop, allowlisted bool, err error) {
return true, false, nil
},
- OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {},
+ OnCountResponses: func(_ context.Context, _ *dns.Msg, _ netip.Addr) {
+ panic("not implemented")
+ },
}
- fltGrpID := agd.FilteringGroupID("1234")
- srvGrpName := agd.ServerGroupName("test_group")
+ testFltGrpID := agd.FilteringGroupID("1234")
+
c := &dnssvc.Config{
Messages: agdtest.NewConstructor(),
BillStat: &agdtest.BillStatRecorder{
@@ -234,31 +247,25 @@ func TestService_Wrap_withClient(t *testing.T) {
GeoIP: geoIP,
QueryLog: ql,
RuleStat: ruleStat,
- Upstream: &agd.Upstream{
- Server: netip.MustParseAddrPort("8.8.8.8:53"),
- FallbackServers: []netip.AddrPort{
- netip.MustParseAddrPort("1.1.1.1:53"),
- },
- },
- NewListener: newTestListenerFunc(tl),
- Handler: h,
- RateLimit: rl,
+ NewListener: newTestListenerFunc(tl),
+ Handler: dnsservertest.DefaultHandler(),
+ RateLimit: rl,
FilteringGroups: map[agd.FilteringGroupID]*agd.FilteringGroup{
- fltGrpID: {
- ID: fltGrpID,
- RuleListIDs: []agd.FilterListID{fltListID},
+ testFltGrpID: {
+ ID: testFltGrpID,
+ RuleListIDs: []agd.FilterListID{testFltListID},
RuleListsEnabled: true,
},
},
ServerGroups: []*agd.ServerGroup{{
TLS: &agd.TLS{
- DeviceIDWildcards: []string{"*.dns.example.com"},
+ DeviceIDWildcards: []string{testDevIDWildcard},
},
DDR: &agd.DDR{
Enabled: true,
},
- Name: srvGrpName,
- FilteringGroup: fltGrpID,
+ Name: testSrvGrpName,
+ FilteringGroup: testFltGrpID,
Servers: srvs,
}},
}
@@ -266,56 +273,190 @@ func TestService_Wrap_withClient(t *testing.T) {
svc, err := dnssvc.New(c)
require.NoError(t, err)
require.NotNil(t, svc)
+
+ err = svc.Start()
+ require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return svc.Shutdown(context.Background())
})
- err = svc.Start()
- require.NoError(t, err)
+ return svc, srvAddr
+}
- // Part 2. Testing Proper
- //
- // Create a context, a request, and a simple handler. Wrap the handler
- // and make sure that all processing went as needed.
+func TestService_Wrap(t *testing.T) {
+ profileDBCh := make(chan agd.DeviceID, 1)
+ querylogCh := make(chan *querylog.Entry, 1)
+ geoIPCh := make(chan string, 2)
+ dnsDBCh := make(chan *agd.RequestInfo, 1)
+ ruleStatCh := make(chan agd.FilterRuleText, 1)
+
+ errCollCh := make(chan error, 1)
+ go func() {
+ for err := range errCollCh {
+ require.NoError(t, err)
+ }
+ }()
+
+ const domain = "example.org"
+
+ domainFQDN := dns.Fqdn(domain)
- domain := "example.org."
reqType := dns.TypeA
- req := &dns.Msg{
- MsgHdr: dns.MsgHdr{
- Id: dns.Id(),
- RecursionDesired: true,
- },
- Question: []dns.Question{{
- Name: domain,
- Qtype: reqType,
- Qclass: dns.ClassINET,
- }},
- }
+ req := dnsservertest.CreateMessage(domain, reqType)
+
+ clientAddr := &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 12345}
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{
- TLSServerName: string(devID) + ".dns.example.com",
+ TLSServerName: strings.ReplaceAll(testDevIDWildcard, "*", string(testDevID)),
})
ctx = dnsserver.ContextWithServerInfo(ctx, dnsserver.ServerInfo{
Proto: agd.ProtoDoT,
})
- ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
- clientAddr := &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 12345}
- rw := &testResponseWriter{
- onLocalAddr: func() (a net.Addr) { return net.TCPAddrFromAddrPort(srvAddr) },
- onRemoteAddr: func() (a net.Addr) { return clientAddr },
- onWriteMsg: func(_ context.Context, _, _ *dns.Msg) (err error) {
- return nil
- },
- }
- err = svc.Handle(ctx, srvGrpName, srvName, rw, req)
- require.NoError(t, err)
+ t.Run("simple_success", func(t *testing.T) {
+ noMatch := func(
+ _ context.Context,
+ m *dns.Msg,
+ _ *agd.RequestInfo,
+ ) (r filter.Result, err error) {
+ pt := testutil.PanicT{}
+ require.NotEmpty(pt, m.Question)
+ require.Equal(pt, domainFQDN, m.Question[0].Name)
- assert.Equal(t, devID, <-dbDeviceIDs)
- assert.Equal(t, domain, <-fltDomainCh)
- assert.Equal(t, domain, <-logDomainCh)
- assert.Equal(t, reqType, <-logQTypeCh)
- assert.Equal(t, 1, numDNSDBReq)
- assert.Equal(t, 1, numRuleStatReq)
+ return nil, nil
+ }
+
+ flt := &agdtest.Filter{
+ OnFilterRequest: noMatch,
+ OnFilterResponse: noMatch,
+ }
+
+ svc, srvAddr := newTestService(
+ t,
+ flt,
+ errCollCh,
+ profileDBCh,
+ querylogCh,
+ geoIPCh,
+ dnsDBCh,
+ ruleStatCh,
+ )
+
+ rw := dnsserver.NewNonWriterResponseWriter(
+ net.TCPAddrFromAddrPort(srvAddr),
+ clientAddr,
+ )
+
+ ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
+
+ err := svc.Handle(ctx, testSrvGrpName, testSrvName, rw, req)
+ require.NoError(t, err)
+
+ resp := rw.Msg()
+ dnsservertest.RequireResponse(t, req, resp, 1, dns.RcodeSuccess, false)
+
+ assert.Equal(t, testDevID, <-profileDBCh)
+
+ logEntry := <-querylogCh
+ assert.Equal(t, domainFQDN, logEntry.DomainFQDN)
+ assert.Equal(t, reqType, logEntry.RequestType)
+
+ assert.Equal(t, "", <-geoIPCh)
+ assert.Equal(t, domain, <-geoIPCh)
+
+ dnsDBReqInfo := <-dnsDBCh
+ assert.NotNil(t, dnsDBReqInfo)
+ assert.Equal(t, agd.FilterRuleText(""), <-ruleStatCh)
+ })
+
+ t.Run("request_cname", func(t *testing.T) {
+ const (
+ cname = "cname.example.org"
+ cnameRule agd.FilterRuleText = "||" + domain + "^$dnsrewrite=" + cname
+ )
+
+ cnameFQDN := dns.Fqdn(cname)
+
+ flt := &agdtest.Filter{
+ OnFilterRequest: func(
+ _ context.Context,
+ m *dns.Msg,
+ _ *agd.RequestInfo,
+ ) (r filter.Result, err error) {
+ // Pretend a CNAME rewrite matched the request.
+ mod := dnsmsg.Clone(m)
+ mod.Question[0].Name = cnameFQDN
+
+ return &filter.ResultModified{
+ Msg: mod,
+ List: testFltListID,
+ Rule: cnameRule,
+ }, nil
+ },
+ OnFilterResponse: func(
+ _ context.Context,
+ _ *dns.Msg,
+ _ *agd.RequestInfo,
+ ) (filter.Result, error) {
+ panic("not implemented")
+ },
+ }
+
+ svc, srvAddr := newTestService(
+ t,
+ flt,
+ errCollCh,
+ profileDBCh,
+ querylogCh,
+ geoIPCh,
+ dnsDBCh,
+ ruleStatCh,
+ )
+
+ rw := dnsserver.NewNonWriterResponseWriter(
+ net.TCPAddrFromAddrPort(srvAddr),
+ clientAddr,
+ )
+
+ ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
+
+ err := svc.Handle(ctx, testSrvGrpName, testSrvName, rw, req)
+ require.NoError(t, err)
+
+ resp := rw.Msg()
+ require.NotNil(t, resp)
+ require.Len(t, resp.Answer, 2)
+
+ assert.Equal(t, []dns.RR{&dns.CNAME{
+ Hdr: dns.RR_Header{
+ Name: domainFQDN,
+ Rrtype: dns.TypeCNAME,
+ Class: dns.ClassINET,
+ Ttl: uint32(agdtest.FilteredResponseTTL.Seconds()),
+ },
+ Target: cnameFQDN,
+ }, &dns.A{
+ Hdr: dns.RR_Header{
+ Name: cnameFQDN,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ Ttl: uint32(dnsservertest.AnswerTTL.Seconds()),
+ },
+ A: netutil.IPv4Localhost().AsSlice(),
+ }}, resp.Answer)
+
+ assert.Equal(t, testDevID, <-profileDBCh)
+
+ logEntry := <-querylogCh
+ assert.Equal(t, domainFQDN, logEntry.DomainFQDN)
+ assert.Equal(t, reqType, logEntry.RequestType)
+
+ assert.Equal(t, "", <-geoIPCh)
+ assert.Equal(t, cname, <-geoIPCh)
+
+ dnsDBReqInfo := <-dnsDBCh
+ assert.Equal(t, cname, dnsDBReqInfo.Host)
+ assert.Equal(t, cnameRule, <-ruleStatCh)
+ })
}
diff --git a/internal/dnssvc/presvcmw.go b/internal/dnssvc/presvcmw.go
index 2229632..16e0c48 100644
--- a/internal/dnssvc/presvcmw.go
+++ b/internal/dnssvc/presvcmw.go
@@ -25,8 +25,8 @@ type preServiceMw struct {
// messages is used to construct TXT responses.
messages *dnsmsg.Constructor
- // filter is the safe browsing DNS filter.
- filter *filter.SafeBrowsingServer
+ // hashMatcher is the safe browsing DNS hashMatcher.
+ hashMatcher filter.HashMatcher
// checker is used to detect and process DNS-check requests.
checker dnscheck.Interface
@@ -91,7 +91,7 @@ func (mh *preServiceMwHandler) respondWithHashes(
) (err error) {
optlog.Debug1("presvc mw: safe browsing: got txt req for %q", ri.Host)
- hashes, matched, err := mh.mw.filter.Hashes(ctx, ri.Host)
+ hashes, matched, err := mh.mw.hashMatcher.MatchByPrefix(ctx, ri.Host)
if err != nil {
// Don't return or collect this error to prevent DDoS of the error
// collector by sending bad requests.
diff --git a/internal/dnssvc/presvcmw_test.go b/internal/dnssvc/presvcmw_test.go
index 7f2071f..f6db9da 100644
--- a/internal/dnssvc/presvcmw_test.go
+++ b/internal/dnssvc/presvcmw_test.go
@@ -14,7 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -30,11 +30,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
sum := sha256.Sum256([]byte(safeBrowsingHost))
hashStr := hex.EncodeToString(sum[:])
- hashes, herr := hashstorage.New(safeBrowsingHost)
- require.NoError(t, herr)
-
- srv := filter.NewSafeBrowsingServer(hashes, nil)
- host := hashStr[:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix
+ host := hashStr[:hashprefix.PrefixEncLen] + filter.GeneralTXTSuffix
ctx := context.Background()
ctx = dnsserver.ContextWithClientInfo(ctx, dnsserver.ClientInfo{})
@@ -48,12 +44,14 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
req *dns.Msg
dnscheckResp *dns.Msg
ri *agd.RequestInfo
+ hashes []string
wantAns []dns.RR
}{{
name: "normal",
req: dnsservertest.CreateMessage(name, dns.TypeA),
dnscheckResp: nil,
ri: &agd.RequestInfo{},
+ hashes: nil,
wantAns: []dns.RR{
dnsservertest.NewA(name, 100, ip),
},
@@ -63,16 +61,14 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
dnscheckResp: dnsservertest.NewResp(
dns.RcodeSuccess,
dnsservertest.NewReq(name, dns.TypeA, dns.ClassINET),
- dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(name, ttl, ip)},
- Sec: dnsservertest.SectionAnswer,
- },
+ dnsservertest.SectionAnswer{dnsservertest.NewA(name, ttl, ip)},
),
ri: &agd.RequestInfo{
Host: name,
QType: dns.TypeA,
QClass: dns.ClassINET,
},
+ hashes: nil,
wantAns: []dns.RR{
dnsservertest.NewA(name, ttl, ip),
},
@@ -81,6 +77,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
req: dnsservertest.CreateMessage(safeBrowsingHost, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: host, QType: dns.TypeTXT},
+ hashes: []string{hashStr},
wantAns: []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{
Name: safeBrowsingHost,
@@ -95,6 +92,7 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
req: dnsservertest.CreateMessage(name, dns.TypeTXT),
dnscheckResp: nil,
ri: &agd.RequestInfo{Host: name, QType: dns.TypeTXT},
+ hashes: nil,
wantAns: []dns.RR{dnsservertest.NewA(name, 100, ip)},
}}
@@ -113,10 +111,19 @@ func TestPreServiceMwHandler_ServeDNS(t *testing.T) {
},
}
+ hashMatcher := &agdtest.HashMatcher{
+ OnMatchByPrefix: func(
+ ctx context.Context,
+ host string,
+ ) (hashes []string, matched bool, err error) {
+ return tc.hashes, len(tc.hashes) > 0, nil
+ },
+ }
+
mw := &preServiceMw{
- messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second),
- filter: srv,
- checker: dnsCk,
+ messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, ttl*time.Second),
+ hashMatcher: hashMatcher,
+ checker: dnsCk,
}
handler := dnsservertest.DefaultHandler()
h := mw.Wrap(handler)
diff --git a/internal/dnssvc/preupstreammw_test.go b/internal/dnssvc/preupstreammw_test.go
index 58a0c91..09a5d6a 100644
--- a/internal/dnssvc/preupstreammw_test.go
+++ b/internal/dnssvc/preupstreammw_test.go
@@ -28,9 +28,8 @@ func TestPreUpstreamMwHandler_ServeDNS_withCache(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
respIP := remoteIP.AsSlice()
- resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, respIP)},
- Sec: dnsservertest.SectionAnswer,
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, defaultTTL, respIP),
})
ctx := agd.ContextWithRequestInfo(context.Background(), &agd.RequestInfo{
Host: aReq.Question[0].Name,
@@ -91,18 +90,9 @@ func TestPreUpstreamMwHandler_ServeDNS_withECSCache(t *testing.T) {
const ctry = agd.CountryAD
- resp := dnsservertest.NewResp(
- dns.RcodeSuccess,
- aReq,
- dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(
- reqHostname,
- defaultTTL,
- net.IP{1, 2, 3, 4},
- )},
- Sec: dnsservertest.SectionAnswer,
- },
- )
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}),
+ })
numReq := 0
handler := dnsserver.HandlerFunc(
@@ -176,8 +166,16 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
ctx = dnsserver.ContextWithStartTime(ctx, time.Now())
ctx = agd.ContextWithRequestInfo(ctx, &agd.RequestInfo{})
+ ipA := net.IP{1, 2, 3, 4}
+ ipB := net.IP{1, 2, 3, 5}
+
const ttl = 100
+ const (
+ httpsDomain = "-dnsohttps-ds.metric.gstatic.com."
+ tlsDomain = "-dnsotls-ds.metric.gstatic.com."
+ )
+
testCases := []struct {
name string
req *dns.Msg
@@ -191,49 +189,28 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
wantName: "example.com.",
wantAns: nil,
}, {
- name: "android-tls-metric",
- req: dnsservertest.CreateMessage(
- "12345678-dnsotls-ds.metric.gstatic.com.",
- dns.TypeA,
- ),
+ name: "android-tls-metric",
+ req: dnsservertest.CreateMessage("12345678"+tlsDomain, dns.TypeA),
resp: resp,
- wantName: "00000000-dnsotls-ds.metric.gstatic.com.",
+ wantName: "00000000" + tlsDomain,
wantAns: nil,
}, {
- name: "android-https-metric",
- req: dnsservertest.CreateMessage(
- "123456-dnsohttps-ds.metric.gstatic.com.",
- dns.TypeA,
- ),
+ name: "android-https-metric",
+ req: dnsservertest.CreateMessage("123456"+httpsDomain, dns.TypeA),
resp: resp,
- wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
+ wantName: "000000" + httpsDomain,
wantAns: nil,
}, {
name: "multiple_answers_metric",
- req: dnsservertest.CreateMessage(
- "123456-dnsohttps-ds.metric.gstatic.com.",
- dns.TypeA,
- ),
- resp: dnsservertest.NewResp(
- dns.RcodeSuccess,
- req,
- dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(
- "123456-dnsohttps-ds.metric.gstatic.com.",
- ttl,
- net.IP{1, 2, 3, 4},
- ), dnsservertest.NewA(
- "654321-dnsohttps-ds.metric.gstatic.com.",
- ttl,
- net.IP{1, 2, 3, 5},
- )},
- Sec: dnsservertest.SectionAnswer,
- },
- ),
- wantName: "000000-dnsohttps-ds.metric.gstatic.com.",
+ req: dnsservertest.CreateMessage("123456"+httpsDomain, dns.TypeA),
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{
+ dnsservertest.NewA("123456"+httpsDomain, ttl, ipA),
+ dnsservertest.NewA("654321"+httpsDomain, ttl, ipB),
+ }),
+ wantName: "000000" + httpsDomain,
wantAns: []dns.RR{
- dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 4}),
- dnsservertest.NewA("123456-dnsohttps-ds.metric.gstatic.com.", ttl, net.IP{1, 2, 3, 5}),
+ dnsservertest.NewA("123456"+httpsDomain, ttl, ipA),
+ dnsservertest.NewA("123456"+httpsDomain, ttl, ipB),
},
}}
@@ -248,6 +225,7 @@ func TestPreUpstreamMwHandler_ServeDNS_androidMetric(t *testing.T) {
return rw.WriteMsg(ctx, req, tc.resp)
})
+
h := mw.Wrap(handler)
rw := dnsserver.NewNonWriterResponseWriter(nil, testRAddr)
diff --git a/internal/dnssvc/record.go b/internal/dnssvc/record.go
index f7adcb3..2da0b7f 100644
--- a/internal/dnssvc/record.go
+++ b/internal/dnssvc/record.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
+ "strings"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
@@ -68,7 +69,14 @@ func (svc *Service) recordQueryInfo(
var respCtry agd.Country
if !respIP.IsUnspecified() {
- respCtry = svc.country(ctx, ri.Host, respIP)
+ host := ri.Host
+ if modReq := rewrittenRequest(reqRes); modReq != nil {
+ // If the request was modified by CNAME rule, the actual result
+ // belongs to the hostname from that CNAME.
+ host = strings.TrimSuffix(modReq.Question[0].Name, ".")
+ }
+
+ respCtry = svc.country(ctx, host, respIP)
}
q := req.Question[0]
diff --git a/internal/dnssvc/resp.go b/internal/dnssvc/resp.go
index 7ed198a..12bfa1c 100644
--- a/internal/dnssvc/resp.go
+++ b/internal/dnssvc/resp.go
@@ -31,16 +31,23 @@ func writeFilteredResp(
case *filter.ResultAllowed:
err = rw.WriteMsg(ctx, req, resp)
if err != nil {
- err = fmt.Errorf("writing allowed response: %w", err)
+ err = fmt.Errorf("writing response to allowed request: %w", err)
} else {
written = resp
}
case *filter.ResultModified:
- err = rw.WriteMsg(ctx, req, reqRes.Msg)
+ if reqRes.Msg.Response {
+ // Only use the request filtering result in case it's already a
+ // response. Otherwise, it's a CNAME rewrite result, which isn't
+ // filtered after resolving.
+ resp = reqRes.Msg
+ }
+
+ err = rw.WriteMsg(ctx, req, resp)
if err != nil {
- err = fmt.Errorf("writing modified response: %w", err)
+ err = fmt.Errorf("writing response to modified request: %w", err)
} else {
- written = reqRes.Msg
+ written = resp
}
default:
// Consider unhandled sum type members as unrecoverable programmer
diff --git a/internal/dnssvc/resp_internal_test.go b/internal/dnssvc/resp_internal_test.go
index d8d65e4..a1bfe7a 100644
--- a/internal/dnssvc/resp_internal_test.go
+++ b/internal/dnssvc/resp_internal_test.go
@@ -4,10 +4,9 @@ import (
"context"
"net"
"testing"
- "time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
@@ -20,10 +19,10 @@ import (
func TestWriteFilteredResp(t *testing.T) {
const (
- fltRespTTL = 42
- respTTL = 10
+ respTTL = 60
)
+ const fltRespTTL = agdtest.FilteredResponseTTLSec
respIP := net.IP{1, 2, 3, 4}
rewrIP := net.IP{5, 6, 7, 8}
blockIP := netutil.IPv4Zero()
@@ -85,14 +84,14 @@ func TestWriteFilteredResp(t *testing.T) {
ctx := context.Background()
ri := &agd.RequestInfo{
- Messages: dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, fltRespTTL*time.Second),
+ Messages: agdtest.NewConstructor(),
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rw := dnsserver.NewNonWriterResponseWriter(nil, nil)
resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
- resp.Answer = append(resp.Answer, dnsservertest.NewA(domain, 10, respIP))
+ resp.Answer = append(resp.Answer, dnsservertest.NewA(domain, respTTL, respIP))
written, err := writeFilteredResp(ctx, ri, rw, req, resp, tc.reqRes, tc.respRes)
require.NoError(t, err)
diff --git a/internal/ecscache/ecscache_test.go b/internal/ecscache/ecscache_test.go
index 26ae13b..79309b6 100644
--- a/internal/ecscache/ecscache_test.go
+++ b/internal/ecscache/ecscache_test.go
@@ -42,13 +42,11 @@ var remoteIP = netip.MustParseAddr("1.2.3.4")
func TestMiddleware_Wrap_noECS(t *testing.T) {
aReq := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
cnameReq := dnsservertest.NewReq(reqHostname, dns.TypeCNAME, dns.ClassINET)
- cnameAns := dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME)},
- Sec: dnsservertest.SectionAnswer,
+ cnameAns := dnsservertest.SectionAnswer{
+ dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME),
}
- soaNS := dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewSOA(reqHostname, defaultTTL, reqNS1, reqNS2)},
- Sec: dnsservertest.SectionNs,
+ soaNS := dnsservertest.SectionNs{
+ dnsservertest.NewSOA(reqHostname, defaultTTL, reqNS1, reqNS2),
}
const N = 5
@@ -60,8 +58,8 @@ func TestMiddleware_Wrap_noECS(t *testing.T) {
wantTTL uint32
}{{
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4})},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4}),
}),
name: "simple_a",
wantNumReq: 1,
@@ -92,9 +90,8 @@ func TestMiddleware_Wrap_noECS(t *testing.T) {
wantTTL: defaultTTL,
}, {
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewNS(reqHostname, defaultTTL, reqNS1)},
- Sec: dnsservertest.SectionNs,
+ resp: dnsservertest.NewResp(dns.RcodeNameError, aReq, dnsservertest.SectionNs{
+ dnsservertest.NewNS(reqHostname, defaultTTL, reqNS1),
}),
name: "non_authoritative_nxdomain",
// TODO(ameshkov): Consider https://datatracker.ietf.org/doc/html/rfc2308#section-3.
@@ -114,16 +111,16 @@ func TestMiddleware_Wrap_noECS(t *testing.T) {
wantTTL: ecscache.ServFailMaxCacheTTL,
}, {
req: cnameReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME)},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, cnameReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewCNAME(reqHostname, defaultTTL, reqCNAME),
}),
name: "simple_cname_ans",
wantNumReq: 1,
wantTTL: defaultTTL,
}, {
req: aReq,
- resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4})},
+ resp: dnsservertest.NewResp(dns.RcodeSuccess, aReq, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(reqHostname, 0, net.IP{1, 2, 3, 4}),
}),
name: "expired_one",
wantNumReq: N,
@@ -270,14 +267,12 @@ func TestMiddleware_Wrap_ecs(t *testing.T) {
resp := dnsservertest.NewResp(
dns.RcodeSuccess,
aReq,
- dnsservertest.RRSection{
- RRs: []dns.RR{dnsservertest.NewA(reqHostname, defaultTTL, net.IP{1, 2, 3, 4})},
- Sec: dnsservertest.SectionAnswer,
- },
- dnsservertest.RRSection{
- RRs: []dns.RR{tc.respECS},
- Sec: dnsservertest.SectionExtra,
- },
+ dnsservertest.SectionAnswer{dnsservertest.NewA(
+ reqHostname,
+ defaultTTL,
+ net.IP{1, 2, 3, 4},
+ )},
+ dnsservertest.SectionExtra{tc.respECS},
)
numReq := 0
@@ -337,13 +332,12 @@ func TestMiddleware_Wrap_ecsOrder(t *testing.T) {
newResp := func(t *testing.T, req *dns.Msg, answer, extra dns.RR) (resp *dns.Msg) {
t.Helper()
- return dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{
- RRs: []dns.RR{answer},
- Sec: dnsservertest.SectionAnswer,
- }, dnsservertest.RRSection{
- RRs: []dns.RR{extra},
- Sec: dnsservertest.SectionExtra,
- })
+ return dnsservertest.NewResp(
+ dns.RcodeSuccess,
+ req,
+ dnsservertest.SectionAnswer{answer},
+ dnsservertest.SectionExtra{extra},
+ )
}
reqNoECS := dnsservertest.NewReq(reqHostname, dns.TypeA, dns.ClassINET)
diff --git a/internal/errcoll/sentry.go b/internal/errcoll/sentry.go
index e4eb994..86f828e 100644
--- a/internal/errcoll/sentry.go
+++ b/internal/errcoll/sentry.go
@@ -79,9 +79,9 @@ func isReportable(err error) (ok bool) {
} else if errors.As(err, &dnsWErr) {
switch dnsWErr.Protocol {
case "tcp":
- return isReportableTCP(dnsWErr.Err)
+ return isReportableWriteTCP(dnsWErr.Err)
case "udp":
- return isReportableUDP(dnsWErr.Err)
+ return isReportableWriteUDP(dnsWErr.Err)
default:
return true
}
@@ -102,25 +102,28 @@ func isReportableNetwork(err error) (ok bool) {
return errors.As(err, &netErr) && !netErr.Timeout()
}
-// isReportableTCP returns true if err is a TCP or TLS error that should be
+// isReportableWriteTCP returns true if err is a TCP or TLS error that should be
// reported.
-func isReportableTCP(err error) (ok bool) {
+func isReportableWriteTCP(err error) (ok bool) {
if isConnectionBreak(err) {
return false
}
- // Ignore the TLS errors that are probably caused by a network error and
- // errors about protocol versions.
+ // Ignore the TLS errors that are probably caused by a network error, a
+ // record overflow attempt, and errors about protocol versions.
+ //
+ // See also AGDNS-1520.
//
// TODO(a.garipov): Propose exporting these from crypto/tls.
errStr := err.Error()
return !strings.Contains(errStr, "bad record MAC") &&
- !strings.Contains(errStr, "protocol version not supported")
+ !strings.Contains(errStr, "protocol version not supported") &&
+ !strings.Contains(errStr, "local error: tls: record overflow")
}
-// isReportableUDP returns true if err is a UDP error that should be reported.
-func isReportableUDP(err error) (ok bool) {
+// isReportableWriteUDP returns true if err is a UDP error that should be reported.
+func isReportableWriteUDP(err error) (ok bool) {
switch {
case
errors.Is(err, io.EOF),
diff --git a/internal/filter/compfilter.go b/internal/filter/compfilter.go
deleted file mode 100644
index b579a2a..0000000
--- a/internal/filter/compfilter.go
+++ /dev/null
@@ -1,315 +0,0 @@
-package filter
-
-import (
- "context"
- "fmt"
- "strings"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
- "github.com/AdguardTeam/golibs/errors"
- "github.com/AdguardTeam/golibs/log"
- "github.com/AdguardTeam/urlfilter/rules"
- "github.com/miekg/dns"
-)
-
-// Composite Filter
-
-// type check
-var _ Interface = (*compFilter)(nil)
-
-// compFilter is a composite filter based on several types of safe search
-// filters and rule lists.
-type compFilter struct {
- safeBrowsing *HashPrefix
- adultBlocking *HashPrefix
-
- genSafeSearch *safeSearch
- ytSafeSearch *safeSearch
-
- ruleLists []*ruleListFilter
-}
-
-// qtHostFilter is a filter that can filter a request based on its query type
-// and host.
-//
-// TODO(a.garipov): See if devirtualizing this interface would give us any
-// considerable performance gains.
-type qtHostFilter interface {
- filterReq(
- ctx context.Context,
- ri *agd.RequestInfo,
- req *dns.Msg,
- ) (r Result, err error)
- name() (n string)
-}
-
-// FilterRequest implements the Interface interface for *compFilter. If there
-// is a safe search result, it returns it. Otherwise, it returns the action
-// created from the filter list network rule with the highest priority. If f is
-// empty, it returns nil with no error.
-func (f *compFilter) FilterRequest(
- ctx context.Context,
- req *dns.Msg,
- ri *agd.RequestInfo,
-) (r Result, err error) {
- if f.isEmpty() {
- return nil, nil
- }
-
- // Prepare common data for filters.
- reqID := ri.ID
- log.Debug("filters: filtering req %s: %d rule lists", reqID, len(f.ruleLists))
-
- // Firstly, check the profile's filter list rules, custom rules, and the
- // rules from blocked services settings.
- host := ri.Host
- flRes := f.filterMsg(ri, host, ri.QType, req, false)
- switch flRes := flRes.(type) {
- case *ResultAllowed:
- // Skip any additional filtering if the domain is explicitly allowed by
- // user's custom rule.
- if flRes.List == agd.FilterListIDCustom {
- return flRes, nil
- }
- case *ResultBlocked:
- // Skip any additional filtering if the domain is already blocked.
- return flRes, nil
- default:
- // Go on.
- }
-
- // Secondly, apply the safe browsing and safe search filters in the
- // following order.
- //
- // DO NOT change the order of filters without necessity.
- filters := []qtHostFilter{
- f.safeBrowsing,
- f.adultBlocking,
- f.genSafeSearch,
- f.ytSafeSearch,
- }
-
- for _, flt := range filters {
- name := flt.name()
- if name == "" {
- // A nil filter, skip.
- continue
- }
-
- log.Debug("filter %s: filtering req %s", name, reqID)
- r, err = flt.filterReq(ctx, ri, req)
- log.Debug("filter %s: finished filtering req %s, errors: %v", name, reqID, err)
- if err != nil {
- return nil, err
- } else if r != nil {
- return r, nil
- }
- }
-
- // Thirdly, return the previously obtained filter list result.
- return flRes, nil
-}
-
-// FilterResponse implements the Interface interface for *compFilter. It
-// returns the action created from the filter list network rule with the highest
-// priority. If f is empty, it returns nil with no error.
-func (f *compFilter) FilterResponse(
- ctx context.Context,
- resp *dns.Msg,
- ri *agd.RequestInfo,
-) (r Result, err error) {
- if f.isEmpty() || len(resp.Answer) == 0 {
- return nil, nil
- }
-
- for _, ans := range resp.Answer {
- host, rrType, ok := parseRespAnswer(ans)
- if !ok {
- continue
- }
-
- r = f.filterMsg(ri, host, rrType, resp, true)
- if r != nil {
- break
- }
- }
-
- return r, nil
-}
-
-// parseRespAnswer parses hostname and rrType from the answer if there are any.
-// If ans is of a type that doesn't have an IP address or a hostname in it, ok
-// is false.
-func parseRespAnswer(ans dns.RR) (hostname string, rrType dnsmsg.RRType, ok bool) {
- switch ans := ans.(type) {
- case *dns.A:
- return ans.A.String(), dns.TypeA, true
- case *dns.AAAA:
- return ans.AAAA.String(), dns.TypeAAAA, true
- case *dns.CNAME:
- return strings.TrimSuffix(ans.Target, "."), dns.TypeCNAME, true
- default:
- return "", dns.TypeNone, false
- }
-}
-
-// isEmpty returns true if this composite filter is an empty filter.
-func (f *compFilter) isEmpty() (ok bool) {
- return f == nil || (f.safeBrowsing == nil &&
- f.adultBlocking == nil &&
- f.genSafeSearch == nil &&
- f.ytSafeSearch == nil &&
- len(f.ruleLists) == 0)
-}
-
-// filterMsg filters one question's or answer's information through all rule
-// list filters of the composite filter.
-func (f *compFilter) filterMsg(
- ri *agd.RequestInfo,
- host string,
- rrType dnsmsg.RRType,
- msg *dns.Msg,
- answer bool,
-) (r Result) {
- var devName agd.DeviceName
- if d := ri.Device; d != nil {
- devName = d.Name
- }
-
- var networkRules []*rules.NetworkRule
- var hostRules4 []*rules.HostRule
- var hostRules6 []*rules.HostRule
- for _, rl := range f.ruleLists {
- dr := rl.dnsResult(ri.RemoteIP, string(devName), host, rrType, answer)
- if dr == nil {
- continue
- }
-
- // Collect only custom $dnsrewrite rules. It's much more easy
- // to process dnsrewrite rules only from one list, cause when
- // there is no problem with merging them among different lists.
- if !answer && rl.id() == agd.FilterListIDCustom {
- dnsRewriteResult := processDNSRewrites(ri.Messages, msg, dr.DNSRewrites(), host)
- if dnsRewriteResult != nil {
- dnsRewriteResult.List = rl.id()
-
- return dnsRewriteResult
- }
- }
-
- networkRules = append(networkRules, dr.NetworkRules...)
- hostRules4 = append(hostRules4, dr.HostRulesV4...)
- hostRules6 = append(hostRules6, dr.HostRulesV6...)
- }
-
- mr := rules.NewMatchingResult(networkRules, nil)
- if nr := mr.GetBasicResult(); nr != nil {
- return f.ruleDataToResult(nr.FilterListID, nr.RuleText, nr.Whitelist)
- }
-
- return f.hostsRulesToResult(hostRules4, hostRules6, rrType)
-}
-
-// mustRuleListDataByURLFilterID returns the rule list data by its synthetic
-// integer ID in the urlfilter engine. It panics if id is not found.
-func (f *compFilter) mustRuleListDataByURLFilterID(id int) (fltID agd.FilterListID, subID string) {
- for _, rl := range f.ruleLists {
- if rl.urlFilterID == id {
- return rl.id(), rl.subID
- }
- }
-
- // Technically shouldn't happen, since id is supposed to be among the rule
- // list filters in the composite filter.
- panic(fmt.Errorf("filter: synthetic id %d not found", id))
-}
-
-// hostsRulesToResult converts /etc/hosts-style rules into a filtering action.
-func (f *compFilter) hostsRulesToResult(
- hostRules4 []*rules.HostRule,
- hostRules6 []*rules.HostRule,
- rrType dnsmsg.RRType,
-) (r Result) {
- if len(hostRules4) == 0 && len(hostRules6) == 0 {
- return nil
- }
-
- // Only use the first matched rule, since we currently don't care about the
- // IP addresses in the rule. If the request is neither an A one nor an AAAA
- // one, or if there are no matching rules of the requested type, then use
- // whatever rule isn't empty.
- //
- // See also AGDNS-591.
- var resHostRule *rules.HostRule
- if rrType == dns.TypeA && len(hostRules4) > 0 {
- resHostRule = hostRules4[0]
- } else if rrType == dns.TypeAAAA && len(hostRules6) > 0 {
- resHostRule = hostRules6[0]
- } else {
- if len(hostRules4) > 0 {
- resHostRule = hostRules4[0]
- } else {
- resHostRule = hostRules6[0]
- }
- }
-
- return f.ruleDataToResult(resHostRule.FilterListID, resHostRule.RuleText, false)
-}
-
-// ruleDataToResult converts a urlfilter rule data into a filtering result.
-func (f *compFilter) ruleDataToResult(
- urlFilterID int,
- ruleText string,
- allowlist bool,
-) (r Result) {
- // Use the urlFilterID crutch to find the actual IDs of the filtering rule
- // list and blocked service.
- fltID, subID := f.mustRuleListDataByURLFilterID(urlFilterID)
-
- var rule agd.FilterRuleText
- if fltID == agd.FilterListIDBlockedService {
- rule = agd.FilterRuleText(subID)
- } else {
- rule = agd.FilterRuleText(ruleText)
- }
-
- if allowlist {
- log.Debug("rule list %s: allowed by rule %s", fltID, rule)
-
- return &ResultAllowed{
- List: fltID,
- Rule: rule,
- }
- }
-
- log.Debug("rule list %s: blocked by rule %s", fltID, rule)
-
- return &ResultBlocked{
- List: fltID,
- Rule: rule,
- }
-}
-
-// Close implements the Filter interface for *compFilter. It closes all
-// underlying filters.
-func (f *compFilter) Close() (err error) {
- if f.isEmpty() {
- return nil
- }
-
- errs := make([]error, len(f.ruleLists))
- for i, rl := range f.ruleLists {
- err = rl.Close()
- if err != nil {
- errs[i] = fmt.Errorf("rule list at index %d: %w", i, err)
- }
- }
-
- err = errors.Join(errs...)
- if err != nil {
- return fmt.Errorf("closing filters: %w", err)
- }
-
- return nil
-}
diff --git a/internal/filter/compfilter_internal_test.go b/internal/filter/compfilter_internal_test.go
deleted file mode 100644
index d0ab69c..0000000
--- a/internal/filter/compfilter_internal_test.go
+++ /dev/null
@@ -1,284 +0,0 @@
-package filter
-
-import (
- "context"
- "net"
- "strings"
- "testing"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestCompFilter_FilterRequest_badrequest(t *testing.T) {
- const (
- fltListID1 agd.FilterListID = "fl1"
- fltListID2 agd.FilterListID = "fl2"
-
- blockRule = "||example.com^"
- )
-
- rl1, err := newRuleListFltFromStr(blockRule, fltListID1, "", 0, false)
- require.NoError(t, err)
-
- rl2, err := newRuleListFltFromStr("||example.com^$badfilter", fltListID2, "", 0, false)
- require.NoError(t, err)
-
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: testReqFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- testCases := []struct {
- name string
- wantRes Result
- ruleLists []*ruleListFilter
- }{{
- name: "block",
- wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule},
- ruleLists: []*ruleListFilter{rl1},
- }, {
- name: "badfilter_no_block",
- wantRes: nil,
- ruleLists: []*ruleListFilter{rl2},
- }, {
- name: "badfilter_removes_block",
- wantRes: nil,
- ruleLists: []*ruleListFilter{rl1, rl2},
- }}
-
- ri := &agd.RequestInfo{
- Messages: newConstructor(),
- Host: testReqHost,
- RemoteIP: testRemoteIP,
- QType: dns.TypeA,
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- f := &compFilter{
- ruleLists: tc.ruleLists,
- }
-
- ctx := context.Background()
-
- res, rerr := f.FilterRequest(ctx, req, ri)
- require.NoError(t, rerr)
-
- assert.Equal(t, tc.wantRes, res)
- })
- }
-}
-
-func TestCompFilter_FilterRequest_hostsRules(t *testing.T) {
- const (
- fltListID agd.FilterListID = "fl1"
-
- reqHost4 = "www.example.com"
- reqHost6 = "www.example.net"
-
- blockRule4 = "127.0.0.1 www.example.com"
- blockRule6 = "::1 www.example.net"
- )
-
- const rules = blockRule4 + "\n" + blockRule6
-
- rl, err := newRuleListFltFromStr(rules, fltListID, "", 0, false)
- require.NoError(t, err)
-
- f := &compFilter{
- ruleLists: []*ruleListFilter{rl},
- }
-
- testCases := []struct {
- wantRes Result
- name string
- reqHost string
- reqType dnsmsg.RRType
- }{{
- wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4},
- name: "a",
- reqHost: reqHost4,
- reqType: dns.TypeA,
- }, {
- wantRes: &ResultBlocked{List: fltListID, Rule: blockRule6},
- name: "aaaa",
- reqHost: reqHost6,
- reqType: dns.TypeAAAA,
- }, {
- wantRes: &ResultBlocked{List: fltListID, Rule: blockRule6},
- name: "a_with_ipv6_rule",
- reqHost: reqHost6,
- reqType: dns.TypeA,
- }, {
- wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4},
- name: "aaaa_with_ipv4_rule",
- reqHost: reqHost4,
- reqType: dns.TypeAAAA,
- }, {
- wantRes: &ResultBlocked{List: fltListID, Rule: blockRule4},
- name: "mx",
- reqHost: reqHost4,
- reqType: dns.TypeMX,
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- ri := &agd.RequestInfo{
- Messages: newConstructor(),
- Host: tc.reqHost,
- RemoteIP: testRemoteIP,
- QType: tc.reqType,
- }
-
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: dns.Fqdn(tc.reqHost),
- Qtype: tc.reqType,
- Qclass: dns.ClassINET,
- }},
- }
-
- ctx := context.Background()
-
- res, rerr := f.FilterRequest(ctx, req, ri)
- require.NoError(t, rerr)
-
- assert.Equal(t, tc.wantRes, res)
- assert.Equal(t, tc.wantRes, res)
- })
- }
-}
-
-func TestCompFilter_FilterRequest_dnsrewrite(t *testing.T) {
- const (
- fltListID1 agd.FilterListID = "fl1"
- fltListID2 agd.FilterListID = "fl2"
-
- fltListIDCustom = agd.FilterListIDCustom
-
- blockRule = "||example.com^"
- dnsRewriteRuleRefused = "||example.com^$dnsrewrite=REFUSED"
- dnsRewriteRuleCname = "||example.com^$dnsrewrite=cname"
- dnsRewriteRule1 = "||example.com^$dnsrewrite=1.2.3.4"
- dnsRewriteRule2 = "||example.com^$dnsrewrite=1.2.3.5"
- )
-
- rl1, err := newRuleListFltFromStr(blockRule, fltListID1, "", 0, false)
- require.NoError(t, err)
-
- rl2, err := newRuleListFltFromStr(dnsRewriteRuleRefused, fltListID2, "", 0, false)
- require.NoError(t, err)
-
- rlCustomRefused, err := newRuleListFltFromStr(
- dnsRewriteRuleRefused,
- fltListIDCustom,
- string(testProfID),
- 0,
- false,
- )
- require.NoError(t, err)
-
- rlCustomCname, err := newRuleListFltFromStr(
- dnsRewriteRuleCname,
- fltListIDCustom,
- "prof1235",
- 0,
- false,
- )
- require.NoError(t, err)
-
- rlCustom2, err := newRuleListFltFromStr(
- strings.Join([]string{dnsRewriteRule1, dnsRewriteRule2}, "\n"),
- fltListIDCustom,
- "prof1236",
- 0,
- false,
- )
- require.NoError(t, err)
-
- question := dns.Question{
- Name: testReqFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }
-
- req := &dns.Msg{
- Question: []dns.Question{question},
- }
-
- testCases := []struct {
- name string
- wantRes Result
- ruleLists []*ruleListFilter
- }{{
- name: "block",
- wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule},
- ruleLists: []*ruleListFilter{rl1},
- }, {
- name: "dnsrewrite_no_effect",
- wantRes: &ResultBlocked{List: fltListID1, Rule: blockRule},
- ruleLists: []*ruleListFilter{rl1, rl2},
- }, {
- name: "dnsrewrite_block",
- wantRes: &ResultModified{
- Msg: dnsservertest.NewResp(dns.RcodeRefused, req, dnsservertest.RRSection{}),
- List: fltListIDCustom,
- Rule: dnsRewriteRuleRefused,
- },
- ruleLists: []*ruleListFilter{rl1, rl2, rlCustomRefused},
- }, {
- name: "dnsrewrite_cname",
- wantRes: &ResultModified{
- Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{
- RRs: []dns.RR{
- dnsservertest.NewCNAME(testReqHost, 10, "cname"),
- },
- }),
- List: fltListIDCustom,
- Rule: dnsRewriteRuleCname,
- },
- ruleLists: []*ruleListFilter{rl1, rl2, rlCustomCname},
- }, {
- name: "dnsrewrite_answers",
- wantRes: &ResultModified{
- Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{
- RRs: []dns.RR{
- dnsservertest.NewA(testReqHost, 10, net.IP{1, 2, 3, 4}),
- dnsservertest.NewA(testReqHost, 10, net.IP{1, 2, 3, 5}),
- },
- }),
- List: fltListIDCustom,
- Rule: "",
- },
- ruleLists: []*ruleListFilter{rl1, rl2, rlCustom2},
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- f := &compFilter{
- ruleLists: tc.ruleLists,
- }
-
- ctx := context.Background()
- ri := &agd.RequestInfo{
- Messages: newConstructor(),
- Host: testReqHost,
- RemoteIP: testRemoteIP,
- QType: dns.TypeA,
- }
-
- res, rerr := f.FilterRequest(ctx, req, ri)
- require.NoError(t, rerr)
-
- assert.Equal(t, tc.wantRes, res)
- })
- }
-}
diff --git a/internal/filter/custom_internal_test.go b/internal/filter/custom_internal_test.go
deleted file mode 100644
index 5a64acc..0000000
--- a/internal/filter/custom_internal_test.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package filter
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/bluele/gcache"
-)
-
-// errorCollector is an agd.ErrorCollector for tests. This is a copy of the
-// code from package agdtest to evade an import cycle.
-type errorCollector struct {
- OnCollect func(ctx context.Context, err error)
-}
-
-// type check
-var _ agd.ErrorCollector = (*errorCollector)(nil)
-
-// Collect implements the agd.GeoIP interface for *GeoIP.
-func (c *errorCollector) Collect(ctx context.Context, err error) {
- c.OnCollect(ctx, err)
-}
-
-var ruleListsSink []*ruleListFilter
-
-func BenchmarkCustomFilters_ruleCache(b *testing.B) {
- f := &customFilters{
- cache: gcache.New(1).LRU().Build(),
- errColl: &errorCollector{
- OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
- },
- }
-
- p := &agd.Profile{
- ID: testProfID,
- UpdateTime: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC),
- CustomRules: []agd.FilterRuleText{
- "||example.com",
- "||example.org",
- "||example.net",
- },
- }
-
- ctx := context.Background()
-
- b.Run("cache", func(b *testing.B) {
- rls := make([]*ruleListFilter, 0, 1)
-
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- ruleListsSink = f.appendRuleLists(ctx, rls, p)
- }
- })
-
- b.Run("no_cache", func(b *testing.B) {
- rls := make([]*ruleListFilter, 0, 1)
-
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- // Update the time on each iteration to make sure that the cache is
- // never used.
- p.UpdateTime.Add(1 * time.Millisecond)
- ruleListsSink = f.appendRuleLists(ctx, rls, p)
- }
- })
-}
diff --git a/internal/filter/dnsrewrite_internal_test.go b/internal/filter/dnsrewrite_internal_test.go
deleted file mode 100644
index 83aa3d3..0000000
--- a/internal/filter/dnsrewrite_internal_test.go
+++ /dev/null
@@ -1,137 +0,0 @@
-package filter
-
-import (
- "net"
- "testing"
-
- "github.com/AdguardTeam/golibs/testutil"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
- "github.com/AdguardTeam/urlfilter/rules"
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
-)
-
-func Test_processDNSRewriteRules(t *testing.T) {
- cnameRule, _ := rules.NewNetworkRule("|cname^$dnsrewrite=new-cname", 1)
- aRecordRule, _ := rules.NewNetworkRule("|a-record^$dnsrewrite=127.0.0.1", 1)
- refusedRule, _ := rules.NewNetworkRule("|refused^$dnsrewrite=REFUSED", 1)
-
- testCases := []struct {
- name string
- want *DNSRewriteResult
- dnsr []*rules.NetworkRule
- }{{
- name: "empty",
- want: &DNSRewriteResult{
- Response: DNSRewriteResultResponse{},
- },
- dnsr: []*rules.NetworkRule{},
- }, {
- name: "cname",
- want: &DNSRewriteResult{
- ResRuleText: agd.FilterRuleText(cnameRule.RuleText),
- CanonName: cnameRule.DNSRewrite.NewCNAME,
- },
- dnsr: []*rules.NetworkRule{
- cnameRule,
- aRecordRule,
- refusedRule,
- },
- }, {
- name: "refused",
- want: &DNSRewriteResult{
- ResRuleText: agd.FilterRuleText(refusedRule.RuleText),
- RCode: refusedRule.DNSRewrite.RCode,
- },
- dnsr: []*rules.NetworkRule{
- aRecordRule,
- refusedRule,
- },
- }, {
- name: "a_record",
- want: &DNSRewriteResult{
- Rules: []*rules.NetworkRule{aRecordRule},
- RCode: aRecordRule.DNSRewrite.RCode,
- Response: DNSRewriteResultResponse{
- aRecordRule.DNSRewrite.RRType: []rules.RRValue{aRecordRule.DNSRewrite.Value},
- },
- },
- dnsr: []*rules.NetworkRule{aRecordRule},
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- got := processDNSRewriteRules(tc.dnsr)
- assert.Equal(t, tc.want, got)
- })
- }
-}
-
-func Test_filterDNSRewrite(t *testing.T) {
- cnameRule, _ := rules.NewNetworkRule("|cname^$dnsrewrite=new-cname", 1)
- aRecordRule, _ := rules.NewNetworkRule("|a-record^$dnsrewrite=127.0.0.1", 1)
- refusedRule, _ := rules.NewNetworkRule("|refused^$dnsrewrite=REFUSED", 1)
-
- messages := newConstructor()
-
- req := dnsservertest.NewReq(testReqFQDN, dns.TypeA, dns.ClassINET)
-
- testCases := []struct {
- dnsrr *DNSRewriteResult
- want *dns.Msg
- name string
- wantErr string
- }{{
- dnsrr: &DNSRewriteResult{
- Response: DNSRewriteResultResponse{},
- },
- want: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.RRSection{}),
- name: "empty",
- wantErr: "",
- }, {
- dnsrr: &DNSRewriteResult{
- Rules: []*rules.NetworkRule{cnameRule},
- CanonName: cnameRule.DNSRewrite.NewCNAME,
- },
- want: nil,
- name: "cname",
- wantErr: "no dns rewrite rule responses",
- }, {
- dnsrr: &DNSRewriteResult{
- Rules: []*rules.NetworkRule{refusedRule},
- RCode: refusedRule.DNSRewrite.RCode,
- },
- want: nil,
- name: "refused",
- wantErr: "non-success answer",
- }, {
- dnsrr: &DNSRewriteResult{
- Rules: []*rules.NetworkRule{aRecordRule},
- RCode: aRecordRule.DNSRewrite.RCode,
- Response: DNSRewriteResultResponse{
- aRecordRule.DNSRewrite.RRType: []rules.RRValue{aRecordRule.DNSRewrite.Value},
- },
- },
- want: dnsservertest.NewResp(
- aRecordRule.DNSRewrite.RCode,
- req,
- dnsservertest.RRSection{
- RRs: []dns.RR{
- dnsservertest.NewA(testReqHost, 10, net.IP{127, 0, 0, 1}),
- },
- },
- ),
- name: "a_record",
- wantErr: "",
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- resp, err := filterDNSRewrite(messages, req, tc.dnsrr)
- testutil.AssertErrorMsg(t, tc.wantErr, err)
- assert.Equal(t, tc.want, resp)
- })
- }
-}
diff --git a/internal/filter/filter.go b/internal/filter/filter.go
index e176111..91652b5 100644
--- a/internal/filter/filter.go
+++ b/internal/filter/filter.go
@@ -5,50 +5,14 @@ package filter
import (
"context"
- "time"
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
- "github.com/c2h5oh/datasize"
- "github.com/miekg/dns"
)
-// Common Constants, Functions, and Types
-
-// maxFilterSize is the maximum size of downloaded filters.
-//
-// TODO(ameshkov): Consider making configurable.
-const maxFilterSize = 256 * int64(datasize.MB)
-
-// defaultFilterRefreshTimeout is the default timeout to use when fetching
-// filter lists data.
-//
-// TODO(a.garipov): Consider making timeouts where they are used configurable.
-const defaultFilterRefreshTimeout = 180 * time.Second
-
-// defaultResolveTimeout is the default timeout for resolving hosts for safe
-// search and safe browsing filters.
-//
-// TODO(ameshkov): Consider making configurable.
-const defaultResolveTimeout = 1 * time.Second
-
// Interface is the DNS request and response filter interface.
-type Interface interface {
- // FilterRequest filters the DNS request for the provided client. All
- // parameters must be non-nil. req must have exactly one question. If a is
- // nil, the request doesn't match any of the rules.
- FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error)
+type Interface = internal.Interface
- // FilterResponse filters the DNS response for the provided client. All
- // parameters must be non-nil. If a is nil, the response doesn't match any
- // of the rules.
- FilterResponse(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) (r Result, err error)
-
- // Close closes the filter and frees resources associated with it.
- Close() (err error)
-}
-
-// Filtering Result Aliases
+// Filtering result aliases
// Result is a sum type of all possible filtering actions. See the following
// types as implementations:
@@ -69,3 +33,17 @@ type ResultBlocked = internal.ResultBlocked
// ResultModified means that this request or response was rewritten or modified
// by a rewrite rule within the given filter list.
type ResultModified = internal.ResultModified
+
+// Hash matching for safe-browsing and adult-content blocking
+
+// HashMatcher is the interface for a safe-browsing and adult-blocking hash
+// matcher, which is used to respond to a TXT query based on the domain name.
+type HashMatcher interface {
+ MatchByPrefix(ctx context.Context, host string) (hashes []string, matched bool, err error)
+}
+
+// Default safe-browsing host suffixes.
+const (
+ GeneralTXTSuffix = ".sb.dns.adguard.com"
+ AdultBlockingTXTSuffix = ".pc.dns.adguard.com"
+)
diff --git a/internal/filter/filter_internal_test.go b/internal/filter/filter_internal_test.go
deleted file mode 100644
index 31e659e..0000000
--- a/internal/filter/filter_internal_test.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package filter
-
-import (
- "net/netip"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
-)
-
-// Common test constants.
-
-// testTimeout is the timeout for tests.
-const testTimeout = 1 * time.Second
-
-// testProfID is the profile ID for tests.
-const testProfID agd.ProfileID = "prof1234"
-
-// testReqHost is the request host for tests.
-const testReqHost = "www.example.com"
-
-// testReqFQDN is the request FQDN for tests.
-const testReqFQDN = testReqHost + "."
-
-// testRemoteIP is the client IP for tests
-var testRemoteIP = netip.MustParseAddr("1.2.3.4")
-
-// newConstructor returns a standard dnsmsg.Constructor for tests.
-//
-// TODO(a.garipov): Use [agdtest.NewConstructor] once the package is split and
-// import cycles are resolved.
-func newConstructor() (c *dnsmsg.Constructor) {
- return dnsmsg.NewConstructor(&dnsmsg.BlockingModeNullIP{}, 10*time.Second)
-}
diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go
index 8c360c2..de8edb6 100644
--- a/internal/filter/filter_test.go
+++ b/internal/filter/filter_test.go
@@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/require"
)
@@ -182,8 +183,8 @@ func prepareConf(t testing.TB) (c *filter.DefaultStorageConfig) {
FilterIndexURL: fltsURL,
GeneralSafeSearchRulesURL: ssURL,
YoutubeSafeSearchRulesURL: ssURL,
- SafeBrowsing: &filter.HashPrefix{},
- AdultBlocking: &filter.HashPrefix{},
+ SafeBrowsing: &hashprefix.Filter{},
+ AdultBlocking: &hashprefix.Filter{},
Now: time.Now,
ErrColl: nil,
Resolver: nil,
diff --git a/internal/filter/hashprefix.go b/internal/filter/hashprefix/filter.go
similarity index 62%
rename from internal/filter/hashprefix.go
rename to internal/filter/hashprefix/filter.go
index 3d0dd48..c1d8708 100644
--- a/internal/filter/hashprefix.go
+++ b/internal/filter/hashprefix/filter.go
@@ -1,4 +1,4 @@
-package filter
+package hashprefix
import (
"context"
@@ -8,9 +8,8 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/log"
@@ -20,12 +19,10 @@ import (
"golang.org/x/net/publicsuffix"
)
-// Hash-prefix filter
-
-// HashPrefixConfig is the hash-prefix filter configuration structure.
-type HashPrefixConfig struct {
+// FilterConfig is the hash-prefix filter configuration structure.
+type FilterConfig struct {
// Hashes are the hostname hashes for this filter.
- Hashes *hashstorage.Storage
+ Hashes *Storage
// URL is the URL used to update the filter.
URL *url.URL
@@ -54,71 +51,67 @@ type HashPrefixConfig struct {
// CacheTTL is the time-to-live value used to cache the results of the
// filter.
//
- // TODO(a.garipov): Currently unused.
+ // TODO(a.garipov): Currently unused. See AGDNS-398.
CacheTTL time.Duration
// CacheSize is the size of the filter's result cache.
CacheSize int
}
-// HashPrefix is a filter that matches hosts by their hashes based on a
+// Filter is a filter that matches hosts by their hashes based on a
// hash-prefix table.
-type HashPrefix struct {
- hashes *hashstorage.Storage
- refr *refreshableFilter
- resCache *resultcache.Cache[*ResultModified]
+type Filter struct {
+ hashes *Storage
+ refr *internal.Refreshable
+ resCache *resultcache.Cache[*internal.ResultModified]
resolver agdnet.Resolver
errColl agd.ErrorCollector
+ id agd.FilterListID
repHost string
}
-// NewHashPrefix returns a new hash-prefix filter. c must not be nil.
-func NewHashPrefix(c *HashPrefixConfig) (f *HashPrefix, err error) {
- f = &HashPrefix{
- hashes: c.Hashes,
- refr: &refreshableFilter{
- http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultFilterRefreshTimeout,
- }),
- url: c.URL,
- id: c.ID,
- cachePath: c.CachePath,
- typ: "hash storage",
- staleness: c.Staleness,
- },
- resCache: resultcache.New[*ResultModified](c.CacheSize),
+// NewFilter returns a new hash-prefix filter. c must not be nil.
+func NewFilter(c *FilterConfig) (f *Filter, err error) {
+ id := c.ID
+ f = &Filter{
+ hashes: c.Hashes,
+ resCache: resultcache.New[*internal.ResultModified](c.CacheSize),
resolver: c.Resolver,
errColl: c.ErrColl,
+ id: id,
repHost: c.ReplacementHost,
}
- f.refr.resetRules = f.resetRules
+ f.refr = internal.NewRefreshable(
+ &agd.FilterList{
+ ID: id,
+ URL: c.URL,
+ RefreshIvl: c.Staleness,
+ },
+ c.CachePath,
+ )
err = f.refresh(context.Background(), true)
if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
return nil, err
}
return f, nil
}
-// id returns the ID of the hash storage.
-func (f *HashPrefix) id() (fltID agd.FilterListID) {
- return f.refr.id
-}
-
// type check
-var _ qtHostFilter = (*HashPrefix)(nil)
+var _ internal.RequestFilter = (*Filter)(nil)
-// filterReq implements the qtHostFilter interface for *hashPrefixFilter. It
-// modifies the response if host matches f.
-func (f *HashPrefix) filterReq(
+// FilterRequest implements the [internal.RequestFilter] interface for
+// *Filter. It modifies the response if host matches f.
+func (f *Filter) FilterRequest(
ctx context.Context,
- ri *agd.RequestInfo,
req *dns.Msg,
-) (r Result, err error) {
- host, qt := ri.Host, ri.QType
- cacheKey := resultcache.DefaultKey(host, qt, false)
+ ri *agd.RequestInfo,
+) (r internal.Result, err error) {
+ host, qt, cl := ri.Host, ri.QType, ri.QClass
+ cacheKey := resultcache.DefaultKey(host, qt, cl, false)
rm, ok := f.resCache.Get(cacheKey)
f.updateCacheLookupsMetrics(ok)
if ok {
@@ -152,25 +145,25 @@ func (f *HashPrefix) filterReq(
return nil, nil
}
- ctx, cancel := context.WithTimeout(ctx, defaultResolveTimeout)
+ ctx, cancel := context.WithTimeout(ctx, internal.DefaultResolveTimeout)
defer cancel()
var result *dns.Msg
ips, err := f.resolver.LookupIP(ctx, fam, f.repHost)
if err != nil {
- agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id(), err)
+ agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err)
result = ri.Messages.NewMsgSERVFAIL(req)
} else {
result, err = ri.Messages.NewIPRespMsg(req, ips...)
if err != nil {
- return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id(), err)
+ return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err)
}
}
- rm = &ResultModified{
+ rm = &internal.ResultModified{
Msg: result,
- List: f.id(),
+ List: f.id,
Rule: agd.FilterRuleText(matched),
}
@@ -185,8 +178,8 @@ func (f *HashPrefix) filterReq(
}
// updateCacheSizeMetrics updates cache size metrics.
-func (f *HashPrefix) updateCacheSizeMetrics(size int) {
- switch id := f.id(); id {
+func (f *Filter) updateCacheSizeMetrics(size int) {
+ switch id := f.id; id {
case agd.FilterListIDSafeBrowsing:
metrics.HashPrefixFilterSafeBrowsingCacheSize.Set(float64(size))
case agd.FilterListIDAdultBlocking:
@@ -197,9 +190,9 @@ func (f *HashPrefix) updateCacheSizeMetrics(size int) {
}
// updateCacheLookupsMetrics updates cache lookups metrics.
-func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
+func (f *Filter) updateCacheLookupsMetrics(hit bool) {
var hitsMetric, missesMetric prometheus.Counter
- switch id := f.id(); id {
+ switch id := f.id; id {
case agd.FilterListIDSafeBrowsing:
hitsMetric = metrics.HashPrefixFilterCacheSafeBrowsingHits
missesMetric = metrics.HashPrefixFilterCacheSafeBrowsingMisses
@@ -207,7 +200,7 @@ func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
hitsMetric = metrics.HashPrefixFilterCacheAdultBlockingHits
missesMetric = metrics.HashPrefixFilterCacheAdultBlockingMisses
default:
- panic(fmt.Errorf("unsupported FilterListID %s", id))
+ panic(fmt.Errorf("unsupported filter list id %s", id))
}
if hit {
@@ -217,49 +210,37 @@ func (f *HashPrefix) updateCacheLookupsMetrics(hit bool) {
}
}
-// name implements the qtHostFilter interface for *hashPrefixFilter.
-func (f *HashPrefix) name() (n string) {
- if f == nil {
- return ""
- }
-
- return string(f.id())
-}
-
// type check
-var _ agd.Refresher = (*HashPrefix)(nil)
+var _ agd.Refresher = (*Filter)(nil)
// Refresh implements the [agd.Refresher] interface for *hashPrefixFilter.
-func (f *HashPrefix) Refresh(ctx context.Context) (err error) {
+func (f *Filter) Refresh(ctx context.Context) (err error) {
return f.refresh(ctx, false)
}
-// refresh reloads the hash filter data. If acceptStale is true, do not try to
-// load the list from its URL when there is already a file in the cache
-// directory, regardless of its staleness.
-func (f *HashPrefix) refresh(ctx context.Context, acceptStale bool) (err error) {
- return f.refr.refresh(ctx, acceptStale)
-}
-
-// resetRules resets the hosts in the index.
-func (f *HashPrefix) resetRules(text string) (err error) {
- n, err := f.hashes.Reset(text)
-
- // Report the filter update to prometheus.
- promLabels := prometheus.Labels{
- "filter": string(f.id()),
- }
-
- metrics.SetStatusGauge(metrics.FilterUpdatedStatus.With(promLabels), err)
-
+// refresh reloads and resets the hash-filter data. If acceptStale is true, do
+// not try to load the list from its URL when there is already a file in the
+// cache directory, regardless of its staleness.
+func (f *Filter) refresh(ctx context.Context, acceptStale bool) (err error) {
+ text, err := f.refr.Refresh(ctx, acceptStale)
if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
return err
}
- metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
- metrics.FilterRulesTotal.With(promLabels).Set(float64(n))
+ n, err := f.hashes.Reset(text)
+ fltIDStr := string(f.id)
+ metrics.SetStatusGauge(metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr), err)
+ if err != nil {
+ return fmt.Errorf("resetting: %w", err)
+ }
- log.Info("filter %s: reset %d hosts", f.id(), n)
+ f.resCache.Clear()
+
+ metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime()
+ metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(n))
+
+ log.Info("filter %s: reset %d hosts", f.id, n)
return nil
}
diff --git a/internal/filter/hashprefix/filter_test.go b/internal/filter/hashprefix/filter_test.go
new file mode 100644
index 0000000..386ced8
--- /dev/null
+++ b/internal/filter/hashprefix/filter_test.go
@@ -0,0 +1,338 @@
+package hashprefix_test
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFilter_FilterRequest(t *testing.T) {
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, testHost, http.StatusOK)
+
+ strg, err := hashprefix.NewStorage("")
+ require.NoError(t, err)
+
+ replIP := net.IP{1, 2, 3, 4}
+ f, err := hashprefix.NewFilter(&hashprefix.FilterConfig{
+ Hashes: strg,
+ URL: srvURL,
+ ErrColl: &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, _ error) {
+ panic("not implemented")
+ },
+ },
+ Resolver: &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ return []net.IP{replIP}, nil
+ },
+ },
+ ID: agd.FilterListIDAdultBlocking,
+ CachePath: cachePath,
+ ReplacementHost: "repl.example",
+ Staleness: 1 * time.Minute,
+ CacheTTL: 1 * time.Minute,
+ CacheSize: 1,
+ })
+ require.NoError(t, err)
+
+ messages := agdtest.NewConstructor()
+
+ testCases := []struct {
+ name string
+ host string
+ qType dnsmsg.RRType
+ wantResult bool
+ }{{
+ name: "not_a_or_aaaa",
+ host: testHost,
+ qType: dns.TypeTXT,
+ wantResult: false,
+ }, {
+ name: "success",
+ host: testHost,
+ qType: dns.TypeA,
+ wantResult: true,
+ }, {
+ name: "success_subdomain",
+ host: "a.b.c." + testHost,
+ qType: dns.TypeA,
+ wantResult: true,
+ }, {
+ name: "no_match",
+ host: testOtherHost,
+ qType: dns.TypeA,
+ wantResult: false,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := dnsservertest.NewReq(
+ dns.Fqdn(tc.host),
+ tc.qType,
+ dns.ClassINET,
+ )
+ ri := &agd.RequestInfo{
+ Messages: messages,
+ Host: tc.host,
+ QType: tc.qType,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ if tc.wantResult {
+ wantRes := newModifiedResult(t, req, messages, replIP)
+ assert.Equal(t, wantRes, r)
+ } else {
+ assert.Nil(t, r)
+ }
+ })
+ }
+
+ t.Run("cached_success", func(t *testing.T) {
+ req := dnsservertest.NewReq(
+ dns.Fqdn(testHost),
+ dns.TypeA,
+ dns.ClassINET,
+ )
+ ri := &agd.RequestInfo{
+ Messages: messages,
+ Host: testHost,
+ QType: dns.TypeA,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ wantRes := newModifiedResult(t, req, messages, replIP)
+ assert.Equal(t, wantRes, r)
+ })
+
+ t.Run("cached_no_match", func(t *testing.T) {
+ req := dnsservertest.NewReq(
+ dns.Fqdn(testOtherHost),
+ dns.TypeA,
+ dns.ClassINET,
+ )
+ ri := &agd.RequestInfo{
+ Messages: messages,
+ Host: testOtherHost,
+ QType: dns.TypeA,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Nil(t, r)
+ })
+}
+
+// newModifiedResult is a helper for creating modified results for tests.
+func newModifiedResult(
+ tb testing.TB,
+ req *dns.Msg,
+ messages *dnsmsg.Constructor,
+ replIP net.IP,
+) (r *internal.ResultModified) {
+ resp, err := messages.NewIPRespMsg(req, replIP)
+ require.NoError(tb, err)
+
+ return &internal.ResultModified{
+ Msg: resp,
+ List: testFltListID,
+ Rule: testHost,
+ }
+}
+
+func TestFilter_Refresh(t *testing.T) {
+ reqCh := make(chan struct{}, 1)
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, testHost, http.StatusOK)
+
+ strg, err := hashprefix.NewStorage("")
+ require.NoError(t, err)
+
+ f, err := hashprefix.NewFilter(&hashprefix.FilterConfig{
+ Hashes: strg,
+ URL: srvURL,
+ ErrColl: &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, _ error) {
+ panic("not implemented")
+ },
+ },
+ Resolver: &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ panic("not implemented")
+ },
+ },
+ ID: agd.FilterListIDAdultBlocking,
+ CachePath: cachePath,
+ ReplacementHost: "",
+ Staleness: 1 * time.Minute,
+ CacheTTL: 1 * time.Minute,
+ CacheSize: 1,
+ })
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ err = f.Refresh(ctx)
+ assert.NoError(t, err)
+
+ testutil.RequireReceive(t, reqCh, filtertest.Timeout)
+}
+
+func TestFilter_FilterRequest_staleCache(t *testing.T) {
+ refrCh := make(chan struct{}, 1)
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, refrCh, testHost, http.StatusOK)
+
+ // Put some initial data into the cache to avoid the first refresh.
+
+ cf, err := os.OpenFile(cachePath, os.O_WRONLY|os.O_APPEND, os.ModeAppend)
+ require.NoError(t, err)
+
+ _, err = cf.WriteString(testOtherHost)
+ require.NoError(t, err)
+ require.NoError(t, cf.Close())
+
+ // Create the filter.
+
+ strg, err := hashprefix.NewStorage("")
+ require.NoError(t, err)
+
+ replIP := net.IP{1, 2, 3, 4}
+ fconf := &hashprefix.FilterConfig{
+ Hashes: strg,
+ URL: srvURL,
+ ErrColl: &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, _ error) {
+ panic("not implemented")
+ },
+ },
+ Resolver: &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ return []net.IP{replIP}, nil
+ },
+ },
+ ID: agd.FilterListIDAdultBlocking,
+ CachePath: cachePath,
+ ReplacementHost: "repl.example",
+ Staleness: 1 * time.Minute,
+ CacheTTL: 1 * time.Minute,
+ CacheSize: 1,
+ }
+ f, err := hashprefix.NewFilter(fconf)
+ require.NoError(t, err)
+
+ messages := agdtest.NewConstructor()
+
+ // Test the following:
+ //
+ // 1. Check that the stale rules cache is used.
+ // 2. Refresh the stale rules cache.
+ // 3. Ensure the result cache is cleared.
+ // 4. Ensure the stale rules aren't used.
+
+ testHostReq := dnsservertest.NewReq(dns.Fqdn(testHost), dns.TypeA, dns.ClassINET)
+ testReqInfo := &agd.RequestInfo{Messages: messages, Host: testHost, QType: dns.TypeA}
+
+ testOtherHostReq := dnsservertest.NewReq(dns.Fqdn(testOtherHost), dns.TypeA, dns.ClassINET)
+ testOtherReqInfo := &agd.RequestInfo{Messages: messages, Host: testOtherHost, QType: dns.TypeA}
+
+ t.Run("hit_cached_host", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, testOtherHostReq, testOtherReqInfo)
+ require.NoError(t, err)
+
+ var resp *dns.Msg
+ resp, err = messages.NewIPRespMsg(testOtherHostReq, replIP)
+ require.NoError(t, err)
+
+ assert.Equal(t, &internal.ResultModified{
+ Msg: resp,
+ List: testFltListID,
+ Rule: testOtherHost,
+ }, r)
+ })
+
+ t.Run("refresh", func(t *testing.T) {
+ // Make the cache stale.
+ now := time.Now()
+ err = os.Chtimes(cachePath, now, now.Add(-2*fconf.Staleness))
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ err = f.Refresh(ctx)
+ assert.NoError(t, err)
+
+ testutil.RequireReceive(t, refrCh, filtertest.Timeout)
+ })
+
+ t.Run("previously_cached", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, testOtherHostReq, testOtherReqInfo)
+ require.NoError(t, err)
+
+ assert.Nil(t, r)
+ })
+
+ t.Run("new_host", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var r internal.Result
+ r, err = f.FilterRequest(ctx, testHostReq, testReqInfo)
+ require.NoError(t, err)
+
+ wantRes := newModifiedResult(t, testHostReq, messages, replIP)
+ assert.Equal(t, wantRes, r)
+ })
+}
diff --git a/internal/filter/hashprefix/hashprefix.go b/internal/filter/hashprefix/hashprefix.go
new file mode 100644
index 0000000..34dfcd1
--- /dev/null
+++ b/internal/filter/hashprefix/hashprefix.go
@@ -0,0 +1,32 @@
+// Package hashprefix defines a storage of hashes of domain names used for
+// filtering and serving TXT records with domain-name hashes.
+package hashprefix
+
+import "crypto/sha256"
+
+// Hash and hash part length constants.
+const (
+ // PrefixLen is the length of the hash prefix of the filtered hostname.
+ PrefixLen = 2
+
+ // PrefixEncLen is the encoded length of the hash prefix. Two text
+ // bytes per one binary byte.
+ PrefixEncLen = PrefixLen * 2
+
+ // hashLen is the length of the whole hash of the checked hostname.
+ hashLen = sha256.Size
+
+ // suffixLen is the length of the hash suffix of the filtered hostname.
+ suffixLen = hashLen - PrefixLen
+
+ // hashEncLen is the encoded length of the hash. Two text bytes per one
+ // binary byte.
+ hashEncLen = hashLen * 2
+)
+
+// Prefix is the type of the SHA256 hash prefix used to match against the
+// domain-name database.
+type Prefix [PrefixLen]byte
+
+// suffix is the type of the rest of a SHA256 hash of the filtered domain names.
+type suffix [suffixLen]byte
diff --git a/internal/filter/hashprefix/hashprefix_test.go b/internal/filter/hashprefix/hashprefix_test.go
new file mode 100644
index 0000000..de2de4f
--- /dev/null
+++ b/internal/filter/hashprefix/hashprefix_test.go
@@ -0,0 +1,21 @@
+package hashprefix_test
+
+import (
+ "testing"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/golibs/testutil"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// testFltListID is the common filtering-list for tests.
+const testFltListID = agd.FilterListIDAdultBlocking
+
+// Common hostnames for tests.
+const (
+ testHost = "porn.example"
+ testOtherHost = "otherporn.example"
+)
diff --git a/internal/filter/hashprefix/matcher.go b/internal/filter/hashprefix/matcher.go
new file mode 100644
index 0000000..5d8f752
--- /dev/null
+++ b/internal/filter/hashprefix/matcher.go
@@ -0,0 +1,105 @@
+package hashprefix
+
+import (
+ "context"
+ "encoding/hex"
+ "fmt"
+ "strings"
+
+ "github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/golibs/stringutil"
+)
+
+// Matcher is a hash-prefix matcher that uses the hash-prefix storages as the
+// source of its data.
+type Matcher struct {
+ // storages is a mapping of domain-name suffixes to the storage containing
+ // hashes for this domain.
+ storages map[string]*Storage
+}
+
+// NewMatcher returns a new hash-prefix matcher. storages is a mapping of
+// domain-name suffixes to the storage containing hashes for this domain.
+func NewMatcher(storages map[string]*Storage) (m *Matcher) {
+ return &Matcher{
+ storages: storages,
+ }
+}
+
+// MatchByPrefix implements the [filter.HashMatcher] interface for *Matcher. It
+// returns the matched hashes if the host matched one of the domain names in m's
+// storages.
+//
+// TODO(a.garipov): Use the context for logging etc.
+func (m *Matcher) MatchByPrefix(
+ _ context.Context,
+ host string,
+) (hashes []string, matched bool, err error) {
+ var (
+ suffix string
+ prefixesStr string
+ strg *Storage
+ )
+
+ for suffix, strg = range m.storages {
+ if strings.HasSuffix(host, suffix) {
+ prefixesStr = host[:len(host)-len(suffix)]
+ matched = true
+
+ break
+ }
+ }
+
+ if !matched {
+ return nil, false, nil
+ }
+
+ log.Debug("hashprefix matcher: got prefixes string %q", prefixesStr)
+
+ hashPrefixes, err := prefixesFromStr(prefixesStr)
+ if err != nil {
+ return nil, false, err
+ }
+
+ return strg.Hashes(hashPrefixes), true, nil
+}
+
+// legacyPrefixEncLen is the encoded length of a legacy hash.
+const legacyPrefixEncLen = 8
+
+// prefixesFromStr returns hash prefixes from a dot-separated string.
+func prefixesFromStr(prefixesStr string) (hashPrefixes []Prefix, err error) {
+ if prefixesStr == "" {
+ return nil, nil
+ }
+
+ prefixSet := stringutil.NewSet()
+ prefixStrs := strings.Split(prefixesStr, ".")
+ for _, s := range prefixStrs {
+ if len(s) != PrefixEncLen {
+ // Some legacy clients send eight-character hashes instead of
+ // four-character ones. For now, remove the final four characters.
+ //
+ // TODO(a.garipov): Either remove this crutch or support such
+ // prefixes better.
+ if len(s) == legacyPrefixEncLen {
+ s = s[:PrefixEncLen]
+ } else {
+ return nil, fmt.Errorf("bad hash len for %q", s)
+ }
+ }
+
+ prefixSet.Add(s)
+ }
+
+ hashPrefixes = make([]Prefix, prefixSet.Len())
+ prefixStrs = prefixSet.Values()
+ for i, s := range prefixStrs {
+ _, err = hex.Decode(hashPrefixes[i][:], []byte(s))
+ if err != nil {
+ return nil, fmt.Errorf("bad hash encoding for %q", s)
+ }
+ }
+
+ return hashPrefixes, nil
+}
diff --git a/internal/filter/safebrowsing_test.go b/internal/filter/hashprefix/matcher_test.go
similarity index 63%
rename from internal/filter/safebrowsing_test.go
rename to internal/filter/hashprefix/matcher_test.go
index 55afa30..1c57288 100644
--- a/internal/filter/safebrowsing_test.go
+++ b/internal/filter/hashprefix/matcher_test.go
@@ -1,4 +1,4 @@
-package filter_test
+package hashprefix_test
import (
"context"
@@ -8,23 +8,29 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func TestSafeBrowsingServer(t *testing.T) {
- // Hashes
+// type check
+//
+// TODO(a.garipov): Move this into the actual package instead of keeping it in
+// the test package if [filter.Storage] and [filter.compFilter] are moved.
+var _ filter.HashMatcher = (*hashprefix.Matcher)(nil)
+func TestMatcher(t *testing.T) {
const (
realisticHostIdx = iota
samePrefixHost1Idx
samePrefixHost2Idx
)
+ const suffix = filter.GeneralTXTSuffix
+
hosts := []string{
// Data closer to real world.
- realisticHostIdx: safeBrowsingHost,
+ realisticHostIdx: "scam.example.net",
// Additional data that has the same prefixes.
samePrefixHost1Idx: "3z",
@@ -37,7 +43,7 @@ func TestSafeBrowsingServer(t *testing.T) {
hashStrs[i] = hex.EncodeToString(sum[:])
}
- hashes, err := hashstorage.New(strings.Join(hosts, "\n"))
+ hashes, err := hashprefix.NewStorage(strings.Join(hosts, "\n"))
require.NoError(t, err)
ctx := context.Background()
@@ -53,14 +59,14 @@ func TestSafeBrowsingServer(t *testing.T) {
wantMatched: false,
}, {
name: "realistic",
- host: hashStrs[realisticHostIdx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
+ host: hashStrs[realisticHostIdx][:hashprefix.PrefixEncLen] + suffix,
wantHashStrs: []string{
hashStrs[realisticHostIdx],
},
wantMatched: true,
}, {
name: "same_prefix",
- host: hashStrs[samePrefixHost1Idx][:hashstorage.PrefixEncLen] + filter.GeneralTXTSuffix,
+ host: hashStrs[samePrefixHost1Idx][:hashprefix.PrefixEncLen] + suffix,
wantHashStrs: []string{
hashStrs[samePrefixHost1Idx],
hashStrs[samePrefixHost2Idx],
@@ -70,11 +76,13 @@ func TestSafeBrowsingServer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- srv := filter.NewSafeBrowsingServer(hashes, nil)
+ srv := hashprefix.NewMatcher(map[string]*hashprefix.Storage{
+ suffix: hashes,
+ })
var gotHashStrs []string
var matched bool
- gotHashStrs, matched, err = srv.Hashes(ctx, tc.host)
+ gotHashStrs, matched, err = srv.MatchByPrefix(ctx, tc.host)
require.NoError(t, err)
assert.Equal(t, tc.wantMatched, matched)
diff --git a/internal/filter/hashstorage/hashstorage.go b/internal/filter/hashprefix/storage.go
similarity index 77%
rename from internal/filter/hashstorage/hashstorage.go
rename to internal/filter/hashprefix/storage.go
index 29afda9..49a522f 100644
--- a/internal/filter/hashstorage/hashstorage.go
+++ b/internal/filter/hashprefix/storage.go
@@ -1,6 +1,4 @@
-// Package hashstorage defines a storage of hashes of domain names used for
-// filtering.
-package hashstorage
+package hashprefix
import (
"bufio"
@@ -11,47 +9,19 @@ import (
"sync"
)
-// Hash and hash part length constants.
-const (
- // PrefixLen is the length of the prefix of the hash of the filtered
- // hostname.
- PrefixLen = 2
-
- // PrefixEncLen is the encoded length of the hash prefix. Two text
- // bytes per one binary byte.
- PrefixEncLen = PrefixLen * 2
-
- // hashLen is the length of the whole hash of the checked hostname.
- hashLen = sha256.Size
-
- // suffixLen is the length of the suffix of the hash of the filtered
- // hostname.
- suffixLen = hashLen - PrefixLen
-
- // hashEncLen is the encoded length of the hash. Two text bytes per one
- // binary byte.
- hashEncLen = hashLen * 2
-)
-
-// Prefix is the type of the 2-byte prefix of a full 32-byte SHA256 hash of a
-// host being checked.
-type Prefix [PrefixLen]byte
-
-// suffix is the type of the 30-byte suffix of a full 32-byte SHA256 hash of a
-// host being checked.
-type suffix [suffixLen]byte
-
// Storage stores hashes of the filtered hostnames. All methods are safe for
// concurrent use.
+//
+// TODO(a.garipov): See if we could unexport this.
type Storage struct {
// mu protects hashSuffixes.
mu *sync.RWMutex
hashSuffixes map[Prefix][]suffix
}
-// New returns a new hash storage containing hashes of the domain names listed
-// in hostnames, one domain name per line.
-func New(hostnames string) (s *Storage, err error) {
+// NewStorage returns a new hash storage containing hashes of the domain names
+// listed in hostnames, one domain name per line.
+func NewStorage(hostnames string) (s *Storage, err error) {
s = &Storage{
mu: &sync.RWMutex{},
hashSuffixes: map[Prefix][]suffix{},
diff --git a/internal/filter/hashstorage/hashstorage_test.go b/internal/filter/hashprefix/storage_test.go
similarity index 71%
rename from internal/filter/hashstorage/hashstorage_test.go
rename to internal/filter/hashprefix/storage_test.go
index 944d304..98f6a94 100644
--- a/internal/filter/hashstorage/hashstorage_test.go
+++ b/internal/filter/hashprefix/storage_test.go
@@ -1,4 +1,4 @@
-package hashstorage_test
+package hashprefix_test
import (
"crypto/sha256"
@@ -8,61 +8,55 @@ import (
"strings"
"testing"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-// Common hostnames for tests.
-const (
- testHost = "porn.example"
- otherHost = "otherporn.example"
-)
-
func TestStorage_Hashes(t *testing.T) {
- s, err := hashstorage.New(testHost)
+ s, err := hashprefix.NewStorage(testHost)
require.NoError(t, err)
h := sha256.Sum256([]byte(testHost))
want := []string{hex.EncodeToString(h[:])}
- p := hashstorage.Prefix{h[0], h[1]}
- got := s.Hashes([]hashstorage.Prefix{p})
+ p := hashprefix.Prefix{h[0], h[1]}
+ got := s.Hashes([]hashprefix.Prefix{p})
assert.Equal(t, want, got)
- wrong := s.Hashes([]hashstorage.Prefix{{}})
+ wrong := s.Hashes([]hashprefix.Prefix{{}})
assert.Empty(t, wrong)
}
func TestStorage_Matches(t *testing.T) {
- s, err := hashstorage.New(testHost)
+ s, err := hashprefix.NewStorage(testHost)
require.NoError(t, err)
got := s.Matches(testHost)
assert.True(t, got)
- got = s.Matches(otherHost)
+ got = s.Matches(testOtherHost)
assert.False(t, got)
}
func TestStorage_Reset(t *testing.T) {
- s, err := hashstorage.New(testHost)
+ s, err := hashprefix.NewStorage(testHost)
require.NoError(t, err)
- n, err := s.Reset(otherHost)
+ n, err := s.Reset(testOtherHost)
require.NoError(t, err)
assert.Equal(t, 1, n)
- h := sha256.Sum256([]byte(otherHost))
+ h := sha256.Sum256([]byte(testOtherHost))
want := []string{hex.EncodeToString(h[:])}
- p := hashstorage.Prefix{h[0], h[1]}
- got := s.Hashes([]hashstorage.Prefix{p})
+ p := hashprefix.Prefix{h[0], h[1]}
+ got := s.Hashes([]hashprefix.Prefix{p})
assert.Equal(t, want, got)
prevHash := sha256.Sum256([]byte(testHost))
- prev := s.Hashes([]hashstorage.Prefix{{prevHash[0], prevHash[1]}})
+ prev := s.Hashes([]hashprefix.Prefix{{prevHash[0], prevHash[1]}})
assert.Empty(t, prev)
}
@@ -80,12 +74,12 @@ func BenchmarkStorage_Hashes(b *testing.B) {
hosts = append(hosts, fmt.Sprintf("%d."+testHost, i))
}
- s, err := hashstorage.New(strings.Join(hosts, "\n"))
+ s, err := hashprefix.NewStorage(strings.Join(hosts, "\n"))
require.NoError(b, err)
- var hashPrefixes []hashstorage.Prefix
+ var hashPrefixes []hashprefix.Prefix
for i := 0; i < 4; i++ {
- hashPrefixes = append(hashPrefixes, hashstorage.Prefix{hosts[i][0], hosts[i][1]})
+ hashPrefixes = append(hashPrefixes, hashprefix.Prefix{hosts[i][0], hosts[i][1]})
}
for n := 1; n <= 4; n++ {
@@ -104,7 +98,7 @@ func BenchmarkStorage_Hashes(b *testing.B) {
//
// goos: linux
// goarch: amd64
- // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_Hashes/1-16 29928834 41.76 ns/op 0 B/op 0 allocs/op
// BenchmarkStorage_Hashes/2-16 18693033 63.80 ns/op 0 B/op 0 allocs/op
@@ -121,7 +115,7 @@ func BenchmarkStorage_ResetHosts(b *testing.B) {
}
hostnames := strings.Join(hosts, "\n")
- s, err := hashstorage.New(hostnames)
+ s, err := hashprefix.NewStorage(hostnames)
require.NoError(b, err)
b.ReportAllocs()
@@ -136,7 +130,7 @@ func BenchmarkStorage_ResetHosts(b *testing.B) {
//
// goos: linux
// goarch: amd64
- // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkStorage_ResetHosts-16 2212 469343 ns/op 36224 B/op 1002 allocs/op
}
diff --git a/internal/filter/hashprefix_test.go b/internal/filter/hashprefix_test.go
deleted file mode 100644
index fac4d93..0000000
--- a/internal/filter/hashprefix_test.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package filter_test
-
-import (
- "context"
- "net"
- "os"
- "path/filepath"
- "testing"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
- "github.com/AdguardTeam/golibs/netutil"
- "github.com/AdguardTeam/golibs/testutil"
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
- cacheDir := t.TempDir()
- cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
- err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
- require.NoError(t, err)
-
- hashes, err := hashstorage.New("")
- require.NoError(t, err)
-
- errColl := &agdtest.ErrorCollector{
- OnCollect: func(_ context.Context, err error) {
- panic("not implemented")
- },
- }
-
- resolver := &agdtest.Resolver{
- OnLookupIP: func(
- _ context.Context,
- _ netutil.AddrFamily,
- _ string,
- ) (ips []net.IP, err error) {
- return []net.IP{safeBrowsingSafeIP4}, nil
- },
- }
-
- c := prepareConf(t)
-
- c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
- Hashes: hashes,
- ErrColl: errColl,
- Resolver: resolver,
- ID: agd.FilterListIDSafeBrowsing,
- CachePath: cachePath,
- ReplacementHost: safeBrowsingSafeHost,
- Staleness: 1 * time.Hour,
- CacheTTL: 10 * time.Second,
- CacheSize: 100,
- })
- require.NoError(t, err)
-
- c.ErrColl = errColl
- c.Resolver = resolver
-
- s, err := filter.NewDefaultStorage(c)
- require.NoError(t, err)
-
- g := &agd.FilteringGroup{
- ID: "default",
- RuleListIDs: []agd.FilterListID{},
- ParentalEnabled: true,
- SafeBrowsingEnabled: true,
- }
-
- // Test
-
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: safeBrowsingSubFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r)
-
- assert.Equal(t, rm.Rule, agd.FilterRuleText(safeBrowsingHost))
- assert.Equal(t, rm.List, agd.FilterListIDSafeBrowsing)
-}
diff --git a/internal/filter/internal/composite/composite.go b/internal/filter/internal/composite/composite.go
new file mode 100644
index 0000000..e480c6f
--- /dev/null
+++ b/internal/filter/internal/composite/composite.go
@@ -0,0 +1,369 @@
+// Package composite implements a composite filter based on several types of
+// filters and the logic of the filter application.
+package composite
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch"
+ "github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/urlfilter/rules"
+ "github.com/miekg/dns"
+)
+
+// Filter is a composite filter based on several types of safe-search and
+// rule-list filters.
+//
+// An empty composite filter is a filter that always returns a nil filtering
+// result.
+type Filter struct {
+ safeBrowsing *hashprefix.Filter
+ adultBlocking *hashprefix.Filter
+
+ genSafeSearch *safesearch.Filter
+ ytSafeSearch *safesearch.Filter
+
+ // custom is the custom rule-list filter of the profile, if any.
+ custom *rulelist.Immutable
+
+ // ruleLists are the enabled rule-list filters of the profile or filtering
+ // group.
+ ruleLists []*rulelist.Refreshable
+
+ // svcLists are the rule-list filters of the profile's enabled blocked
+ // services, if any.
+ svcLists []*rulelist.Immutable
+}
+
+// Config is the configuration structure for the composite filter.
+type Config struct {
+ // SafeBrowsing is the safe-browsing filter to apply, if any.
+ SafeBrowsing *hashprefix.Filter
+
+ // AdultBlocking is the adult-content filter to apply, if any.
+ AdultBlocking *hashprefix.Filter
+
+ // GeneralSafeSearch is the general safe-search filter to apply, if any.
+ GeneralSafeSearch *safesearch.Filter
+
+ // YouTubeSafeSearch is the youtube safe-search filter to apply, if any.
+ YouTubeSafeSearch *safesearch.Filter
+
+ // Custom is the custom rule-list filter of the profile, if any.
+ Custom *rulelist.Immutable
+
+ // RuleLists are the enabled rule-list filters of the profile or filtering
+ // group.
+ RuleLists []*rulelist.Refreshable
+
+ // ServiceLists are the rule-list filters of the profile's enabled blocked
+ // services, if any.
+ ServiceLists []*rulelist.Immutable
+}
+
+// New returns a new composite filter. If c is nil or empty, f returns a filter
+// that always returns a nil filtering result.
+func New(c *Config) (f *Filter) {
+ if c == nil {
+ return &Filter{}
+ }
+
+ return &Filter{
+ safeBrowsing: c.SafeBrowsing,
+ adultBlocking: c.AdultBlocking,
+ genSafeSearch: c.GeneralSafeSearch,
+ ytSafeSearch: c.YouTubeSafeSearch,
+ custom: c.Custom,
+ ruleLists: c.RuleLists,
+ svcLists: c.ServiceLists,
+ }
+}
+
+// type check
+var _ internal.Interface = (*Filter)(nil)
+
+// FilterRequest implements the [internal.Interface] interface for *Filter. If
+// there is a safe-search result, it returns it. Otherwise, it returns the
+// action created from the filter list network rule with the highest priority.
+// If f is empty, it returns nil with no error.
+func (f *Filter) FilterRequest(
+ ctx context.Context,
+ req *dns.Msg,
+ ri *agd.RequestInfo,
+) (r internal.Result, err error) {
+ if f.isEmpty() {
+ return nil, nil
+ }
+
+ // Prepare common data for filters.
+ reqID := ri.ID
+ log.Debug("filters: filtering req %s: %d rule lists", reqID, len(f.ruleLists))
+
+ // Firstly, check the profile's rule-list filtering, the custom rules, and
+ // the rules from blocked services settings.
+ host := ri.Host
+ rlRes := f.filterWithRuleLists(ri, host, ri.QType, req, false)
+ switch flRes := rlRes.(type) {
+ case *internal.ResultAllowed:
+ // Skip any additional filtering if the domain is explicitly allowed by
+ // user's custom rule.
+ if flRes.List == agd.FilterListIDCustom {
+ return flRes, nil
+ }
+ case *internal.ResultBlocked:
+ // Skip any additional filtering if the domain is already blocked.
+ return flRes, nil
+ default:
+ // Go on.
+ }
+
+ // Secondly, apply the safe browsing and safe search request filters in the
+ // following order.
+ //
+ // DO NOT change the order of reqFilters without necessity.
+ reqFilters := []struct {
+ filter internal.RequestFilter
+ id agd.FilterListID
+ }{{
+ filter: nullify(f.safeBrowsing),
+ id: agd.FilterListIDSafeBrowsing,
+ }, {
+ filter: nullify(f.adultBlocking),
+ id: agd.FilterListIDAdultBlocking,
+ }, {
+ filter: nullify(f.genSafeSearch),
+ id: agd.FilterListIDGeneralSafeSearch,
+ }, {
+ filter: nullify(f.ytSafeSearch),
+ id: agd.FilterListIDYoutubeSafeSearch,
+ }}
+
+ for _, rf := range reqFilters {
+ if rf.filter == nil {
+ continue
+ }
+
+ log.Debug("filter %s: filtering req %s", rf.id, reqID)
+ r, err = rf.filter.FilterRequest(ctx, req, ri)
+ log.Debug("filter %s: finished filtering req %s, errors: %v", rf.id, reqID, err)
+ if err != nil {
+ return nil, err
+ } else if r != nil {
+ return r, nil
+ }
+ }
+
+ // Thirdly, return the previously obtained filter list result.
+ return rlRes, nil
+}
+
+// nullify returns a nil interface value if flt is a nil pointer. Otherwise, it
+// returns flt converted to the interface type. It is used to avoid situations
+// where an interface value doesn't have any data but does have a type.
+func nullify[T *safesearch.Filter | *hashprefix.Filter](flt T) (fr internal.RequestFilter) {
+ if flt == nil {
+ return nil
+ }
+
+ return internal.RequestFilter(flt)
+}
+
+// FilterResponse implements the [internal.Interface] interface for *Filter. It
+// returns the action created from the filter list network rule with the highest
+// priority. If f is empty, it returns nil with no error.
+func (f *Filter) FilterResponse(
+ ctx context.Context,
+ resp *dns.Msg,
+ ri *agd.RequestInfo,
+) (r internal.Result, err error) {
+ if f.isEmpty() {
+ return nil, nil
+ }
+
+ for _, ans := range resp.Answer {
+ host, rrType, ok := parseRespAnswer(ans)
+ if !ok {
+ continue
+ }
+
+ r = f.filterWithRuleLists(ri, host, rrType, resp, true)
+ if r != nil {
+ break
+ }
+ }
+
+ return r, nil
+}
+
+// parseRespAnswer parses hostname and rrType from the answer if there are any.
+// If ans is of a type that doesn't have an IP address or a hostname in it, ok
+// is false.
+func parseRespAnswer(ans dns.RR) (hostname string, rrType dnsmsg.RRType, ok bool) {
+ switch ans := ans.(type) {
+ case *dns.A:
+ return ans.A.String(), dns.TypeA, true
+ case *dns.AAAA:
+ return ans.AAAA.String(), dns.TypeAAAA, true
+ case *dns.CNAME:
+ return strings.TrimSuffix(ans.Target, "."), dns.TypeCNAME, true
+ default:
+ return "", dns.TypeNone, false
+ }
+}
+
+// isEmpty returns true if this composite filter is an empty filter.
+func (f *Filter) isEmpty() (ok bool) {
+ return f == nil ||
+ (f.safeBrowsing == nil &&
+ f.adultBlocking == nil &&
+ f.genSafeSearch == nil &&
+ f.ytSafeSearch == nil &&
+ f.custom == nil &&
+ len(f.ruleLists) == 0 &&
+ len(f.svcLists) == 0)
+}
+
+// filterWithRuleLists filters one question's or answer's information through
+// all rule list filters of the composite filter.
+func (f *Filter) filterWithRuleLists(
+ ri *agd.RequestInfo,
+ host string,
+ rrType dnsmsg.RRType,
+ msg *dns.Msg,
+ isAnswer bool,
+) (r internal.Result) {
+ var devName string
+ if d := ri.Device; d != nil {
+ devName = string(d.Name)
+ }
+
+ ufRes := &urlFilterResult{}
+ for _, rl := range f.ruleLists {
+ ufRes.add(rl.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer))
+ }
+
+ if f.custom != nil {
+ dr := f.custom.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer)
+ // Collect only custom $dnsrewrite rules. It's much easier to process
+ // dnsrewrite rules only from one list, cause when there is no problem
+ // with merging them among different lists.
+ if !isAnswer {
+ modified := processDNSRewrites(ri.Messages, msg, dr.DNSRewrites(), host)
+ if modified != nil {
+ return modified
+ }
+ }
+
+ ufRes.add(dr)
+ }
+
+ for _, rl := range f.svcLists {
+ ufRes.add(rl.DNSResult(ri.RemoteIP, devName, host, rrType, isAnswer))
+ }
+
+ mr := rules.NewMatchingResult(ufRes.networkRules, nil)
+ if nr := mr.GetBasicResult(); nr != nil {
+ return f.ruleDataToResult(nr.FilterListID, nr.RuleText, nr.Whitelist)
+ }
+
+ return f.hostsRulesToResult(ufRes.hostRules4, ufRes.hostRules6, rrType)
+}
+
+// mustRuleListDataByURLFilterID returns the rule list data by its synthetic
+// integer ID in the urlfilter engine. It panics if id is not found.
+func (f *Filter) mustRuleListDataByURLFilterID(
+ id int,
+) (fltID agd.FilterListID, svcID agd.BlockedServiceID) {
+ for _, rl := range f.ruleLists {
+ if rl.URLFilterID() == id {
+ return rl.ID()
+ }
+ }
+
+ if rl := f.custom; rl != nil && rl.URLFilterID() == id {
+ return rl.ID()
+ }
+
+ for _, rl := range f.svcLists {
+ if rl.URLFilterID() == id {
+ return rl.ID()
+ }
+ }
+
+ // Technically shouldn't happen, since id is supposed to be among the rule
+ // list filters in the composite filter.
+ panic(fmt.Errorf("filter: synthetic id %d not found", id))
+}
+
+// hostsRulesToResult converts /etc/hosts-style rules into a filtering action.
+func (f *Filter) hostsRulesToResult(
+ hostRules4 []*rules.HostRule,
+ hostRules6 []*rules.HostRule,
+ rrType dnsmsg.RRType,
+) (r internal.Result) {
+ if len(hostRules4) == 0 && len(hostRules6) == 0 {
+ return nil
+ }
+
+ // Only use the first matched rule, since we currently don't care about the
+ // IP addresses in the rule. If the request is neither an A one nor an AAAA
+ // one, or if there are no matching rules of the requested type, then use
+ // whatever rule isn't empty.
+ //
+ // See also AGDNS-591.
+ var resHostRule *rules.HostRule
+ if rrType == dns.TypeA && len(hostRules4) > 0 {
+ resHostRule = hostRules4[0]
+ } else if rrType == dns.TypeAAAA && len(hostRules6) > 0 {
+ resHostRule = hostRules6[0]
+ } else {
+ if len(hostRules4) > 0 {
+ resHostRule = hostRules4[0]
+ } else {
+ resHostRule = hostRules6[0]
+ }
+ }
+
+ return f.ruleDataToResult(resHostRule.FilterListID, resHostRule.RuleText, false)
+}
+
+// ruleDataToResult converts a urlfilter rule data into a filtering result.
+func (f *Filter) ruleDataToResult(
+ urlFilterID int,
+ ruleText string,
+ allowlist bool,
+) (r internal.Result) {
+ // Use the urlFilterID crutch to find the actual IDs of the filtering rule
+ // list and blocked service.
+ fltID, svcID := f.mustRuleListDataByURLFilterID(urlFilterID)
+
+ var rule agd.FilterRuleText
+ if fltID == agd.FilterListIDBlockedService {
+ rule = agd.FilterRuleText(svcID)
+ } else {
+ rule = agd.FilterRuleText(ruleText)
+ }
+
+ if allowlist {
+ log.Debug("rule list %s: allowed by rule %s", fltID, rule)
+
+ return &internal.ResultAllowed{
+ List: fltID,
+ Rule: rule,
+ }
+ }
+
+ log.Debug("rule list %s: blocked by rule %s", fltID, rule)
+
+ return &internal.ResultBlocked{
+ List: fltID,
+ Rule: rule,
+ }
+}
diff --git a/internal/filter/internal/composite/composite_test.go b/internal/filter/internal/composite/composite_test.go
new file mode 100644
index 0000000..8e517d4
--- /dev/null
+++ b/internal/filter/internal/composite/composite_test.go
@@ -0,0 +1,521 @@
+package composite_test
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "net/netip"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/composite"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// Common filter list IDs for tests.
+const (
+ testFltListID1 agd.FilterListID = "fl1"
+ testFltListID2 agd.FilterListID = "fl2"
+)
+
+// newFromStr is a helper to create a rule-list filter from a rule text and a
+// filtering-list ID.
+func newFromStr(tb testing.TB, text string, id agd.FilterListID) (rl *rulelist.Refreshable) {
+ tb.Helper()
+
+ rl, err := rulelist.NewFromString(text, id, "", 0, false)
+ require.NoError(tb, err)
+
+ return rl
+}
+
+// newImmutable is a helper to create an immutable rule-list filter from a rule
+// text and a filtering-list ID.
+func newImmutable(tb testing.TB, text string, id agd.FilterListID) (rl *rulelist.Immutable) {
+ tb.Helper()
+
+ rl, err := rulelist.NewImmutable(text, id, "", 0, false)
+ require.NoError(tb, err)
+
+ return rl
+}
+
+// newReqData returns data for calling FilterRequest. The context uses
+// [filtertest.Timeout] and [tb.Cleanup] is used for its cancellation. Both req
+// and ri use [filtertest.ReqFQDN], [dns.TypeA], and [dns.ClassINET] for the
+// request data.
+func newReqData(tb testing.TB) (ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ tb.Cleanup(cancel)
+
+ req = dnsservertest.NewReq(filtertest.ReqFQDN, dns.TypeA, dns.ClassINET)
+ ri = &agd.RequestInfo{
+ Messages: agdtest.NewConstructor(),
+ Host: filtertest.ReqHost,
+ QType: dns.TypeA,
+ QClass: dns.ClassINET,
+ }
+
+ return ctx, req, ri
+}
+
+func TestFilter_nil(t *testing.T) {
+ testCases := []struct {
+ flt *composite.Filter
+ name string
+ }{{
+ flt: nil,
+ name: "nil",
+ }, {
+ flt: composite.New(nil),
+ name: "config_nil",
+ }, {
+ flt: composite.New(&composite.Config{}),
+ name: "config_empty",
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx, req, ri := newReqData(t)
+ res, err := tc.flt.FilterRequest(ctx, req, ri)
+ assert.NoError(t, err)
+ assert.Nil(t, res)
+
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, req)
+ res, err = tc.flt.FilterResponse(ctx, resp, ri)
+ assert.NoError(t, err)
+ assert.Nil(t, res)
+ })
+ }
+}
+
+func TestFilter_FilterRequest_badfilter(t *testing.T) {
+ const (
+ blockRule = filtertest.BlockRule
+ badFilterRule = filtertest.BlockRule + "$badfilter"
+ )
+
+ rl1 := newFromStr(t, blockRule, testFltListID1)
+ rl2 := newFromStr(t, badFilterRule, testFltListID2)
+
+ testCases := []struct {
+ name string
+ wantRes internal.Result
+ ruleLists []*rulelist.Refreshable
+ }{{
+ name: "block",
+ wantRes: &internal.ResultBlocked{
+ List: testFltListID1,
+ Rule: blockRule,
+ },
+ ruleLists: []*rulelist.Refreshable{rl1},
+ }, {
+ name: "badfilter_no_block",
+ wantRes: nil,
+ ruleLists: []*rulelist.Refreshable{rl2},
+ }, {
+ name: "badfilter_removes_block",
+ wantRes: nil,
+ ruleLists: []*rulelist.Refreshable{rl1, rl2},
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ f := composite.New(&composite.Config{
+ RuleLists: tc.ruleLists,
+ })
+
+ ctx, req, ri := newReqData(t)
+ res, err := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.wantRes, res)
+ })
+ }
+}
+
+func TestFilter_FilterRequest_customAllow(t *testing.T) {
+ const allowRule = "@@" + filtertest.BlockRule
+
+ blockingRL := newFromStr(t, filtertest.BlockRule, testFltListID1)
+ customRL := newImmutable(t, allowRule, agd.FilterListIDCustom)
+
+ f := composite.New(&composite.Config{
+ Custom: customRL,
+ RuleLists: []*rulelist.Refreshable{blockingRL},
+ })
+
+ ctx, req, ri := newReqData(t)
+ res, err := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ want := &internal.ResultAllowed{
+ List: agd.FilterListIDCustom,
+ Rule: allowRule,
+ }
+ assert.Equal(t, want, res)
+}
+
+func TestFilter_FilterRequest_dnsrewrite(t *testing.T) {
+ const (
+ blockRule = filtertest.BlockRule
+ dnsRewriteRuleRefused = filtertest.BlockRule + "$dnsrewrite=REFUSED"
+ dnsRewriteRuleCname = filtertest.BlockRule + "$dnsrewrite=new-cname.example"
+ dnsRewrite2Rules = filtertest.BlockRule + "$dnsrewrite=1.2.3.4\n" +
+ filtertest.BlockRule + "$dnsrewrite=1.2.3.5"
+ dnsRewriteRuleTXT = filtertest.BlockRule + "$dnsrewrite=NOERROR;TXT;abcdefg"
+ )
+
+ var (
+ rlNonRewrite = newFromStr(t, blockRule, testFltListID1)
+ rlRewriteIgnored = newFromStr(t, dnsRewriteRuleRefused, testFltListID2)
+ rlCustomRefused = newImmutable(t, dnsRewriteRuleRefused, agd.FilterListIDCustom)
+ rlCustomCname = newImmutable(t, dnsRewriteRuleCname, agd.FilterListIDCustom)
+ rlCustom2Rules = newImmutable(t, dnsRewrite2Rules, agd.FilterListIDCustom)
+ )
+
+ req := dnsservertest.NewReq(filtertest.ReqFQDN, dns.TypeA, dns.ClassINET)
+
+ // Create a CNAME-modified request.
+ modifiedReq := dnsmsg.Clone(req)
+ modifiedReq.Question[0].Name = "new-cname.example."
+
+ testCases := []struct {
+ custom *rulelist.Immutable
+ wantRes internal.Result
+ name string
+ ruleLists []*rulelist.Refreshable
+ }{{
+ custom: nil,
+ wantRes: &internal.ResultBlocked{List: testFltListID1, Rule: blockRule},
+ name: "block",
+ ruleLists: []*rulelist.Refreshable{rlNonRewrite},
+ }, {
+ custom: nil,
+ wantRes: &internal.ResultBlocked{List: testFltListID1, Rule: blockRule},
+ name: "dnsrewrite_no_effect",
+ ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored},
+ }, {
+ custom: rlCustomRefused,
+ wantRes: &internal.ResultModified{
+ Msg: dnsservertest.NewResp(dns.RcodeRefused, req),
+ List: agd.FilterListIDCustom,
+ Rule: dnsRewriteRuleRefused,
+ },
+ name: "dnsrewrite_block",
+ ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored},
+ }, {
+ custom: rlCustomCname,
+ wantRes: &internal.ResultModified{
+ Msg: modifiedReq,
+ List: agd.FilterListIDCustom,
+ Rule: dnsRewriteRuleCname,
+ },
+ name: "dnsrewrite_cname",
+ ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored},
+ }, {
+ custom: rlCustom2Rules,
+ wantRes: &internal.ResultModified{
+ Msg: dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(
+ filtertest.ReqFQDN,
+ agdtest.FilteredResponseTTLSec,
+ net.IP{1, 2, 3, 4},
+ ),
+ dnsservertest.NewA(
+ filtertest.ReqFQDN,
+ agdtest.FilteredResponseTTLSec,
+ net.IP{1, 2, 3, 5},
+ ),
+ }),
+ List: agd.FilterListIDCustom,
+ Rule: "",
+ },
+ name: "dnsrewrite_answers",
+ ruleLists: []*rulelist.Refreshable{rlNonRewrite, rlRewriteIgnored},
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ f := composite.New(&composite.Config{
+ Custom: tc.custom,
+ RuleLists: tc.ruleLists,
+ })
+
+ ctx := context.Background()
+ ri := &agd.RequestInfo{
+ Messages: agdtest.NewConstructor(),
+ Host: filtertest.ReqHost,
+ QType: dns.TypeA,
+ }
+
+ res, fltErr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ assert.Equal(t, tc.wantRes, res)
+ })
+ }
+}
+
+func TestFilter_FilterRequest_hostsRules(t *testing.T) {
+ const (
+ reqHost4 = "www.example.com"
+ reqHost6 = "www.example.net"
+ )
+
+ const (
+ blockRule4 = "127.0.0.1 www.example.com"
+ blockRule6 = "::1 www.example.net"
+ rules = blockRule4 + "\n" + blockRule6
+ )
+
+ rl := newFromStr(t, rules, testFltListID1)
+ f := composite.New(&composite.Config{
+ RuleLists: []*rulelist.Refreshable{rl},
+ })
+
+ resBlocked4 := &internal.ResultBlocked{
+ List: testFltListID1,
+ Rule: blockRule4,
+ }
+
+ resBlocked6 := &internal.ResultBlocked{
+ List: testFltListID1,
+ Rule: blockRule6,
+ }
+
+ testCases := []struct {
+ wantRes internal.Result
+ name string
+ reqHost string
+ reqType dnsmsg.RRType
+ }{{
+ wantRes: resBlocked4,
+ name: "a",
+ reqHost: reqHost4,
+ reqType: dns.TypeA,
+ }, {
+ wantRes: resBlocked6,
+ name: "aaaa",
+ reqHost: reqHost6,
+ reqType: dns.TypeAAAA,
+ }, {
+ wantRes: resBlocked6,
+ name: "a_with_ipv6_rule",
+ reqHost: reqHost6,
+ reqType: dns.TypeA,
+ }, {
+ wantRes: resBlocked4,
+ name: "aaaa_with_ipv4_rule",
+ reqHost: reqHost4,
+ reqType: dns.TypeAAAA,
+ }, {
+ wantRes: resBlocked4,
+ name: "mx",
+ reqHost: reqHost4,
+ reqType: dns.TypeMX,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ri := &agd.RequestInfo{
+ Messages: agdtest.NewConstructor(),
+ Host: tc.reqHost,
+ QType: tc.reqType,
+ }
+
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: dns.Fqdn(tc.reqHost),
+ Qtype: tc.reqType,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ctx := context.Background()
+
+ res, rerr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, rerr)
+
+ assert.Equal(t, tc.wantRes, res)
+ assert.Equal(t, tc.wantRes, res)
+ })
+ }
+}
+
+func TestFilter_FilterRequest_safeSearch(t *testing.T) {
+ const safeSearchIPStr = "1.2.3.4"
+
+ const rewriteRule = filtertest.BlockRule + "$dnsrewrite=NOERROR;A;" + safeSearchIPStr
+
+ var safeSearchIP net.IP = netip.MustParseAddr(safeSearchIPStr).AsSlice()
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, rewriteRule, http.StatusOK)
+
+ const fltListID = agd.FilterListIDGeneralSafeSearch
+
+ gen := safesearch.New(&safesearch.Config{
+ List: &agd.FilterList{
+ URL: srvURL,
+ ID: fltListID,
+ },
+ Resolver: &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ return []net.IP{safeSearchIP}, nil
+ },
+ },
+ ErrColl: &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, _ error) {
+ panic("not implemented")
+ },
+ },
+ CacheDir: filepath.Dir(cachePath),
+ CacheTTL: 1 * time.Minute,
+ CacheSize: 100,
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ err := gen.Refresh(ctx, false)
+ require.NoError(t, err)
+
+ f := composite.New(&composite.Config{
+ GeneralSafeSearch: gen,
+ })
+
+ ctx, req, ri := newReqData(t)
+ res, err := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ wantResp := dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{
+ dnsservertest.NewA(filtertest.ReqFQDN, agdtest.FilteredResponseTTLSec, safeSearchIP),
+ })
+ want := &internal.ResultModified{
+ Msg: wantResp,
+ List: fltListID,
+ Rule: filtertest.ReqHost,
+ }
+ assert.Equal(t, want, res)
+}
+
+func TestFilter_FilterRequest_services(t *testing.T) {
+ const svcID = "test_service"
+
+ svcRL, err := rulelist.NewImmutable(
+ filtertest.BlockRule,
+ agd.FilterListIDBlockedService,
+ svcID,
+ 0,
+ false,
+ )
+ require.NoError(t, err)
+
+ f := composite.New(&composite.Config{
+ ServiceLists: []*rulelist.Immutable{svcRL},
+ })
+
+ ctx, req, ri := newReqData(t)
+ res, err := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ want := &internal.ResultBlocked{
+ List: agd.FilterListIDBlockedService,
+ Rule: svcID,
+ }
+ assert.Equal(t, want, res)
+}
+
+func TestFilter_FilterResponse(t *testing.T) {
+ const cnameReqFQDN = "sub." + filtertest.ReqFQDN
+
+ const (
+ blockedCNAME = filtertest.ReqHost
+ blockedIPv4Str = "1.2.3.4"
+ blockedIPv6Str = "1234::cdef"
+ blockRules = blockedCNAME + "\n" + blockedIPv4Str + "\n" + blockedIPv6Str + "\n"
+ )
+
+ var (
+ blockedIPv4 net.IP = netip.MustParseAddr(blockedIPv4Str).AsSlice()
+ blockedIPv6 net.IP = netip.MustParseAddr(blockedIPv6Str).AsSlice()
+ )
+
+ blockingRL := newFromStr(t, blockRules, testFltListID1)
+ f := composite.New(&composite.Config{
+ RuleLists: []*rulelist.Refreshable{blockingRL},
+ })
+
+ const ttl = agdtest.FilteredResponseTTLSec
+
+ testCases := []struct {
+ name string
+ reqFQDN string
+ wantRule agd.FilterRuleText
+ respAns dnsservertest.SectionAnswer
+ qType dnsmsg.RRType
+ }{{
+ name: "cname",
+ reqFQDN: cnameReqFQDN,
+ wantRule: filtertest.ReqHost,
+ respAns: dnsservertest.SectionAnswer{
+ dnsservertest.NewCNAME(cnameReqFQDN, ttl, filtertest.ReqFQDN),
+ dnsservertest.NewA(filtertest.ReqFQDN, ttl, net.IP{1, 2, 3, 4}),
+ },
+ qType: dns.TypeA,
+ }, {
+ name: "ipv4",
+ reqFQDN: filtertest.ReqFQDN,
+ wantRule: blockedIPv4Str,
+ respAns: dnsservertest.SectionAnswer{
+ dnsservertest.NewA(filtertest.ReqFQDN, ttl, blockedIPv4),
+ },
+ qType: dns.TypeA,
+ }, {
+ name: "ipv6",
+ reqFQDN: filtertest.ReqFQDN,
+ wantRule: blockedIPv6Str,
+ respAns: dnsservertest.SectionAnswer{
+ dnsservertest.NewAAAA(filtertest.ReqFQDN, ttl, blockedIPv6),
+ },
+ qType: dns.TypeAAAA,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx, req, ri := newReqData(t)
+ req.Question[0].Name = tc.reqFQDN
+ req.Question[0].Qtype = tc.qType
+
+ resp := dnsservertest.NewResp(dns.RcodeSuccess, req, tc.respAns)
+ res, err := f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ want := &internal.ResultBlocked{
+ List: testFltListID1,
+ Rule: tc.wantRule,
+ }
+ assert.Equal(t, want, res)
+ })
+ }
+}
diff --git a/internal/filter/internal/composite/dnsresult.go b/internal/filter/internal/composite/dnsresult.go
new file mode 100644
index 0000000..7d4cd27
--- /dev/null
+++ b/internal/filter/internal/composite/dnsresult.go
@@ -0,0 +1,27 @@
+package composite
+
+import (
+ "github.com/AdguardTeam/urlfilter"
+ "github.com/AdguardTeam/urlfilter/rules"
+)
+
+// urlFilterResult is an entity simplifying the collection and compilation of
+// urlfilter results.
+//
+// TODO(a.garipov): Think of ways to move all urlfilter result processing to
+// ./internal/rulelist.
+type urlFilterResult struct {
+ networkRules []*rules.NetworkRule
+ hostRules4 []*rules.HostRule
+ hostRules6 []*rules.HostRule
+}
+
+// add appends the rules from dr to the slices within r. If dr is nil, add does
+// nothing.
+func (r *urlFilterResult) add(dr *urlfilter.DNSResult) {
+ if dr != nil {
+ r.networkRules = append(r.networkRules, dr.NetworkRules...)
+ r.hostRules4 = append(r.hostRules4, dr.HostRulesV4...)
+ r.hostRules6 = append(r.hostRules6, dr.HostRulesV6...)
+ }
+}
diff --git a/internal/filter/dnsrewrite.go b/internal/filter/internal/composite/dnsrewrite.go
similarity index 84%
rename from internal/filter/dnsrewrite.go
rename to internal/filter/internal/composite/dnsrewrite.go
index df5126c..8b30f45 100644
--- a/internal/filter/dnsrewrite.go
+++ b/internal/filter/internal/composite/dnsrewrite.go
@@ -1,11 +1,13 @@
-package filter
+package composite
import (
"fmt"
"net"
+ "strings"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
@@ -13,13 +15,13 @@ import (
)
// processDNSRewrites processes $dnsrewrite rules dnsr and creates a filtering
-// result, if necessary.
+// result, if necessary. res.List, if any, is set to [agd.FilterListIDCustom].
func processDNSRewrites(
messages *dnsmsg.Constructor,
req *dns.Msg,
dnsr []*rules.NetworkRule,
host string,
-) (res *ResultModified) {
+) (res *internal.ResultModified) {
if len(dnsr) == 0 {
return nil
}
@@ -27,16 +29,18 @@ func processDNSRewrites(
dnsRewriteResult := processDNSRewriteRules(dnsr)
if resCanonName := dnsRewriteResult.CanonName; resCanonName != "" {
- if resCanonName == host {
+ // Rewrite the question name to a matched CNAME.
+ if strings.EqualFold(resCanonName, host) {
// A rewrite of a host to itself.
return nil
}
- resp := messages.NewRespMsg(req)
- resp.Answer = append(resp.Answer, messages.NewAnswerCNAME(req, resCanonName))
+ req = dnsmsg.Clone(req)
+ req.Question[0].Name = dns.Fqdn(resCanonName)
- return &ResultModified{
- Msg: resp,
+ return &internal.ResultModified{
+ Msg: req,
+ List: agd.FilterListIDCustom,
Rule: dnsRewriteResult.ResRuleText,
}
}
@@ -45,8 +49,9 @@ func processDNSRewrites(
resp := messages.NewRespMsg(req)
resp.Rcode = dnsRewriteResult.RCode
- return &ResultModified{
+ return &internal.ResultModified{
Msg: resp,
+ List: agd.FilterListIDCustom,
Rule: dnsRewriteResult.ResRuleText,
}
}
@@ -56,36 +61,37 @@ func processDNSRewrites(
return nil
}
- return &ResultModified{
- Msg: resp,
+ return &internal.ResultModified{
+ Msg: resp,
+ List: agd.FilterListIDCustom,
}
}
-// DNSRewriteResult is the result of application of $dnsrewrite rules.
-type DNSRewriteResult struct {
- Response DNSRewriteResultResponse
+// dnsRewriteResult is the result of application of $dnsrewrite rules.
+type dnsRewriteResult struct {
+ Response dnsRewriteResultResponse
CanonName string
ResRuleText agd.FilterRuleText
Rules []*rules.NetworkRule
RCode rules.RCode
}
-// DNSRewriteResultResponse is the collection of DNS response records
+// dnsRewriteResultResponse is the collection of DNS response records
// the server returns.
-type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
+type dnsRewriteResultResponse map[rules.RRType][]rules.RRValue
// processDNSRewriteRules processes DNS rewrite rules in dnsr. The result will
// have either CanonName or RCode or Response set.
-func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) {
- dnsrr := &DNSRewriteResult{
- Response: DNSRewriteResultResponse{},
+func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *dnsRewriteResult) {
+ dnsrr := &dnsRewriteResult{
+ Response: dnsRewriteResultResponse{},
}
for _, rule := range dnsr {
dr := rule.DNSRewrite
if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than other rules.
- return &DNSRewriteResult{
+ return &dnsRewriteResult{
ResRuleText: agd.FilterRuleText(rule.RuleText),
CanonName: dr.NewCNAME,
}
@@ -99,7 +105,7 @@ func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) {
default:
// RcodeRefused and other such codes have higher priority. Return
// immediately.
- return &DNSRewriteResult{
+ return &dnsRewriteResult{
ResRuleText: agd.FilterRuleText(rule.RuleText),
RCode: dr.RCode,
}
@@ -114,7 +120,7 @@ func processDNSRewriteRules(dnsr []*rules.NetworkRule) (res *DNSRewriteResult) {
func filterDNSRewrite(
messages *dnsmsg.Constructor,
req *dns.Msg,
- dnsrr *DNSRewriteResult,
+ dnsrr *dnsRewriteResult,
) (resp *dns.Msg, err error) {
if dnsrr.RCode != dns.RcodeSuccess {
return nil, errors.Error("non-success answer")
diff --git a/internal/filter/custom.go b/internal/filter/internal/custom/custom.go
similarity index 61%
rename from internal/filter/custom.go
rename to internal/filter/internal/custom/custom.go
index ac4ce24..eccb4e9 100644
--- a/internal/filter/custom.go
+++ b/internal/filter/internal/custom/custom.go
@@ -1,4 +1,6 @@
-package filter
+// Package custom contains the caching storage of filters made from custom
+// filtering rules of profiles.
+package custom
import (
"context"
@@ -7,51 +9,50 @@ import (
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
- "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/bluele/gcache"
)
-// Custom Filters For Profiles
-
-// customFilters contains custom filters made from custom filtering rules of
-// profiles.
-type customFilters struct {
+// Filters contains custom filters made from custom filtering rules of profiles.
+type Filters struct {
cache gcache.Cache
errColl agd.ErrorCollector
}
-// appendRuleLists appends the custom rule list filter made from the profile's
-// custom rules list, if any, to rls.
-func (f *customFilters) appendRuleLists(
- ctx context.Context,
- rls []*ruleListFilter,
- p *agd.Profile,
-) (res []*ruleListFilter) {
+// New returns a new custom filter storage.
+func New(cache gcache.Cache, errColl agd.ErrorCollector) (f *Filters) {
+ return &Filters{
+ cache: cache,
+ errColl: errColl,
+ }
+}
+
+// Get returns the custom rule-list filter made from the profile's custom rules
+// list, if any.
+func (f *Filters) Get(ctx context.Context, p *agd.Profile) (rl *rulelist.Immutable) {
if len(p.CustomRules) == 0 {
// Technically, there could be an old filter left in the cache, but it
// will eventually be evicted, so don't do anything about it.
- return rls
+ return nil
}
- optlog.Debug2("%s: compiling custom filter for profile %s", strgLogPrefix, p.ID)
- defer optlog.Debug2("%s: finished compiling custom filter for profile %s", strgLogPrefix, p.ID)
-
// Report the custom filters cache lookup to prometheus so that we could
// keep track of whether the cache size is enough.
defer func() {
- if rls == nil {
+ if rl == nil {
metrics.FilterCustomCacheLookupsMisses.Inc()
} else {
metrics.FilterCustomCacheLookupsHits.Inc()
}
}()
- rl := f.get(p)
+ rl = f.get(p)
if rl != nil {
- return append(rls, rl)
+ return rl
}
// TODO(a.garipov): Consider making a copy of strings.Join for
@@ -68,8 +69,17 @@ func (f *customFilters) appendRuleLists(
stringutil.WriteToBuilder(b, string(r), "\n")
}
- // Don't use cache for users' custom filters.
- rl, err := newRuleListFltFromStr(b.String(), agd.FilterListIDCustom, string(p.ID), 0, false)
+ rl, err := rulelist.NewImmutable(
+ b.String(),
+ agd.FilterListIDCustom,
+ "",
+ // Don't use cache for users' custom filters, because resultcache
+ // doesn't take $client rules into account.
+ //
+ // TODO(a.garipov): Consider enabling caching if necessary.
+ 0,
+ false,
+ )
if err != nil {
// In a rare situation where the custom rules are so badly formed that
// we cannot even create a filtering engine, consider that there is no
@@ -80,14 +90,16 @@ func (f *customFilters) appendRuleLists(
return nil
}
+ log.Info("%s/%s: got %d rules", agd.FilterListIDCustom, p.ID, rl.RulesCount())
+
f.set(p, rl)
- return append(rls, rl)
+ return rl
}
-// get returns the cached custom rule list filter, if there is one and the
+// get returns the cached custom rule-list filter, if there is one and the
// profile hasn't changed since the filter was cached.
-func (f *customFilters) get(p *agd.Profile) (rl *ruleListFilter) {
+func (f *Filters) get(p *agd.Profile) (rl *rulelist.Immutable) {
itemVal, err := f.cache.Get(p.ID)
if errors.Is(err, gcache.KeyNotFoundError) {
return nil
@@ -104,8 +116,8 @@ func (f *customFilters) get(p *agd.Profile) (rl *ruleListFilter) {
return item.ruleList
}
-// set caches the custom rule list filter.
-func (f *customFilters) set(p *agd.Profile, rl *ruleListFilter) {
+// set caches the custom rule-list filter.
+func (f *Filters) set(p *agd.Profile, rl *rulelist.Immutable) {
item := &customFilterCacheItem{
updTime: p.UpdateTime,
ruleList: rl,
@@ -121,5 +133,5 @@ func (f *customFilters) set(p *agd.Profile, rl *ruleListFilter) {
// customFilterCacheItem is an item of the custom filter cache.
type customFilterCacheItem struct {
updTime time.Time
- ruleList *ruleListFilter
+ ruleList *rulelist.Immutable
}
diff --git a/internal/filter/internal/custom/custom_test.go b/internal/filter/internal/custom/custom_test.go
new file mode 100644
index 0000000..136a5c6
--- /dev/null
+++ b/internal/filter/internal/custom/custom_test.go
@@ -0,0 +1,102 @@
+package custom_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/bluele/gcache"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// testProfID is the profile ID for tests.
+const testProfID agd.ProfileID = "prof1234"
+
+func TestFilters_Get(t *testing.T) {
+ f := custom.New(
+ gcache.New(1).LRU().Build(),
+ &agdtest.ErrorCollector{
+ OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
+ },
+ )
+
+ p := &agd.Profile{
+ ID: testProfID,
+ UpdateTime: time.Now(),
+ CustomRules: []agd.FilterRuleText{
+ "||first.example",
+ },
+ }
+
+ ctx := context.Background()
+
+ rl := f.Get(ctx, p)
+ require.NotNil(t, rl)
+
+ // Recheck cached.
+ cachedRL := f.Get(ctx, p)
+ require.NotNil(t, cachedRL)
+
+ assert.Same(t, rl, cachedRL)
+}
+
+var ruleListSink *rulelist.Immutable
+
+func BenchmarkFilters_Get(b *testing.B) {
+ f := custom.New(
+ gcache.New(1).LRU().Build(),
+ &agdtest.ErrorCollector{
+ OnCollect: func(ctx context.Context, err error) { panic("not implemented") },
+ },
+ )
+
+ p := &agd.Profile{
+ ID: testProfID,
+ UpdateTime: time.Now(),
+ CustomRules: []agd.FilterRuleText{
+ "||first.example",
+ "||second.example",
+ "||third.example",
+ },
+ }
+
+ ctx := context.Background()
+
+ b.Run("cache", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ruleListSink = f.Get(ctx, p)
+ }
+ })
+
+ b.Run("no_cache", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Update the time on each iteration to make sure that the cache is
+ // never used.
+ p.UpdateTime = p.UpdateTime.Add(1 * time.Millisecond)
+ ruleListSink = f.Get(ctx, p)
+ }
+ })
+
+ // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkFilters_Get/cache-16 7870251 233.4 ns/op 16 B/op 1 allocs/op
+ // BenchmarkFilters_Get/no_cache-16 53073 23490 ns/op 14610 B/op 93 allocs/op
+}
diff --git a/internal/filter/internal/filtertest/filtertest.go b/internal/filter/internal/filtertest/filtertest.go
new file mode 100644
index 0000000..61082cc
--- /dev/null
+++ b/internal/filter/internal/filtertest/filtertest.go
@@ -0,0 +1,80 @@
+// Package filtertest contains common constants and utilities for the internal
+// filtering packages.
+package filtertest
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/netip"
+ "net/url"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/golibs/httphdr"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/require"
+)
+
+// BlockRule is the common blocking rule for filtering tests that blocks
+// [ReqHost].
+const BlockRule = "|" + ReqHost + "^"
+
+// RemoteIP is the common client IP for filtering tests
+var RemoteIP = netip.MustParseAddr("1.2.3.4")
+
+// ReqHost is the common request host for filtering tests.
+const ReqHost = "www.host.example"
+
+// ReqFQDN is the common request FQDN for filtering tests.
+const ReqFQDN = ReqHost + "."
+
+// ServerName is the common server name for filtering tests.
+const ServerName = "testServer/1.0"
+
+// Timeout is the common timeout for filtering tests.
+const Timeout = 1 * time.Second
+
+// PrepareRefreshable launches an HTTP server serving the given text and code,
+// as well as creates a cache file. If code is zero, the server isn't started.
+// If reqCh not nil, a signal is sent every time the server is called. The
+// server uses [ServerName] as the value of the Server header.
+func PrepareRefreshable(
+ tb testing.TB,
+ reqCh chan<- struct{},
+ text string,
+ code int,
+) (cachePath string, srvURL *url.URL) {
+ tb.Helper()
+
+ if code != 0 {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ pt := testutil.PanicT{}
+ if reqCh != nil {
+ testutil.RequireSend(pt, reqCh, struct{}{}, Timeout)
+ }
+
+ w.Header().Set(httphdr.Server, ServerName)
+
+ w.WriteHeader(code)
+
+ _, writeErr := io.WriteString(w, text)
+ require.NoError(pt, writeErr)
+ }))
+ tb.Cleanup(srv.Close)
+
+ var err error
+ srvURL, err = agdhttp.ParseHTTPURL(srv.URL)
+ require.NoError(tb, err)
+ }
+
+ cacheDir := tb.TempDir()
+ cacheFile, err := os.CreateTemp(cacheDir, filepath.Base(tb.Name()))
+ require.NoError(tb, err)
+ require.NoError(tb, cacheFile.Close())
+
+ return cacheFile.Name(), srvURL
+}
diff --git a/internal/filter/internal/internal.go b/internal/filter/internal/internal.go
index 6565cd8..1db7315 100644
--- a/internal/filter/internal/internal.go
+++ b/internal/filter/internal/internal.go
@@ -3,3 +3,48 @@
//
// TODO(a.garipov): Move more code to subpackages, see AGDNS-824.
package internal
+
+import (
+ "context"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/c2h5oh/datasize"
+ "github.com/miekg/dns"
+)
+
+// Make sure that the signatures for FilterRequest match.
+var _ RequestFilter = (Interface)(nil)
+
+// Interface is the DNS request and response filter interface.
+type Interface interface {
+ // FilterRequest filters the DNS request for the provided client. All
+ // parameters must be non-nil. req must have exactly one question. If a is
+ // nil, the request doesn't match any of the rules.
+ FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error)
+
+ // FilterResponse filters the DNS response for the provided client. All
+ // parameters must be non-nil. If a is nil, the response doesn't match any
+ // of the rules.
+ FilterResponse(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) (r Result, err error)
+}
+
+// maxFilterSize is the maximum size of downloaded filters.
+const maxFilterSize = 256 * int64(datasize.MB)
+
+// DefaultFilterRefreshTimeout is the default timeout to use when fetching
+// filter lists data.
+//
+// TODO(a.garipov): Consider making timeouts where they are used configurable.
+const DefaultFilterRefreshTimeout = 3 * time.Minute
+
+// DefaultResolveTimeout is the default timeout for resolving hosts for
+// safe-search and safe-browsing filters.
+//
+// TODO(ameshkov): Consider making configurable.
+const DefaultResolveTimeout = 1 * time.Second
+
+// RequestFilter can filter a request based on the request info.
+type RequestFilter interface {
+ FilterRequest(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (r Result, err error)
+}
diff --git a/internal/filter/internal/internal_test.go b/internal/filter/internal/internal_test.go
new file mode 100644
index 0000000..c76bb5b
--- /dev/null
+++ b/internal/filter/internal/internal_test.go
@@ -0,0 +1,11 @@
+package internal_test
+
+import (
+ "testing"
+
+ "github.com/AdguardTeam/golibs/testutil"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
diff --git a/internal/filter/refrfilter.go b/internal/filter/internal/refreshable.go
similarity index 58%
rename from internal/filter/refrfilter.go
rename to internal/filter/internal/refreshable.go
index f776266..57bc5f7 100644
--- a/internal/filter/refrfilter.go
+++ b/internal/filter/internal/refreshable.go
@@ -1,4 +1,4 @@
-package filter
+package internal
import (
"context"
@@ -16,88 +16,79 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdio"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio"
)
-// Refreshable Filter
-
-// refreshableFilter contains entities common to filters that can refresh
-// themselves from a file and a URL.
-type refreshableFilter struct {
+// Refreshable contains entities common to filters that can refresh themselves
+// from a file and a URL.
+type Refreshable struct {
// http is the HTTP client used to refresh the filter.
http *agdhttp.Client
// url is the URL used to refresh the filter.
url *url.URL
- // resetRules is the function that compiles the rules and replaces the
- // filtering engine. It must not be nil and should return an informative
- // error.
- resetRules func(text string) (err error)
-
- // id is the identifier of the filter.
+ // id is the filter list ID, if any.
id agd.FilterListID
// cachePath is the path to the file containing the cached filter rules.
cachePath string
- // typ is the type of this filter used for logging and error reporting.
- typ string
-
// staleness is the time after which a file is considered stale.
staleness time.Duration
}
-// refresh reloads the filter data. If acceptStale is true, refresh doesn't try
+// NewRefreshable returns a new refreshable filter. All parameters must be
+// non-zero.
+func NewRefreshable(l *agd.FilterList, cachePath string) (f *Refreshable) {
+ return &Refreshable{
+ http: agdhttp.NewClient(&agdhttp.ClientConfig{
+ Timeout: DefaultFilterRefreshTimeout,
+ }),
+ url: l.URL,
+ id: l.ID,
+ cachePath: cachePath,
+ staleness: l.RefreshIvl,
+ }
+}
+
+// Refresh reloads the filter data. If acceptStale is true, refresh doesn't try
// to load the filter data from its URL when there is already a file in the
// cache directory, regardless of its staleness.
-func (f *refreshableFilter) refresh(
+func (f *Refreshable) Refresh(
ctx context.Context,
acceptStale bool,
-) (err error) {
- // TODO(a.garipov): Consider adding a helper for enriching errors with
- // context deadline data.
- start := time.Now()
- deadline, hasDeadline := ctx.Deadline()
+) (text string, err error) {
+ now := time.Now()
- defer func() {
- if err != nil && hasDeadline {
- err = fmt.Errorf("started refresh at %s, deadline at %s: %w", start, deadline, err)
- }
+ defer func() { err = errors.Annotate(err, "%s: %w", f.id) }()
- err = errors.Annotate(err, "%s %q: %w", f.typ, f.id)
- }()
-
- text, err := f.refreshFromFile(acceptStale)
+ text, err = f.refreshFromFile(acceptStale, now)
if err != nil {
- return fmt.Errorf("refreshing from file %q: %w", f.cachePath, err)
+ return "", fmt.Errorf("refreshing from file %q: %w", f.cachePath, err)
}
if text == "" {
- log.Info("filter %s: refreshing from url %q", f.id, f.url)
+ log.Info("%s: refreshing from url %q", f.id, f.url)
- text, err = f.refreshFromURL(ctx)
+ text, err = f.refreshFromURL(ctx, now)
if err != nil {
- return fmt.Errorf("refreshing from url %q: %w", f.url, err)
+ return "", fmt.Errorf("refreshing from url %q: %w", f.url, err)
}
}
- err = f.resetRules(text)
- if err != nil {
- // Don't wrap the error, because it's informative enough as is.
- return err
- }
-
- return nil
+ return text, nil
}
// refreshFromFile loads filter data from a file if the file's mtime shows that
-// it's still fresh. If acceptStale is true, and the cache file exists, the
-// data is read from there regardless of its staleness. If err is nil and text
-// is empty, a refresh from a URL is required.
-func (f *refreshableFilter) refreshFromFile(
+// it's still fresh relative to updTime. If acceptStale is true, and the cache
+// file exists, the data is read from there regardless of its staleness. If err
+// is nil and text is empty, a refresh from a URL is required.
+func (f *Refreshable) refreshFromFile(
acceptStale bool,
+ updTime time.Time,
) (text string, err error) {
// #nosec G304 -- Assume that cachePath is always cacheDir + a valid,
// no-slash filter list ID.
@@ -117,7 +108,7 @@ func (f *refreshableFilter) refreshFromFile(
return "", fmt.Errorf("reading filter file stat: %w", err)
}
- if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(time.Now()) {
+ if mtime := fi.ModTime(); !mtime.Add(f.staleness).After(updTime) {
return "", nil
}
}
@@ -132,15 +123,19 @@ func (f *refreshableFilter) refreshFromFile(
}
// refreshFromURL loads the filter data from u, puts it into the file specified
-// by cachePath, and also returns its content.
-func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, err error) {
+// by cachePath, returns its content, and also sets its atime and mtime to
+// updTime.
+func (f *Refreshable) refreshFromURL(
+ ctx context.Context,
+ updTime time.Time,
+) (text string, err error) {
// TODO(a.garipov): Cache these like renameio recommends.
tmpDir := renameio.TempDir(filepath.Dir(f.cachePath))
tmpFile, err := renameio.TempFile(tmpDir, f.cachePath)
if err != nil {
return "", fmt.Errorf("creating temporary filter file: %w", err)
}
- defer func() { err = withDeferredTmpCleanup(err, tmpFile) }()
+ defer func() { err = f.withDeferredTmpCleanup(err, tmpFile, updTime) }()
resp, err := f.http.Get(ctx, f.url)
if err != nil {
@@ -148,10 +143,17 @@ func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, er
}
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
- srv := resp.Header.Get(agdhttp.HdrNameServer)
+ srv := resp.Header.Get(httphdr.Server)
cl := resp.ContentLength
- log.Info("loading from %q: got content-length %d, code %d, srv %q", f.url, cl, resp.StatusCode, srv)
+ log.Info(
+ "%s: loading from %q: got content-length %d, code %d, srv %q",
+ f.id,
+ f.url,
+ cl,
+ resp.StatusCode,
+ srv,
+ )
err = agdhttp.CheckStatus(resp, http.StatusOK)
if err != nil {
@@ -178,12 +180,21 @@ func (f *refreshableFilter) refreshFromURL(ctx context.Context) (text string, er
// withDeferredTmpCleanup is a helper that performs the necessary cleanups and
// finalizations of the temporary files based on the returned error.
-func withDeferredTmpCleanup(returned error, tmpFile *renameio.PendingFile) (err error) {
+func (f *Refreshable) withDeferredTmpCleanup(
+ returned error,
+ tmpFile *renameio.PendingFile,
+ updTime time.Time,
+) (err error) {
+ // Make sure that any error returned from here is marked as a deferred one.
if returned != nil {
return errors.WithDeferred(returned, tmpFile.Cleanup())
}
- // Make sure that the error returned from CloseAtomicallyReplace is marked
- // as a deferred one.
- return errors.WithDeferred(nil, tmpFile.CloseAtomicallyReplace())
+ err = tmpFile.CloseAtomicallyReplace()
+ if err != nil {
+ return errors.WithDeferred(nil, err)
+ }
+
+ // Set the modification and access times to the moment the refresh started.
+ return errors.WithDeferred(nil, os.Chtimes(f.cachePath, updTime, updTime))
}
diff --git a/internal/filter/internal/refreshable_test.go b/internal/filter/internal/refreshable_test.go
new file mode 100644
index 0000000..146098a
--- /dev/null
+++ b/internal/filter/internal/refreshable_test.go
@@ -0,0 +1,198 @@
+package internal_test
+
+import (
+ "context"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// refrID is the ID of a [agd.FilterList] used for testing.
+const refrID = "test_id"
+
+func TestRefreshable_Refresh(t *testing.T) {
+ const (
+ defaultFileText = "||filefilter.example\n"
+ defaultURLText = "||urlfilter.example\n"
+ )
+
+ testCases := []struct {
+ name string
+ wantText string
+ wantErrMsg string
+ srvText string
+ staleness time.Duration
+ srvCode int
+ acceptStale bool
+ expectReq bool
+ useCacheFile bool
+ }{{
+ name: "no_file",
+ wantText: defaultURLText,
+ wantErrMsg: "",
+ srvText: defaultURLText,
+ staleness: 0,
+ srvCode: http.StatusOK,
+ acceptStale: true,
+ expectReq: true,
+ useCacheFile: false,
+ }, {
+ name: "no_file_http_empty",
+ wantText: "",
+ wantErrMsg: refrID + `: refreshing from url "URL": ` +
+ `server "` + filtertest.ServerName + `": empty text, not resetting`,
+ srvText: "",
+ staleness: 0,
+ srvCode: http.StatusOK,
+ acceptStale: true,
+ expectReq: true,
+ useCacheFile: false,
+ }, {
+ name: "no_file_http_error",
+ wantText: "",
+ wantErrMsg: refrID + `: refreshing from url "URL": ` +
+ `server "` + filtertest.ServerName + `": ` +
+ `status code error: expected 200, got 500`,
+ srvText: "internal server error",
+ staleness: 0,
+ srvCode: http.StatusInternalServerError,
+ acceptStale: true,
+ expectReq: true,
+ useCacheFile: false,
+ }, {
+ name: "file",
+ wantText: defaultFileText,
+ wantErrMsg: "",
+ srvText: "",
+ staleness: 1 * time.Hour,
+ srvCode: 0,
+ acceptStale: true,
+ expectReq: false,
+ useCacheFile: true,
+ }, {
+ name: "file_stale",
+ wantText: defaultURLText,
+ wantErrMsg: "",
+ srvText: defaultURLText,
+ staleness: -1 * time.Hour,
+ srvCode: http.StatusOK,
+ acceptStale: false,
+ expectReq: true,
+ useCacheFile: true,
+ }, {
+ name: "file_stale_accept",
+ wantText: defaultFileText,
+ wantErrMsg: "",
+ srvText: "",
+ staleness: -1 * time.Hour,
+ srvCode: 0,
+ acceptStale: true,
+ expectReq: false,
+ useCacheFile: true,
+ }}
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var err error
+
+ reqCh := make(chan struct{}, 1)
+ var cachePath string
+ realCachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, tc.srvText, tc.srvCode)
+ if tc.useCacheFile {
+ cachePath = realCachePath
+
+ err = os.WriteFile(cachePath, []byte(defaultFileText), 0o600)
+ require.NoError(t, err)
+ } else {
+ cachePath = filepath.Join(t.TempDir(), "does_not_exist")
+ }
+
+ fl := &agd.FilterList{
+ URL: srvURL,
+ ID: refrID,
+ RefreshIvl: tc.staleness,
+ }
+ f := internal.NewRefreshable(fl, cachePath)
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var gotText string
+ gotText, err = f.Refresh(ctx, tc.acceptStale)
+ if tc.expectReq {
+ testutil.RequireReceive(t, reqCh, filtertest.Timeout)
+ }
+
+ // Since we only get the actual URL within the subtest, replace it
+ // here and check the error message.
+ if srvURL != nil {
+ tc.wantErrMsg = strings.ReplaceAll(tc.wantErrMsg, "URL", srvURL.String())
+ }
+
+ testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
+ assert.Equal(t, tc.wantText, gotText)
+ })
+ }
+}
+
+func TestRefreshable_Refresh_properStaleness(t *testing.T) {
+ const (
+ responseDur = time.Second / 5
+ staleness = time.Hour
+ )
+
+ reqCh := make(chan struct{})
+ cachePath, addr := filtertest.PrepareRefreshable(t, reqCh, filtertest.BlockRule, http.StatusOK)
+
+ fl := &agd.FilterList{
+ URL: addr,
+ ID: refrID,
+ RefreshIvl: staleness,
+ }
+ f := internal.NewRefreshable(fl, cachePath)
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ var err error
+ var now time.Time
+ go func() {
+ <-reqCh
+ now = time.Now()
+ _, err = f.Refresh(ctx, false)
+ <-reqCh
+ }()
+
+ // Start the refresh.
+ reqCh <- struct{}{}
+
+ // Hold the handler to guarantee the refresh will endure some time.
+ time.Sleep(responseDur)
+
+ // Continue the refresh.
+ testutil.RequireReceive(t, reqCh, filtertest.Timeout)
+
+ // Ensure the refresh finished.
+ reqCh <- struct{}{}
+
+ require.NoError(t, err)
+
+ file, err := os.Open(cachePath)
+ require.NoError(t, err)
+ testutil.CleanupAndRequireSuccess(t, file.Close)
+
+ fi, err := file.Stat()
+ require.NoError(t, err)
+
+ assert.InDelta(t, fi.ModTime().Sub(now), 0, float64(time.Millisecond))
+}
diff --git a/internal/filter/internal/resultcache/resultcache.go b/internal/filter/internal/resultcache/resultcache.go
index f230be5..ea4b4d9 100644
--- a/internal/filter/internal/resultcache/resultcache.go
+++ b/internal/filter/internal/resultcache/resultcache.go
@@ -83,7 +83,7 @@ var hashSeed = maphash.MakeSeed()
// DefaultKey produces a cache key based on host, qt, and isAns using the
// default algorithm.
-func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
+func DefaultKey(host string, qt dnsmsg.RRType, cl dnsmsg.Class, isAns bool) (k Key) {
// Use maphash explicitly instead of using a key structure to reduce
// allocations and optimize interface conversion up the stack.
h := &maphash.Hash{}
@@ -92,9 +92,10 @@ func DefaultKey(host string, qt dnsmsg.RRType, isAns bool) (k Key) {
_, _ = h.WriteString(host)
// Save on allocations by reusing a buffer.
- var buf [3]byte
+ var buf [5]byte
binary.LittleEndian.PutUint16(buf[:2], qt)
- buf[2] = mathutil.BoolToNumber[byte](isAns)
+ binary.LittleEndian.PutUint16(buf[2:4], cl)
+ buf[4] = mathutil.BoolToNumber[byte](isAns)
_, _ = h.Write(buf[:])
diff --git a/internal/filter/internal/resultcache/resultcache_test.go b/internal/filter/internal/resultcache/resultcache_test.go
index 5594800..aaa0190 100644
--- a/internal/filter/internal/resultcache/resultcache_test.go
+++ b/internal/filter/internal/resultcache/resultcache_test.go
@@ -11,8 +11,8 @@ import (
// Common keys for tests.
var (
- testKey = resultcache.DefaultKey("example.com", dns.TypeA, true)
- otherKey = resultcache.DefaultKey("example.org", dns.TypeAAAA, false)
+ testKey = resultcache.DefaultKey("example.com", dns.TypeA, dns.ClassINET, true)
+ otherKey = resultcache.DefaultKey("example.org", dns.TypeAAAA, dns.ClassINET, false)
)
// val is the common value for tests.
diff --git a/internal/filter/internal/rulelist/immutable.go b/internal/filter/internal/rulelist/immutable.go
new file mode 100644
index 0000000..8126a34
--- /dev/null
+++ b/internal/filter/internal/rulelist/immutable.go
@@ -0,0 +1,40 @@
+package rulelist
+
+import (
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+)
+
+// Immutable is a rule-list filter that doesn't refresh or change.
+// It is used for users' custom rule-lists as well as in service blocking.
+//
+// TODO(a.garipov): Consider not using rule-list engines for service and custom
+// filters at all. It could be faster to simply go through all enabled rules
+// sequentially instead. Alternatively, rework the urlfilter.DNSEngine and make
+// it use the sequential scan if the number of rules is less than some constant
+// value.
+//
+// See AGDNS-342.
+type Immutable struct {
+ // TODO(a.garipov): Find ways to embed it in a way that shows the methods,
+ // doesn't result in double dereferences, and doesn't cause naming issues.
+ *filter
+}
+
+// NewImmutable returns a new immutable DNS request and response filter using
+// the provided rule text and ID.
+func NewImmutable(
+ text string,
+ id agd.FilterListID,
+ svcID agd.BlockedServiceID,
+ memCacheSize int,
+ useMemCache bool,
+) (f *Immutable, err error) {
+ f = &Immutable{}
+ f.filter, err = newFilter(text, id, svcID, memCacheSize, useMemCache)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return nil, err
+ }
+
+ return f, nil
+}
diff --git a/internal/filter/internal/rulelist/refreshable.go b/internal/filter/internal/rulelist/refreshable.go
new file mode 100644
index 0000000..52d7031
--- /dev/null
+++ b/internal/filter/internal/rulelist/refreshable.go
@@ -0,0 +1,136 @@
+package rulelist
+
+import (
+ "context"
+ "fmt"
+ "net/netip"
+ "path/filepath"
+ "sync"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/golibs/log"
+ "github.com/AdguardTeam/urlfilter"
+ "github.com/AdguardTeam/urlfilter/filterlist"
+)
+
+// Refreshable is a refreshable DNS request and response filter based on filter
+// rule lists.
+//
+// TODO(a.garipov): Consider adding a separate version that uses a single engine
+// for multiple rule lists and using it to optimize the filtering using default
+// filtering groups.
+type Refreshable struct {
+ *filter
+
+ // mu protects filter.engine.
+ mu *sync.RWMutex
+
+ // refr contains data for refreshing the filter.
+ refr *internal.Refreshable
+}
+
+// NewRefreshable returns a new refreshable DNS request and response filter
+// based on the provided rule list. l must be non-nil. The initial refresh
+// should be called explicitly if necessary.
+func NewRefreshable(
+ l *agd.FilterList,
+ fileCacheDir string,
+ memCacheSize int,
+ useMemCache bool,
+) (f *Refreshable) {
+ f = &Refreshable{
+ mu: &sync.RWMutex{},
+ refr: internal.NewRefreshable(l, filepath.Join(fileCacheDir, string(l.ID))),
+ }
+
+ var err error
+ f.filter, err = newFilter("", l.ID, "", memCacheSize, useMemCache)
+ if err != nil {
+ // Should never happen, since text is empty.
+ panic(fmt.Errorf("unexpected filter error: %w", err))
+ }
+
+ return f
+}
+
+// NewFromString returns a new DNS request and response filter using the
+// provided rule text and ID.
+//
+// TODO(a.garipov): Only used in tests. Consider removing later.
+func NewFromString(
+ text string,
+ id agd.FilterListID,
+ svcID agd.BlockedServiceID,
+ memCacheSize int,
+ useMemCache bool,
+) (f *Refreshable, err error) {
+ f = &Refreshable{
+ mu: &sync.RWMutex{},
+ }
+
+ f.filter, err = newFilter(text, id, svcID, memCacheSize, useMemCache)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return nil, err
+ }
+
+ return f, nil
+}
+
+// DNSResult returns the result of applying the urlfilter DNS filtering engine.
+// If the request is not filtered, DNSResult returns nil.
+func (f *Refreshable) DNSResult(
+ clientIP netip.Addr,
+ clientName string,
+ host string,
+ rrType dnsmsg.RRType,
+ isAns bool,
+) (res *urlfilter.DNSResult) {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ return f.filter.DNSResult(clientIP, clientName, host, rrType, isAns)
+}
+
+// Refresh reloads the rule list data. If acceptStale is true, do not try to
+// load the list from its URL when there is already a file in the cache
+// directory, regardless of its staleness.
+func (f *Refreshable) Refresh(ctx context.Context, acceptStale bool) (err error) {
+ text, err := f.refr.Refresh(ctx, acceptStale)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ // TODO(a.garipov): Add filterlist.BytesRuleList.
+ strList := &filterlist.StringRuleList{
+ ID: f.urlFilterID,
+ RulesText: text,
+ IgnoreCosmetic: true,
+ }
+
+ s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
+ if err != nil {
+ return fmt.Errorf("creating rule storage: %w", err)
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cache.Clear()
+ f.engine = urlfilter.NewDNSEngine(s)
+
+ log.Info("%s: reset %d rules", f.id, f.engine.RulesCount)
+
+ return nil
+}
+
+// RulesCount returns the number of rules in the filter's engine.
+func (f *Refreshable) RulesCount() (n int) {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ return f.filter.RulesCount()
+}
diff --git a/internal/filter/internal/rulelist/refreshable_test.go b/internal/filter/internal/rulelist/refreshable_test.go
new file mode 100644
index 0000000..01086ee
--- /dev/null
+++ b/internal/filter/internal/rulelist/refreshable_test.go
@@ -0,0 +1,107 @@
+package rulelist_test
+
+import (
+ "context"
+ "net/http"
+ "net/netip"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// testReqHost is the request host for tests.
+const testReqHost = "blocked.example"
+
+// testRemoteIP is the client IP for tests
+var testRemoteIP = netip.MustParseAddr("1.2.3.4")
+
+// testFltListID is the common filter list IDs for tests.
+const testFltListID agd.FilterListID = "fl1"
+
+// testBlockRule is the common blocking rule for tests.
+const testBlockRule = "||" + testReqHost + "\n"
+
+func TestRefreshable_RulesCount(t *testing.T) {
+ rl, err := rulelist.NewFromString(testBlockRule, testFltListID, "", 0, false)
+ require.NoError(t, err)
+
+ assert.Equal(t, 1, rl.RulesCount())
+}
+
+func TestRefreshable_DNSResult_cache(t *testing.T) {
+ rl, err := rulelist.NewFromString(testBlockRule, testFltListID, "", 100, true)
+ require.NoError(t, err)
+
+ const qt = dns.TypeA
+
+ t.Run("blocked", func(t *testing.T) {
+ dr := rl.DNSResult(testRemoteIP, "", testReqHost, qt, false)
+ require.NotNil(t, dr)
+
+ assert.Len(t, dr.NetworkRules, 1)
+
+ cachedDR := rl.DNSResult(testRemoteIP, "", testReqHost, qt, false)
+ require.NotNil(t, cachedDR)
+
+ assert.Same(t, dr, cachedDR)
+ })
+
+ t.Run("none", func(t *testing.T) {
+ const otherHost = "other.example"
+
+ dr := rl.DNSResult(testRemoteIP, "", otherHost, qt, false)
+ assert.Nil(t, dr)
+
+ cachedDR := rl.DNSResult(testRemoteIP, "", otherHost, dns.TypeA, false)
+ assert.Nil(t, cachedDR)
+ })
+}
+
+func TestRefreshable_ID(t *testing.T) {
+ const svcID = agd.BlockedServiceID("test_service")
+ rl, err := rulelist.NewFromString(testBlockRule, testFltListID, svcID, 0, false)
+ require.NoError(t, err)
+
+ gotID, gotSvcID := rl.ID()
+ assert.Equal(t, testFltListID, gotID)
+ assert.Equal(t, svcID, gotSvcID)
+}
+
+func TestRefreshable_Refresh(t *testing.T) {
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, nil, testBlockRule, http.StatusOK)
+ rl := rulelist.NewRefreshable(
+ &agd.FilterList{
+ URL: srvURL,
+ ID: testFltListID,
+ RefreshIvl: 1 * time.Hour,
+ },
+ filepath.Dir(cachePath),
+ 100,
+ true,
+ )
+
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ err := rl.Refresh(ctx, false)
+ require.NoError(t, err)
+
+ assert.Equal(t, 1, rl.RulesCount())
+
+ dr := rl.DNSResult(testRemoteIP, "", testReqHost, dns.TypeA, false)
+ require.NotNil(t, dr)
+
+ assert.Len(t, dr.NetworkRules, 1)
+}
diff --git a/internal/filter/internal/rulelist/rulelist.go b/internal/filter/internal/rulelist/rulelist.go
new file mode 100644
index 0000000..114e50e
--- /dev/null
+++ b/internal/filter/internal/rulelist/rulelist.go
@@ -0,0 +1,154 @@
+// Package rulelist contains the implementation of the standard rule-list
+// filter that wraps an urlfilter filtering-engine.
+package rulelist
+
+import (
+ "fmt"
+ "math/rand"
+ "net/netip"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
+ "github.com/AdguardTeam/urlfilter"
+ "github.com/AdguardTeam/urlfilter/filterlist"
+ "github.com/miekg/dns"
+)
+
+// newURLFilterID returns a new random ID for the urlfilter DNS engine to use.
+func newURLFilterID() (id int) {
+ // #nosec G404 -- Do not use cryptographically random ID generation, since
+ // these are only used in one place, internal/filter.compFilter.filterMsg,
+ // and are not used in any security-sensitive context.
+ //
+ // Despite the fact that the type of integer filter list IDs in module
+ // urlfilter is int, the module actually assumes that the ID is a
+ // non-negative integer, or at least not a largely negative one. Otherwise,
+ // some of its low-level optimizations seem to break.
+ return int(rand.Int31())
+}
+
+// filter is the basic rule-list filter that doesn't refresh or change in any
+// other way.
+type filter struct {
+ // engine is the DNS filtering engine.
+ //
+ // NOTE: We do not save the [filterlist.RuleList] used to create the engine
+ // to close it, because we exclusively use [filterlist.StringRuleList],
+ // which doesn't require closing.
+ engine *urlfilter.DNSEngine
+
+ // cache contains cached results of filtering.
+ //
+ // TODO(ameshkov): Add metrics for these caches.
+ cache *resultcache.Cache[*urlfilter.DNSResult]
+
+ // id is the filter list ID, if any.
+ id agd.FilterListID
+
+ // svcID is the additional identifier for blocked service lists. If id is
+ svcID agd.BlockedServiceID
+
+ // urlFilterID is the synthetic integer identifier for the urlfilter engine.
+ //
+ // TODO(a.garipov): Change the type to a string in module urlfilter and
+ // remove this crutch.
+ urlFilterID int
+}
+
+// newFilter returns a new basic DNS request and response filter using the
+// provided rule text and ID.
+func newFilter(
+ text string,
+ id agd.FilterListID,
+ svcID agd.BlockedServiceID,
+ memCacheSize int,
+ useMemCache bool,
+) (f *filter, err error) {
+ f = &filter{
+ id: id,
+ svcID: svcID,
+ urlFilterID: newURLFilterID(),
+ }
+
+ if useMemCache {
+ f.cache = resultcache.New[*urlfilter.DNSResult](memCacheSize)
+ }
+
+ // TODO(a.garipov): Add filterlist.BytesRuleList.
+ strList := &filterlist.StringRuleList{
+ ID: f.urlFilterID,
+ RulesText: text,
+ IgnoreCosmetic: true,
+ }
+
+ s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
+ if err != nil {
+ return nil, fmt.Errorf("creating rule storage: %w", err)
+ }
+
+ f.engine = urlfilter.NewDNSEngine(s)
+
+ return f, nil
+}
+
+// DNSResult returns the result of applying the urlfilter DNS filtering engine.
+// If the request is not filtered, DNSResult returns nil.
+func (f *filter) DNSResult(
+ clientIP netip.Addr,
+ clientName string,
+ host string,
+ rrType dnsmsg.RRType,
+ isAns bool,
+) (res *urlfilter.DNSResult) {
+ var ok bool
+ var cacheKey resultcache.Key
+
+ // Don't waste resources on computing the cache key if the cache is not
+ // enabled.
+ useCache := f.cache != nil
+ if useCache {
+ // TODO(a.garipov): Add real class here.
+ cacheKey = resultcache.DefaultKey(host, rrType, dns.ClassINET, isAns)
+ res, ok = f.cache.Get(cacheKey)
+ if ok {
+ return res
+ }
+ }
+
+ dnsReq := &urlfilter.DNSRequest{
+ Hostname: host,
+ // TODO(a.garipov): Make this a net.IP in module urlfilter.
+ ClientIP: clientIP.String(),
+ ClientName: clientName,
+ DNSType: rrType,
+ Answer: isAns,
+ }
+
+ res, ok = f.engine.MatchRequest(dnsReq)
+ if !ok && len(res.NetworkRules) == 0 {
+ res = nil
+ }
+
+ if useCache {
+ f.cache.Set(cacheKey, res)
+ }
+
+ return res
+}
+
+// ID returns the filter list ID of this rule list filter, as well as the ID of
+// the blocked service, if any.
+func (f *filter) ID() (id agd.FilterListID, svcID agd.BlockedServiceID) {
+ return f.id, f.svcID
+}
+
+// RulesCount returns the number of rules in the filter's engine.
+func (f *filter) RulesCount() (n int) {
+ return f.engine.RulesCount
+}
+
+// URLFilterID returns the synthetic ID used for the urlfilter module.
+func (f *filter) URLFilterID() (n int) {
+ return f.urlFilterID
+}
diff --git a/internal/filter/internal/safesearch/safesearch.go b/internal/filter/internal/safesearch/safesearch.go
new file mode 100644
index 0000000..1a50da1
--- /dev/null
+++ b/internal/filter/internal/safesearch/safesearch.go
@@ -0,0 +1,185 @@
+// Package safesearch contains the implementation of the safe-search filter
+// that uses lists of DNS rewrite rules to enforce safe search.
+package safesearch
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/miekg/dns"
+)
+
+// Filter modifies the results of queries to search-engine addresses and
+// rewrites them to the IP addresses of their safe versions.
+type Filter struct {
+ resCache *resultcache.Cache[*internal.ResultModified]
+ flt *rulelist.Refreshable
+ resolver agdnet.Resolver
+ errColl agd.ErrorCollector
+ id agd.FilterListID
+}
+
+// Config contains configuration for the safe-search filter.
+type Config struct {
+ // List is the filtering-rule list used to filter requests.
+ List *agd.FilterList
+
+ // Resolver is used to resolve the IP addresses of replacement hosts.
+ Resolver agdnet.Resolver
+
+ // ErrColl is used to report errors of replacement-host resolving.
+ ErrColl agd.ErrorCollector
+
+ // CacheDir is the path to the directory where the cached filter files are
+ // put. The directory must exist.
+ CacheDir string
+
+ // CacheTTL is the time to live of the result cache-items.
+ //
+ //lint:ignore U1000 TODO(a.garipov): Currently unused. See AGDNS-398.
+ CacheTTL time.Duration
+
+ // CacheSize is the number of items in the result cache.
+ CacheSize int
+}
+
+// New returns a new safe-search filter. c must not be nil. The initial
+// refresh should be called explicitly if necessary.
+func New(c *Config) (f *Filter) {
+ id := c.List.ID
+
+ return &Filter{
+ resCache: resultcache.New[*internal.ResultModified](c.CacheSize),
+ // Don't use the rule list cache, since safeSearch already has its own.
+ flt: rulelist.NewRefreshable(c.List, c.CacheDir, 0, false),
+ resolver: c.Resolver,
+ errColl: c.ErrColl,
+ id: id,
+ }
+}
+
+// type check
+var _ internal.RequestFilter = (*Filter)(nil)
+
+// FilterRequest implements the [internal.RequestFilter] interface for *Filter.
+// It modifies the response if host matches f.
+func (f *Filter) FilterRequest(
+ ctx context.Context,
+ req *dns.Msg,
+ ri *agd.RequestInfo,
+) (r internal.Result, err error) {
+ qt := ri.QType
+ fam := netutil.AddrFamilyFromRRType(qt)
+ if fam == netutil.AddrFamilyNone {
+ return nil, nil
+ }
+
+ host := ri.Host
+ cacheKey := resultcache.DefaultKey(host, qt, ri.QClass, false)
+ rm, ok := f.resCache.Get(cacheKey)
+ if ok {
+ if rm == nil {
+ // Return nil explicitly instead of modifying CloneForReq to return
+ // nil if the result is nil to avoid a “non-nil nil” value.
+ return nil, nil
+ }
+
+ return rm.CloneForReq(req), nil
+ }
+
+ repHost, ok := f.safeSearchHost(host, qt)
+ if !ok {
+ optlog.Debug2("filter %s: host %q is not on the list", f.id, host)
+
+ f.resCache.Set(cacheKey, nil)
+
+ return nil, nil
+ }
+
+ optlog.Debug2("filter %s: found host %q", f.id, repHost)
+
+ ctx, cancel := context.WithTimeout(ctx, internal.DefaultResolveTimeout)
+ defer cancel()
+
+ var result *dns.Msg
+ ips, err := f.resolver.LookupIP(ctx, fam, repHost)
+ if err != nil {
+ agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.id, err)
+
+ result = ri.Messages.NewMsgSERVFAIL(req)
+ } else {
+ result, err = ri.Messages.NewIPRespMsg(req, ips...)
+ if err != nil {
+ return nil, fmt.Errorf("filter %s: creating modified result: %w", f.id, err)
+ }
+ }
+
+ rm = &internal.ResultModified{
+ Msg: result,
+ List: f.id,
+ Rule: agd.FilterRuleText(host),
+ }
+
+ // Copy the result to make sure that modifications to the result message
+ // down the pipeline don't interfere with the cached value.
+ //
+ // See AGDNS-359.
+ f.resCache.Set(cacheKey, rm.Clone())
+
+ return rm, nil
+}
+
+// safeSearchHost returns the replacement host for the given host and question
+// type, if any. qt should be either [dns.TypeA] or [dns.TypeAAAA].
+func (f *Filter) safeSearchHost(host string, qt dnsmsg.RRType) (ssHost string, ok bool) {
+ dr := f.flt.DNSResult(netip.Addr{}, "", host, qt, false)
+ if dr == nil {
+ return "", false
+ }
+
+ for _, nr := range dr.DNSRewrites() {
+ drw := nr.DNSRewrite
+ if drw.RCode != dns.RcodeSuccess {
+ continue
+ }
+
+ if nc := drw.NewCNAME; nc != "" {
+ return nc, true
+ }
+
+ // All the rules in safe search rule lists are expected to have either
+ // A/AAAA or CNAME type.
+ switch drw.RRType {
+ case dns.TypeA, dns.TypeAAAA:
+ return drw.Value.(net.IP).String(), true
+ default:
+ continue
+ }
+ }
+
+ return "", false
+}
+
+// Refresh reloads the rule list data. If acceptStale is true, and the cache
+// file exists, the data is read from there regardless of its staleness.
+func (f *Filter) Refresh(ctx context.Context, acceptStale bool) (err error) {
+ err = f.flt.Refresh(ctx, acceptStale)
+ if err != nil {
+ return err
+ }
+
+ f.resCache.Clear()
+
+ return nil
+}
diff --git a/internal/filter/internal/safesearch/safesearch_test.go b/internal/filter/internal/safesearch/safesearch_test.go
new file mode 100644
index 0000000..1bb4015
--- /dev/null
+++ b/internal/filter/internal/safesearch/safesearch_test.go
@@ -0,0 +1,198 @@
+package safesearch_test
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "net/netip"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/netutil"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// testSafeIPStr is the string representation of the IP address of the safe
+// version of [testEngineWithIP].
+const testSafeIPStr = "1.2.3.4"
+
+// testIPOfEngineWithIP is the IP address of the safe version of
+// search-engine-ip.example.
+var testIPOfEngineWithIP net.IP = netip.MustParseAddr(testSafeIPStr).AsSlice()
+
+// testIPOfEngineWithDomain is the IP address of the safe version of
+// search-engine-domain.example.
+var testIPOfEngineWithDomain = net.IP{1, 2, 3, 5}
+
+// Common domain names for tests.
+const (
+ testOther = "other.example"
+ testEngineWithIP = "search-engine-ip.example"
+ testEngineWithDomain = "search-engine-domain.example"
+ testSafeDomain = "safe-search-engine-domain.example"
+)
+
+// testFilterRules is are common filtering rules for tests.
+const testFilterRules = `|` + testEngineWithIP + `^$dnsrewrite=NOERROR;A;` + testSafeIPStr + "\n" +
+ `|` + testEngineWithDomain + `^$dnsrewrite=NOERROR;CNAME;` + testSafeDomain
+
+func TestFilter(t *testing.T) {
+ reqCh := make(chan struct{}, 1)
+ cachePath, srvURL := filtertest.PrepareRefreshable(t, reqCh, testFilterRules, http.StatusOK)
+
+ id, err := agd.NewFilterListID(filepath.Base(cachePath))
+ require.NoError(t, err)
+
+ f := safesearch.New(&safesearch.Config{
+ List: &agd.FilterList{
+ ID: id,
+ URL: srvURL,
+ },
+ Resolver: &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ host string,
+ ) (ips []net.IP, err error) {
+ switch host {
+ case testSafeIPStr:
+ return []net.IP{testIPOfEngineWithIP}, nil
+ case testSafeDomain:
+ return []net.IP{testIPOfEngineWithDomain}, nil
+ default:
+ return nil, errors.Error("test resolver error")
+ }
+ },
+ },
+ ErrColl: &agdtest.ErrorCollector{
+ OnCollect: func(ctx context.Context, err error) {
+ panic("not implemented")
+ },
+ },
+ CacheDir: filepath.Dir(cachePath),
+ CacheTTL: 1 * time.Minute,
+ CacheSize: 100,
+ })
+
+ refrCtx, refrCancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(refrCancel)
+
+ err = f.Refresh(refrCtx, true)
+ require.NoError(t, err)
+
+ testutil.RequireReceive(t, reqCh, filtertest.Timeout)
+
+ t.Run("no_match", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ req, ri := newReq(t, testOther, dns.TypeA)
+ res, fltErr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ assert.Nil(t, res)
+
+ t.Run("cached", func(t *testing.T) {
+ res, fltErr = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ // TODO(a.garipov): Find a way to make caches more inspectable.
+ assert.Nil(t, res)
+ })
+ })
+
+ t.Run("txt", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ req, ri := newReq(t, testEngineWithIP, dns.TypeTXT)
+ res, fltErr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ assert.Nil(t, res)
+ })
+
+ t.Run("ip", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ req, ri := newReq(t, testEngineWithIP, dns.TypeA)
+ res, fltErr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ rm := testutil.RequireTypeAssert[*internal.ResultModified](t, res)
+ require.Len(t, rm.Msg.Answer, 1)
+
+ assert.Equal(t, rm.Rule, agd.FilterRuleText(testEngineWithIP))
+
+ a := testutil.RequireTypeAssert[*dns.A](t, rm.Msg.Answer[0])
+ assert.Equal(t, testIPOfEngineWithIP, a.A)
+
+ t.Run("cached", func(t *testing.T) {
+ newReq, newRI := newReq(t, testEngineWithIP, dns.TypeA)
+
+ var cachedRes internal.Result
+ cachedRes, fltErr = f.FilterRequest(ctx, newReq, newRI)
+ require.NoError(t, fltErr)
+
+ // Do not assert that the results are the same, since a modified
+ // result of a safe search is always cloned. But assert that the
+ // non-clonable fields are equal and that the message has reply
+ // fields set properly.
+ cachedRM := testutil.RequireTypeAssert[*internal.ResultModified](t, cachedRes)
+ assert.NotSame(t, cachedRM, rm)
+ assert.Equal(t, cachedRM.Msg.Id, newReq.Id)
+ assert.Equal(t, cachedRM.List, rm.List)
+ assert.Equal(t, cachedRM.Rule, rm.Rule)
+ })
+ })
+
+ t.Run("domain", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), filtertest.Timeout)
+ t.Cleanup(cancel)
+
+ req, ri := newReq(t, testEngineWithDomain, dns.TypeA)
+ res, fltErr := f.FilterRequest(ctx, req, ri)
+ require.NoError(t, fltErr)
+
+ rm := testutil.RequireTypeAssert[*internal.ResultModified](t, res)
+ require.Len(t, rm.Msg.Answer, 1)
+
+ assert.Equal(t, rm.Rule, agd.FilterRuleText(testEngineWithDomain))
+
+ a := testutil.RequireTypeAssert[*dns.A](t, rm.Msg.Answer[0])
+ assert.Equal(t, testIPOfEngineWithDomain, a.A)
+ })
+}
+
+// newReq is a test helper that returns the DNS request and its accompanying
+// request info with the given data.
+func newReq(tb testing.TB, host string, qt dnsmsg.RRType) (req *dns.Msg, ri *agd.RequestInfo) {
+ tb.Helper()
+
+ req = dnsservertest.NewReq(host, qt, dns.ClassINET)
+ ri = &agd.RequestInfo{
+ Messages: agdtest.NewConstructor(),
+ Host: host,
+ QType: qt,
+ QClass: dns.ClassINET,
+ }
+
+ return req, ri
+}
diff --git a/internal/filter/internal/serviceblock/index.go b/internal/filter/internal/serviceblock/index.go
new file mode 100644
index 0000000..3524ea4
--- /dev/null
+++ b/internal/filter/internal/serviceblock/index.go
@@ -0,0 +1,96 @@
+package serviceblock
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
+)
+
+// indexResp is the struct for the JSON response from a blocked service index
+// API.
+type indexResp struct {
+ BlockedServices []*indexRespService `json:"blocked_services"`
+}
+
+// toInternal converts the services from the index to serviceRuleLists.
+func (r *indexResp) toInternal(
+ ctx context.Context,
+ errColl agd.ErrorCollector,
+ cacheSize int,
+ useCache bool,
+) (services serviceRuleLists, err error) {
+ l := len(r.BlockedServices)
+ if l == 0 {
+ return nil, nil
+ }
+
+ services = make(serviceRuleLists, l)
+ errs := make([]error, len(r.BlockedServices))
+ for i, svc := range r.BlockedServices {
+ var (
+ svcID agd.BlockedServiceID
+ rl *rulelist.Immutable
+ )
+ svcID, rl, err = svc.toInternal(ctx, errColl, cacheSize, useCache)
+ if err != nil {
+ errs[i] = fmt.Errorf("service at index %d: %w", i, err)
+
+ continue
+ }
+
+ services[svcID] = rl
+ }
+
+ err = errors.Join(errs...)
+ if err != nil {
+ return nil, fmt.Errorf("converting blocked services: %w", err)
+ }
+
+ return services, nil
+}
+
+// indexRespService is the struct for a filter from the JSON response from a
+// blocked service index API.
+type indexRespService struct {
+ ID string `json:"id"`
+ Rules []string `json:"rules"`
+}
+
+// toInternal converts the service from the index to a rule-list filter.
+func (svc *indexRespService) toInternal(
+ ctx context.Context,
+ errColl agd.ErrorCollector,
+ cacheSize int,
+ useCache bool,
+) (svcID agd.BlockedServiceID, rl *rulelist.Immutable, err error) {
+ svcID, err = agd.NewBlockedServiceID(svc.ID)
+ if err != nil {
+ return "", nil, fmt.Errorf("validating id: %w", err)
+ }
+
+ if len(svc.Rules) == 0 {
+ reportErr := fmt.Errorf("service filter: no rules for service with id %s", svcID)
+ errColl.Collect(ctx, reportErr)
+ log.Info("warning: %s", reportErr)
+ }
+
+ rl, err = rulelist.NewImmutable(
+ strings.Join(svc.Rules, "\n"),
+ agd.FilterListIDBlockedService,
+ svcID,
+ cacheSize,
+ useCache,
+ )
+ if err != nil {
+ return "", nil, fmt.Errorf("compiling %s: %w", svc.ID, err)
+ }
+
+ log.Info("%s/%s: got %d rules", agd.FilterListIDBlockedService, svcID, rl.RulesCount())
+
+ return svcID, rl, nil
+}
diff --git a/internal/filter/internal/serviceblock/serviceblock.go b/internal/filter/internal/serviceblock/serviceblock.go
new file mode 100644
index 0000000..de378fd
--- /dev/null
+++ b/internal/filter/internal/serviceblock/serviceblock.go
@@ -0,0 +1,149 @@
+// Package serviceblock contains an implementation of a filter that blocks
+// services using rule lists. The blocking is based on the parental-control
+// settings in the profile.
+package serviceblock
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "sync"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
+)
+
+// Filter is a service-blocking filter that uses rule lists that it gets from an
+// index.
+type Filter struct {
+ // url is the URL from which the services are fetched.
+ url *url.URL
+
+ // http is the HTTP client used to refresh the filter.
+ http *agdhttp.Client
+
+ // mu protects services.
+ mu *sync.RWMutex
+
+ // services is an ID to filter mapping.
+ services serviceRuleLists
+
+ // errColl used to collect non-critical and rare errors.
+ errColl agd.ErrorCollector
+}
+
+// serviceRuleLists is convenient alias for a ID to filter mapping.
+type serviceRuleLists = map[agd.BlockedServiceID]*rulelist.Immutable
+
+// New returns a fully initialized service blocker.
+func New(indexURL *url.URL, errColl agd.ErrorCollector) (f *Filter) {
+ return &Filter{
+ url: indexURL,
+ http: agdhttp.NewClient(&agdhttp.ClientConfig{
+ Timeout: internal.DefaultFilterRefreshTimeout,
+ }),
+ mu: &sync.RWMutex{},
+ errColl: errColl,
+ }
+}
+
+// RuleLists returns the rule-list filters for the given blocked service IDs.
+// The order of the elements in rls is undefined.
+func (f *Filter) RuleLists(
+ ctx context.Context,
+ ids []agd.BlockedServiceID,
+) (rls []*rulelist.Immutable) {
+ if len(ids) == 0 {
+ return nil
+ }
+
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ for _, id := range ids {
+ rl := f.services[id]
+ if rl != nil {
+ rls = append(rls, rl)
+
+ continue
+ }
+
+ reportErr := fmt.Errorf("service filter: no service with id %s", id)
+ f.errColl.Collect(ctx, reportErr)
+ log.Info("warning: %s", reportErr)
+ }
+
+ return rls
+}
+
+// Refresh loads new service data from the index URL.
+func (f *Filter) Refresh(ctx context.Context, cacheSize int, useCache bool) (err error) {
+ fltIDStr := string(agd.FilterListIDBlockedService)
+ defer func() {
+ if err != nil {
+ metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(0)
+ }
+ }()
+
+ resp, err := f.loadIndex(ctx)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ services, err := resp.toInternal(ctx, f.errColl, cacheSize, useCache)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ count := 0
+ for _, s := range services {
+ count += s.RulesCount()
+ }
+
+ metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(count))
+ metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime()
+ metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(1)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.services = services
+
+ return nil
+}
+
+// loadIndex fetches, decodes, and returns the blocked service index data.
+func (f *Filter) loadIndex(ctx context.Context) (resp *indexResp, err error) {
+ defer func() { err = errors.Annotate(err, "loading blocked service index from %q: %w", f.url) }()
+
+ httpResp, err := f.http.Get(ctx, f.url)
+ if err != nil {
+ return nil, fmt.Errorf("requesting: %w", err)
+ }
+ defer func() { err = errors.WithDeferred(err, httpResp.Body.Close()) }()
+
+ err = agdhttp.CheckStatus(httpResp, http.StatusOK)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return nil, err
+ }
+
+ resp = &indexResp{}
+ err = json.NewDecoder(httpResp.Body).Decode(resp)
+ if err != nil {
+ return nil, agdhttp.WrapServerError(fmt.Errorf("decoding: %w", err), httpResp)
+ }
+
+ log.Debug("service filter: loaded index with %d blocked services", len(resp.BlockedServices))
+
+ return resp, nil
+}
diff --git a/internal/filter/internal/serviceblock/serviceblock_test.go b/internal/filter/internal/serviceblock/serviceblock_test.go
new file mode 100644
index 0000000..e93d24b
--- /dev/null
+++ b/internal/filter/internal/serviceblock/serviceblock_test.go
@@ -0,0 +1,74 @@
+package serviceblock_test
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/serviceblock"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Common blocked service IDs for tests.
+const (
+ testSvcID1 agd.BlockedServiceID = "svc_1"
+ testSvcID2 agd.BlockedServiceID = "svc_2"
+)
+
+// testData is a sample of a service index response.
+//
+// See https://github.com/atropnikov/HostlistsRegistry/blob/main/assets/services.json.
+const testData string = `{
+ "blocked_services": [
+ {
+ "id": "` + string(testSvcID1) + `",
+ "name": "Service 1",
+ "rules": [
+ "||service-1.example^"
+ ]
+ },
+ {
+ "id": "` + string(testSvcID2) + `",
+ "name": "Service 2",
+ "rules": [
+ "||service-2.example^"
+ ]
+ }
+ ]
+}`
+
+func TestFilter(t *testing.T) {
+ reqCh := make(chan struct{}, 1)
+ _, srvURL := filtertest.PrepareRefreshable(t, reqCh, testData, http.StatusOK)
+
+ errColl := &agdtest.ErrorCollector{
+ OnCollect: func(ctx context.Context, err error) {
+ panic("not implemented")
+ },
+ }
+
+ f := serviceblock.New(srvURL, errColl)
+
+ ctx := context.Background()
+ err := f.Refresh(ctx, 0, false)
+ require.NoError(t, err)
+
+ testutil.RequireReceive(t, reqCh, filtertest.Timeout)
+
+ svcIDs := []agd.BlockedServiceID{testSvcID1, testSvcID2}
+ rls := f.RuleLists(ctx, svcIDs)
+ require.Len(t, rls, 2)
+
+ gotFltIDs := make([]agd.FilterListID, 2)
+ gotSvcIDs := make([]agd.BlockedServiceID, 2)
+ gotFltIDs[0], gotSvcIDs[0] = rls[0].ID()
+ gotFltIDs[1], gotSvcIDs[1] = rls[1].ID()
+ assert.Equal(t, agd.FilterListIDBlockedService, gotFltIDs[0])
+ assert.Equal(t, agd.FilterListIDBlockedService, gotFltIDs[1])
+ assert.ElementsMatch(t, svcIDs, gotSvcIDs)
+}
diff --git a/internal/filter/refrfilter_internal_test.go b/internal/filter/refrfilter_internal_test.go
deleted file mode 100644
index 8f33eb2..0000000
--- a/internal/filter/refrfilter_internal_test.go
+++ /dev/null
@@ -1,186 +0,0 @@
-package filter
-
-import (
- "context"
- "io"
- "net/http"
- "net/http/httptest"
- "os"
- "testing"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
- "github.com/AdguardTeam/golibs/testutil"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestRefreshableFilter_RefreshFromFile(t *testing.T) {
- dir := t.TempDir()
- f, err := os.CreateTemp(dir, t.Name())
- require.NoError(t, err)
-
- const defaultText = "||example.com\n"
- _, err = io.WriteString(f, defaultText)
- require.NoError(t, err)
-
- cachePath := f.Name()
-
- testCases := []struct {
- name string
- cachePath string
- wantText string
- staleness time.Duration
- acceptStale bool
- }{{
- name: "no_file",
- cachePath: "does_not_exist",
- wantText: "",
- staleness: 0,
- acceptStale: true,
- }, {
- name: "file",
- cachePath: cachePath,
- wantText: defaultText,
- staleness: 0,
- acceptStale: true,
- }, {
- name: "file_stale",
- cachePath: cachePath,
- wantText: "",
- staleness: -1 * time.Second,
- acceptStale: false,
- }, {
- name: "file_stale_accept",
- cachePath: cachePath,
- wantText: defaultText,
- staleness: -1 * time.Second,
- acceptStale: true,
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- f := &refreshableFilter{
- http: nil,
- url: nil,
- id: "test_filter",
- cachePath: tc.cachePath,
- typ: "test filter",
- staleness: tc.staleness,
- }
-
- var text string
- text, err = f.refreshFromFile(tc.acceptStale)
- require.NoError(t, err)
-
- assert.Equal(t, tc.wantText, text)
- })
- }
-}
-
-func TestRefreshableFilter_RefreshFromURL(t *testing.T) {
- const defaultText = "||example.com\n"
-
- codeCh := make(chan int, 1)
- textCh := make(chan string, 1)
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- pt := testutil.PanicT{}
-
- w.WriteHeader(<-codeCh)
-
- _, err := io.WriteString(w, <-textCh)
- require.NoError(pt, err)
- }))
- t.Cleanup(srv.Close)
-
- u, err := agdhttp.ParseHTTPURL(srv.URL)
- require.NoError(t, err)
-
- httpCli := agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: testTimeout,
- })
-
- dir := t.TempDir()
- f, err := os.CreateTemp(dir, t.Name())
- require.NoError(t, err)
-
- _, err = io.WriteString(f, defaultText)
- require.NoError(t, err)
-
- cachePath := f.Name()
-
- testCases := []struct {
- name string
- cachePath string
- text string
- wantText string
- wantErrMsg string
- timeout time.Duration
- code int
- expectReq bool
- }{{
- name: "success",
- cachePath: cachePath,
- text: defaultText,
- wantText: defaultText,
- wantErrMsg: "",
- timeout: testTimeout,
- code: http.StatusOK,
- expectReq: true,
- }, {
- name: "not_found",
- cachePath: cachePath,
- text: defaultText,
- wantText: "",
- wantErrMsg: `server "": status code error: expected 200, got 404`,
- timeout: testTimeout,
- code: http.StatusNotFound,
- expectReq: true,
- }, {
- name: "timeout",
- cachePath: cachePath,
- text: defaultText,
- wantText: "",
- wantErrMsg: `requesting: Get "` + u.String() + `": context deadline exceeded`,
- timeout: 0,
- code: http.StatusOK,
- // Context deadline errors are returned before any actual HTTP
- // requesting happens.
- expectReq: false,
- }, {
- name: "empty",
- cachePath: cachePath,
- text: "",
- wantText: "",
- wantErrMsg: `server "": empty text, not resetting`,
- timeout: testTimeout,
- code: http.StatusOK,
- expectReq: true,
- }}
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- f := &refreshableFilter{
- http: httpCli,
- url: u,
- id: "test_filter",
- cachePath: tc.cachePath,
- typ: "test filter",
- staleness: testTimeout,
- }
-
- if tc.expectReq {
- codeCh <- tc.code
- textCh <- tc.text
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
- defer cancel()
-
- var text string
- text, err = f.refreshFromURL(ctx)
- assert.Equal(t, tc.wantText, text)
- testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
- })
- }
-}
diff --git a/internal/filter/rulelist.go b/internal/filter/rulelist.go
deleted file mode 100644
index ba67be2..0000000
--- a/internal/filter/rulelist.go
+++ /dev/null
@@ -1,239 +0,0 @@
-package filter
-
-import (
- "context"
- "fmt"
- "math/rand"
- "net/netip"
- "path/filepath"
- "sync"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
- "github.com/AdguardTeam/golibs/log"
- "github.com/AdguardTeam/urlfilter"
- "github.com/AdguardTeam/urlfilter/filterlist"
-)
-
-// Rule-list filter
-
-// ruleListFilter is a DNS request and response filter based on filter rule
-// lists.
-//
-// TODO(a.garipov): Consider adding a separate version that uses a single engine
-// for multiple rule lists and using it to optimize the filtering using default
-// filtering groups.
-type ruleListFilter struct {
- // mu protects engine and ruleList.
- mu *sync.RWMutex
-
- // refr contains data for refreshing the filter.
- refr *refreshableFilter
-
- // engine is the DNS filtering engine.
- engine *urlfilter.DNSEngine
-
- // dnsResCache contains cached results of filtering.
- //
- // TODO(ameshkov): Add metrics for these caches.
- dnsResCache *resultcache.Cache[*urlfilter.DNSResult]
-
- // ruleList is the filtering rule ruleList used by the engine.
- //
- // TODO(a.garipov): Consider making engines in module urlfilter closeable
- // and remove this crutch.
- ruleList filterlist.RuleList
-
- // subID is the additional identifier for custom rule lists and blocked
- // service lists. If id is [agd.FilterListIDBlockedService], it contains an
- // [agd.BlockedServiceID], and if id is [agd.FilterListIDCustom], it
- // contains an [agd.ProfileID].
- subID string
-
- // urlFilterID is the synthetic integer identifier for the urlfilter engine.
- //
- // TODO(a.garipov): Change the type to a string in module urlfilter and
- // remove this crutch.
- urlFilterID int
-}
-
-// newRuleListFilter returns a new DNS request and response filter based on the
-// provided rule list. l must be non-nil. The initial refresh should be called
-// explicitly if necessary.
-func newRuleListFilter(
- l *agd.FilterList,
- fileCacheDir string,
- cacheSize int,
- useCache bool,
-) (flt *ruleListFilter) {
- flt = &ruleListFilter{
- mu: &sync.RWMutex{},
- refr: &refreshableFilter{
- http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultFilterRefreshTimeout,
- }),
- url: l.URL,
- id: l.ID,
- cachePath: filepath.Join(fileCacheDir, string(l.ID)),
- typ: "rule list",
- staleness: l.RefreshIvl,
- },
- urlFilterID: newURLFilterID(),
- }
-
- if useCache {
- flt.dnsResCache = resultcache.New[*urlfilter.DNSResult](cacheSize)
- }
-
- // Do not set this in the literal above, since flt is nil there.
- flt.refr.resetRules = flt.resetRules
-
- return flt
-}
-
-// newRuleListFltFromStr returns a new DNS request and response filter using the
-// provided rule text and ID.
-func newRuleListFltFromStr(
- text string,
- id agd.FilterListID,
- subID string,
- cacheSize int,
- useCache bool,
-) (flt *ruleListFilter, err error) {
- flt = &ruleListFilter{
- mu: &sync.RWMutex{},
- refr: &refreshableFilter{
- id: id,
- typ: "rule list",
- },
- subID: subID,
- urlFilterID: newURLFilterID(),
- }
-
- if useCache {
- flt.dnsResCache = resultcache.New[*urlfilter.DNSResult](cacheSize)
- }
-
- err = flt.resetRules(text)
- if err != nil {
- return nil, err
- }
-
- return flt, nil
-}
-
-// newURLFilterID returns a new random ID for the urlfilter DNS engine to use.
-func newURLFilterID() (id int) {
- // #nosec G404 -- Do not use cryptographically random ID generation, since
- // these are only used in one place, compFilter.filterMsg, and are not used
- // in any security-sensitive context.
- //
- // Despite the fact that the type of integer filter list IDs in module
- // urlfilter is int, the module actually assumes that the ID is
- // a non-negative integer, or at least not a largely negative one.
- // Otherwise, some of its low-level optimizations seem to break.
- return int(rand.Int31())
-}
-
-// dnsResult returns the result of applying the urlfilter DNS filtering engine.
-// If the request is not filtered, dnsResult returns nil.
-func (f *ruleListFilter) dnsResult(
- cliIP netip.Addr,
- cliName string,
- host string,
- rrType dnsmsg.RRType,
- isAns bool,
-) (dr *urlfilter.DNSResult) {
- var ok bool
- var cacheKey resultcache.Key
-
- // Don't waste resources on computing the cache key if the cache is not
- // enabled.
- useCache := f.dnsResCache != nil
- if useCache {
- cacheKey = resultcache.DefaultKey(host, rrType, isAns)
- dr, ok = f.dnsResCache.Get(cacheKey)
- if ok {
- return dr
- }
- }
-
- dnsReq := &urlfilter.DNSRequest{
- Hostname: host,
- // TODO(a.garipov): Make this a net.IP in module urlfilter.
- ClientIP: cliIP.String(),
- ClientName: cliName,
- DNSType: rrType,
- Answer: isAns,
- }
-
- f.mu.RLock()
- defer f.mu.RUnlock()
-
- dr, ok = f.engine.MatchRequest(dnsReq)
- if !ok && len(dr.NetworkRules) == 0 {
- dr = nil
- }
-
- if useCache {
- f.dnsResCache.Set(cacheKey, dr)
- }
-
- return dr
-}
-
-// Close implements the [io.Closer] interface for *ruleListFilter.
-func (f *ruleListFilter) Close() (err error) {
- f.mu.Lock()
- defer f.mu.Unlock()
-
- if err = f.ruleList.Close(); err != nil {
- return fmt.Errorf("closing rule list %q: %w", f.id(), err)
- }
-
- return nil
-}
-
-// id returns the ID of the rule list.
-func (f *ruleListFilter) id() (fltID agd.FilterListID) {
- return f.refr.id
-}
-
-// refresh reloads the rule list data. If acceptStale is true, do not try to
-// load the list from its URL when there is already a file in the cache
-// directory, regardless of its staleness.
-func (f *ruleListFilter) refresh(ctx context.Context, acceptStale bool) (err error) {
- return f.refr.refresh(ctx, acceptStale)
-}
-
-// resetRules resets the filtering rules.
-func (f *ruleListFilter) resetRules(text string) (err error) {
- // TODO(a.garipov): Add filterlist.BytesRuleList.
- strList := &filterlist.StringRuleList{
- ID: f.urlFilterID,
- RulesText: text,
- IgnoreCosmetic: true,
- }
-
- s, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
- if err != nil {
- return fmt.Errorf("creating list storage: %w", err)
- }
-
- f.mu.Lock()
- defer f.mu.Unlock()
-
- f.dnsResCache.Clear()
- f.ruleList = strList
- f.engine = urlfilter.NewDNSEngine(s)
-
- if f.subID != "" {
- log.Info("filter %s/%s: reset %d rules", f.id(), f.subID, f.engine.RulesCount)
- } else {
- log.Info("filter %s: reset %d rules", f.id(), f.engine.RulesCount)
- }
-
- return nil
-}
diff --git a/internal/filter/rulelist_internal_test.go b/internal/filter/rulelist_internal_test.go
deleted file mode 100644
index 9347a93..0000000
--- a/internal/filter/rulelist_internal_test.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package filter
-
-import (
- "testing"
-
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestRuleListFilter_dnsResult_cache(t *testing.T) {
- rl, err := newRuleListFltFromStr("||example.com^", "fl1", "", 100, true)
- require.NoError(t, err)
-
- t.Run("blocked", func(t *testing.T) {
- dr := rl.dnsResult(testRemoteIP, "", testReqHost, dns.TypeA, false)
- require.NotNil(t, dr)
-
- assert.Len(t, dr.NetworkRules, 1)
-
- cachedDR := rl.dnsResult(testRemoteIP, "", testReqHost, dns.TypeA, false)
- require.NotNil(t, cachedDR)
-
- assert.Same(t, dr, cachedDR)
- })
-
- t.Run("none", func(t *testing.T) {
- const otherHost = "other.example"
-
- dr := rl.dnsResult(testRemoteIP, "", otherHost, dns.TypeA, false)
- assert.Nil(t, dr)
-
- cachedDR := rl.dnsResult(testRemoteIP, "", otherHost, dns.TypeA, false)
- assert.Nil(t, cachedDR)
- })
-}
diff --git a/internal/filter/rulelist_test.go b/internal/filter/rulelist_test.go
deleted file mode 100644
index a5217e1..0000000
--- a/internal/filter/rulelist_test.go
+++ /dev/null
@@ -1,404 +0,0 @@
-package filter_test
-
-import (
- "context"
- "testing"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/golibs/testutil"
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TODO(a.garipov): Try to turn these into table-driven tests.
-
-func TestStorage_FilterFromContext_ruleList_request(t *testing.T) {
- c := prepareConf(t)
-
- c.ErrColl = &agdtest.ErrorCollector{
- OnCollect: func(_ context.Context, err error) { panic("not implemented") },
- }
-
- s, err := filter.NewDefaultStorage(c)
- require.NoError(t, err)
-
- g := &agd.FilteringGroup{
- ID: "default",
- RuleListIDs: []agd.FilterListID{testFilterID},
- RuleListsEnabled: true,
- }
-
- p := &agd.Profile{
- RuleListIDs: []agd.FilterListID{testFilterID},
- FilteringEnabled: true,
- RuleListsEnabled: true,
- }
-
- t.Run("blocked", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: blockedFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, blockedHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedHost)
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: allowedFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, allowedHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedHost)
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("blocked_client", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: blockedClientFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, blockedClientHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedClientHost)
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed_client", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: allowedClientFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, allowedClientHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedClientHost)
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("blocked_device", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: blockedDeviceFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, p, blockedDeviceHost, deviceIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedDeviceHost)
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed_device", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: allowedDeviceFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, p, allowedDeviceHost, deviceIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedDeviceHost)
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("none", func(t *testing.T) {
- req := &dns.Msg{
- Question: []dns.Question{{
- Name: otherNetFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }},
- }
-
- ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- assert.Nil(t, r)
- })
-}
-
-func TestStorage_FilterFromContext_ruleList_response(t *testing.T) {
- c := prepareConf(t)
-
- c.ErrColl = &agdtest.ErrorCollector{
- OnCollect: func(_ context.Context, err error) { panic("not implemented") },
- }
-
- s, err := filter.NewDefaultStorage(c)
- require.NoError(t, err)
-
- g := &agd.FilteringGroup{
- ID: "default",
- RuleListIDs: []agd.FilterListID{testFilterID},
- RuleListsEnabled: true,
- }
-
- ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- question := []dns.Question{{
- Name: otherNetFQDN,
- Qtype: dns.TypeA,
- Qclass: dns.ClassINET,
- }}
-
- t.Run("blocked_a", func(t *testing.T) {
- resp := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.A{
- A: blockedIP4,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, resp, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedIP4.String())
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed_a", func(t *testing.T) {
- resp := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.A{
- A: allowedIP4,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, resp, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedIP4.String())
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("blocked_cname", func(t *testing.T) {
- resp := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: blockedFQDN,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, resp, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedHost)
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed_cname", func(t *testing.T) {
- resp := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: allowedFQDN,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, resp, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedHost)
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("blocked_client", func(t *testing.T) {
- resp := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: blockedClientFQDN,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, resp, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, blockedClientHost)
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("allowed_client", func(t *testing.T) {
- req := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: allowedClientFQDN,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, req, ri)
- require.NoError(t, err)
-
- ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
-
- assert.Contains(t, ra.Rule, allowedClientHost)
- assert.Equal(t, ra.List, testFilterID)
- })
-
- t.Run("exception_cname", func(t *testing.T) {
- req := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: "cname.exception.",
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, req, ri)
- require.NoError(t, err)
-
- assert.Nil(t, r)
- })
-
- t.Run("exception_cname_blocked", func(t *testing.T) {
- req := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: "cname.blocked.",
- }},
- }
-
- var r filter.Result
- r, err = f.FilterResponse(ctx, req, ri)
- require.NoError(t, err)
-
- rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
-
- assert.Contains(t, rb.Rule, "cname.blocked")
- assert.Equal(t, rb.List, testFilterID)
- })
-
- t.Run("none", func(t *testing.T) {
- req := &dns.Msg{
- Question: question,
- Answer: []dns.RR{&dns.CNAME{
- Target: otherOrgFQDN,
- }},
- }
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- assert.Nil(t, r)
- })
-}
diff --git a/internal/filter/safebrowsing.go b/internal/filter/safebrowsing.go
deleted file mode 100644
index 2087e0a..0000000
--- a/internal/filter/safebrowsing.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package filter
-
-import (
- "context"
- "encoding/hex"
- "fmt"
- "strings"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
- "github.com/AdguardTeam/golibs/log"
- "github.com/AdguardTeam/golibs/stringutil"
-)
-
-// Safe Browsing TXT Record Server
-
-// SafeBrowsingServer is a safe browsing server that responds to TXT DNS queries
-// to known domains.
-//
-// TODO(a.garipov): Consider making an interface to simplify testing.
-type SafeBrowsingServer struct {
- generalHashes *hashstorage.Storage
- adultBlockingHashes *hashstorage.Storage
-}
-
-// NewSafeBrowsingServer returns a new safe browsing DNS server.
-func NewSafeBrowsingServer(general, adultBlocking *hashstorage.Storage) (f *SafeBrowsingServer) {
- return &SafeBrowsingServer{
- generalHashes: general,
- adultBlockingHashes: adultBlocking,
- }
-}
-
-// Default safe browsing host suffixes.
-//
-// TODO(ameshkov): Consider making these configurable.
-const (
- GeneralTXTSuffix = ".sb.dns.adguard.com"
- AdultBlockingTXTSuffix = ".pc.dns.adguard.com"
-)
-
-// Hashes returns the matched hashes if the host matched one of the domain names
-// in srv.
-//
-// TODO(a.garipov): Use the context for logging etc.
-func (srv *SafeBrowsingServer) Hashes(
- _ context.Context,
- host string,
-) (hashes []string, matched bool, err error) {
- // TODO(a.garipov): Remove this if SafeBrowsingServer becomes an interface.
- if srv == nil {
- return nil, false, nil
- }
-
- var prefixesStr string
- var strg *hashstorage.Storage
- if strings.HasSuffix(host, GeneralTXTSuffix) {
- prefixesStr = host[:len(host)-len(GeneralTXTSuffix)]
- strg = srv.generalHashes
- } else if strings.HasSuffix(host, AdultBlockingTXTSuffix) {
- prefixesStr = host[:len(host)-len(AdultBlockingTXTSuffix)]
- strg = srv.adultBlockingHashes
- } else {
- return nil, false, nil
- }
-
- log.Debug("safe browsing txt srv: got prefixes string %q", prefixesStr)
-
- hashPrefixes, err := hashPrefixesFromStr(prefixesStr)
- if err != nil {
- return nil, false, err
- }
-
- return strg.Hashes(hashPrefixes), true, nil
-}
-
-// legacyPrefixEncLen is the encoded length of a legacy hash.
-const legacyPrefixEncLen = 8
-
-// hashPrefixesFromStr returns hash prefixes from a dot-separated string.
-func hashPrefixesFromStr(prefixesStr string) (hashPrefixes []hashstorage.Prefix, err error) {
- if prefixesStr == "" {
- return nil, nil
- }
-
- prefixSet := stringutil.NewSet()
- prefixStrs := strings.Split(prefixesStr, ".")
- for _, s := range prefixStrs {
- if len(s) != hashstorage.PrefixEncLen {
- // Some legacy clients send eight-character hashes instead of
- // four-character ones. For now, remove the final four characters.
- //
- // TODO(a.garipov): Either remove this crutch or support such
- // prefixes better.
- if len(s) == legacyPrefixEncLen {
- s = s[:hashstorage.PrefixEncLen]
- } else {
- return nil, fmt.Errorf("bad hash len for %q", s)
- }
- }
-
- prefixSet.Add(s)
- }
-
- hashPrefixes = make([]hashstorage.Prefix, prefixSet.Len())
- prefixStrs = prefixSet.Values()
- for i, s := range prefixStrs {
- _, err = hex.Decode(hashPrefixes[i][:], []byte(s))
- if err != nil {
- return nil, fmt.Errorf("bad hash encoding for %q", s)
- }
- }
-
- return hashPrefixes, nil
-}
diff --git a/internal/filter/safesearch.go b/internal/filter/safesearch.go
deleted file mode 100644
index f691736..0000000
--- a/internal/filter/safesearch.go
+++ /dev/null
@@ -1,191 +0,0 @@
-package filter
-
-import (
- "context"
- "fmt"
- "net"
- "time"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/resultcache"
- "github.com/AdguardTeam/AdGuardDNS/internal/optlog"
- "github.com/AdguardTeam/golibs/netutil"
- "github.com/AdguardTeam/urlfilter"
- "github.com/miekg/dns"
-)
-
-// Safe search
-
-// safeSearch is a filter that enforces safe search.
-type safeSearch struct {
- // resCache contains cached results.
- resCache *resultcache.Cache[*ResultModified]
-
- // flt is used to filter requests.
- flt *ruleListFilter
-
- // resolver resolves IP addresses.
- resolver agdnet.Resolver
-
- // errColl is used to report rare errors.
- errColl agd.ErrorCollector
-}
-
-// safeSearchConfig contains configuration for the safe search filter.
-type safeSearchConfig struct {
- list *agd.FilterList
- resolver agdnet.Resolver
- errColl agd.ErrorCollector
- cacheDir string
- //lint:ignore U1000 TODO(a.garipov): Currently unused. See AGDNS-398.
- ttl time.Duration
- cacheSize int
-}
-
-// newSafeSearch returns a new safe search filter. c must not be nil. The
-// initial refresh should be called explicitly if necessary.
-func newSafeSearch(c *safeSearchConfig) (f *safeSearch) {
- return &safeSearch{
- resCache: resultcache.New[*ResultModified](c.cacheSize),
- // Don't use the rule list cache, since safeSearch already has its own.
- flt: newRuleListFilter(c.list, c.cacheDir, 0, false),
- resolver: c.resolver,
- errColl: c.errColl,
- }
-}
-
-// type check
-var _ qtHostFilter = (*safeSearch)(nil)
-
-// filterReq implements the qtHostFilter interface for *safeSearch. It modifies
-// the response if host matches f.
-func (f *safeSearch) filterReq(
- ctx context.Context,
- ri *agd.RequestInfo,
- req *dns.Msg,
-) (r Result, err error) {
- qt := ri.QType
- fam := netutil.AddrFamilyFromRRType(qt)
- if fam == netutil.AddrFamilyNone {
- return nil, nil
- }
-
- host := ri.Host
- cacheKey := resultcache.DefaultKey(host, qt, false)
- repHost, ok := f.safeSearchHost(host, qt)
- if !ok {
- optlog.Debug2("filter %s: host %q is not on the list", f.flt.id(), host)
-
- f.resCache.Set(cacheKey, nil)
-
- return nil, nil
- }
-
- optlog.Debug2("filter %s: found host %q", f.flt.id(), repHost)
-
- rm, ok := f.resCache.Get(cacheKey)
- if ok {
- if rm == nil {
- // Return nil explicitly instead of modifying CloneForReq to return
- // nil if the result is nil to avoid a “non-nil nil” value.
- return nil, nil
- }
-
- return rm.CloneForReq(req), nil
- }
-
- ctx, cancel := context.WithTimeout(ctx, defaultResolveTimeout)
- defer cancel()
-
- var result *dns.Msg
- ips, err := f.resolver.LookupIP(ctx, fam, repHost)
- if err != nil {
- agd.Collectf(ctx, f.errColl, "filter %s: resolving: %w", f.flt.id(), err)
-
- result = ri.Messages.NewMsgSERVFAIL(req)
- } else {
- result, err = ri.Messages.NewIPRespMsg(req, ips...)
- if err != nil {
- return nil, fmt.Errorf("filter %s: creating modified result: %w", f.flt.id(), err)
- }
- }
-
- rm = &ResultModified{
- Msg: result,
- List: f.flt.id(),
- Rule: agd.FilterRuleText(host),
- }
-
- // Copy the result to make sure that modifications to the result message
- // down the pipeline don't interfere with the cached value.
- //
- // See AGDNS-359.
- f.resCache.Set(cacheKey, rm.Clone())
-
- return rm, nil
-}
-
-// safeSearchHost returns the replacement host for the given host and question
-// type, if any. qt should be either [dns.TypeA] or [dns.TypeAAAA].
-func (f *safeSearch) safeSearchHost(host string, qt dnsmsg.RRType) (ssHost string, ok bool) {
- dnsReq := &urlfilter.DNSRequest{
- Hostname: host,
- DNSType: qt,
- Answer: false,
- }
-
- f.flt.mu.RLock()
- defer f.flt.mu.RUnlock()
-
- // Omit matching the result since it's always false for rewrite rules.
- dr, _ := f.flt.engine.MatchRequest(dnsReq)
- if dr == nil {
- return "", false
- }
-
- for _, nr := range dr.DNSRewrites() {
- drw := nr.DNSRewrite
- if drw.RCode != dns.RcodeSuccess {
- continue
- }
-
- if nc := drw.NewCNAME; nc != "" {
- return nc, true
- }
-
- // All the rules in safe search rule lists are expected to have either
- // A/AAAA or CNAME type.
- switch drw.RRType {
- case dns.TypeA, dns.TypeAAAA:
- return drw.Value.(net.IP).String(), true
- default:
- continue
- }
- }
-
- return "", false
-}
-
-// name implements the qtHostFilter interface for *safeSearch.
-func (f *safeSearch) name() (n string) {
- if f == nil || f.flt == nil {
- return ""
- }
-
- return string(f.flt.id())
-}
-
-// refresh reloads the rule list data. If acceptStale is true, and the cache
-// file exists, the data is read from there regardless of its staleness.
-func (f *safeSearch) refresh(ctx context.Context, acceptStale bool) (err error) {
- err = f.flt.refresh(ctx, acceptStale)
- if err != nil {
- return err
- }
-
- f.resCache.Clear()
-
- return nil
-}
diff --git a/internal/filter/safesearch_test.go b/internal/filter/safesearch_test.go
deleted file mode 100644
index a295e81..0000000
--- a/internal/filter/safesearch_test.go
+++ /dev/null
@@ -1,137 +0,0 @@
-package filter_test
-
-import (
- "context"
- "net"
- "testing"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
- "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/golibs/netutil"
- "github.com/AdguardTeam/golibs/testutil"
- "github.com/miekg/dns"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
- numLookupIP := 0
- resolver := &agdtest.Resolver{
- OnLookupIP: func(
- _ context.Context,
- fam netutil.AddrFamily,
- _ string,
- ) (ips []net.IP, err error) {
- numLookupIP++
-
- if fam == netutil.AddrFamilyIPv4 {
- return []net.IP{safeSearchIPRespIP4}, nil
- }
-
- return []net.IP{safeSearchIPRespIP6}, nil
- },
- }
-
- c := prepareConf(t)
-
- c.ErrColl = &agdtest.ErrorCollector{
- OnCollect: func(_ context.Context, err error) { panic("not implemented") },
- }
-
- c.Resolver = resolver
-
- s, err := filter.NewDefaultStorage(c)
- require.NoError(t, err)
-
- g := &agd.FilteringGroup{
- ID: "default",
- ParentalEnabled: true,
- GeneralSafeSearch: true,
- }
-
- testCases := []struct {
- name string
- host string
- wantIP net.IP
- rrtype uint16
- wantLookups int
- }{{
- name: "ip4",
- host: safeSearchIPHost,
- wantIP: safeSearchIPRespIP4,
- rrtype: dns.TypeA,
- wantLookups: 1,
- }, {
- name: "ip6",
- host: safeSearchIPHost,
- wantIP: safeSearchIPRespIP6,
- rrtype: dns.TypeAAAA,
- wantLookups: 1,
- }, {
- name: "host_ip4",
- host: safeSearchHost,
- wantIP: safeSearchIPRespIP4,
- rrtype: dns.TypeA,
- wantLookups: 1,
- }, {
- name: "host_ip6",
- host: safeSearchHost,
- wantIP: safeSearchIPRespIP6,
- rrtype: dns.TypeAAAA,
- wantLookups: 1,
- }}
-
- for _, tc := range testCases {
- numLookupIP = 0
- req := dnsservertest.CreateMessage(tc.host, tc.rrtype)
-
- t.Run(tc.name, func(t *testing.T) {
- ri := newReqInfo(g, nil, tc.host, clientIP, tc.rrtype)
- ctx := agd.ContextWithRequestInfo(context.Background(), ri)
-
- f := s.FilterFromContext(ctx, ri)
- require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
-
- var r filter.Result
- r, err = f.FilterRequest(ctx, req, ri)
- require.NoError(t, err)
-
- assert.Equal(t, tc.wantLookups, numLookupIP)
-
- rm, ok := r.(*filter.ResultModified)
- require.True(t, ok)
-
- assert.Contains(t, rm.Rule, tc.host)
- assert.Equal(t, rm.List, agd.FilterListIDGeneralSafeSearch)
-
- res := rm.Msg
- require.NotNil(t, res)
-
- if tc.wantIP == nil {
- assert.Nil(t, res.Answer)
-
- return
- }
-
- require.Len(t, res.Answer, 1)
-
- switch tc.rrtype {
- case dns.TypeA:
- a, aok := res.Answer[0].(*dns.A)
- require.True(t, aok)
-
- assert.Equal(t, tc.wantIP, a.A)
- case dns.TypeAAAA:
- aaaa, aaaaok := res.Answer[0].(*dns.AAAA)
- require.True(t, aaaaok)
-
- assert.Equal(t, tc.wantIP, aaaa.AAAA)
- default:
- t.Fatalf("unexpected question type %d", tc.rrtype)
- }
- })
- }
-}
diff --git a/internal/filter/serviceblocker.go b/internal/filter/serviceblocker.go
deleted file mode 100644
index ccb1f60..0000000
--- a/internal/filter/serviceblocker.go
+++ /dev/null
@@ -1,223 +0,0 @@
-package filter
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "net/url"
- "strings"
- "sync"
-
- "github.com/AdguardTeam/AdGuardDNS/internal/agd"
- "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
- "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
- "github.com/AdguardTeam/golibs/errors"
- "github.com/AdguardTeam/golibs/log"
- "github.com/prometheus/client_golang/prometheus"
-)
-
-// Service Blocking Filter
-
-// serviceBlocker is a filter that blocks services based on the settings in
-// profile.
-//
-// TODO(a.garipov): Add tests.
-type serviceBlocker struct {
- // url is the URL from which the services are fetched.
- url *url.URL
-
- // http is the HTTP client used to refresh the filter.
- http *agdhttp.Client
-
- // mu protects services.
- mu *sync.RWMutex
-
- // services is an ID to filter mapping.
- services serviceRuleLists
-
- // errColl used to collect non-critical and rare errors.
- errColl agd.ErrorCollector
-}
-
-// serviceRuleLists is convenient alias for a ID to filter mapping.
-type serviceRuleLists = map[agd.BlockedServiceID]*ruleListFilter
-
-// newServiceBlocker returns a fully initialized service blocker.
-func newServiceBlocker(indexURL *url.URL, errColl agd.ErrorCollector) (b *serviceBlocker) {
- return &serviceBlocker{
- url: indexURL,
- http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultFilterRefreshTimeout,
- }),
- mu: &sync.RWMutex{},
- errColl: errColl,
- }
-}
-
-// ruleLists returns the rule list filters for the given blocked service IDs.
-// The order of the elements in rls is undefined.
-func (b *serviceBlocker) ruleLists(ids []agd.BlockedServiceID) (rls []*ruleListFilter) {
- if len(ids) == 0 {
- return nil
- }
-
- b.mu.RLock()
- defer b.mu.RUnlock()
-
- for _, id := range ids {
- rl := b.services[id]
- if rl == nil {
- log.Info("service filter: no service with id %q", id)
-
- continue
- }
-
- rls = append(rls, rl)
- }
-
- return rls
-}
-
-// refresh loads new service data from the index URL.
-func (b *serviceBlocker) refresh(
- ctx context.Context,
- cacheSize int,
- useCache bool,
-) (err error) {
- // Report the services update to prometheus.
- promLabels := prometheus.Labels{
- "filter": string(agd.FilterListIDBlockedService),
- }
-
- defer func() {
- if err != nil {
- agd.Collectf(ctx, b.errColl, "refreshing blocked services: %w", err)
- metrics.FilterUpdatedStatus.With(promLabels).Set(0)
- }
- }()
-
- resp, err := b.loadIndex(ctx)
- if err != nil {
- // Don't wrap the error, because it's informative enough as is.
- return err
- }
-
- services, err := resp.toInternal(cacheSize, useCache)
- if err != nil {
- // Don't wrap the error, because it's informative enough as is.
- return err
- }
-
- b.mu.Lock()
- defer b.mu.Unlock()
-
- b.services = services
-
- count := 0
- for _, s := range services {
- count += s.engine.RulesCount
- }
- metrics.FilterRulesTotal.With(promLabels).Set(float64(count))
- metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
- metrics.FilterUpdatedStatus.With(promLabels).Set(1)
-
- return nil
-}
-
-// loadIndex fetches, decodes, and returns the blocked service index data.
-func (b *serviceBlocker) loadIndex(ctx context.Context) (resp *svcIndexResp, err error) {
- defer func() { err = errors.Annotate(err, "loading blocked service index from %q: %w", b.url) }()
-
- httpResp, err := b.http.Get(ctx, b.url)
- if err != nil {
- return nil, fmt.Errorf("requesting: %w", err)
- }
- defer func() { err = errors.WithDeferred(err, httpResp.Body.Close()) }()
-
- err = agdhttp.CheckStatus(httpResp, http.StatusOK)
- if err != nil {
- // Don't wrap the error, because it's informative enough as is.
- return nil, err
- }
-
- resp = &svcIndexResp{}
- err = json.NewDecoder(httpResp.Body).Decode(resp)
- if err != nil {
- return nil, agdhttp.WrapServerError(
- fmt.Errorf("decoding: %w", err),
- httpResp,
- )
- }
-
- log.Debug("service filter: loaded index with %d blocked services", len(resp.BlockedServices))
-
- return resp, nil
-}
-
-// svcIndexResp is the struct for the JSON response from a blocked service index
-// API.
-type svcIndexResp struct {
- BlockedServices []*svcIndexRespService `json:"blocked_services"`
-}
-
-// toInternal converts the services from the index to serviceRuleLists.
-func (r *svcIndexResp) toInternal(
- cacheSize int,
- useCache bool,
-) (services serviceRuleLists, err error) {
- l := len(r.BlockedServices)
- if l == 0 {
- return nil, nil
- }
-
- services = make(serviceRuleLists, l)
- errs := make([]error, len(r.BlockedServices))
- for i, svc := range r.BlockedServices {
- var id agd.BlockedServiceID
- id, err = agd.NewBlockedServiceID(svc.ID)
- if err != nil {
- errs[i] = fmt.Errorf("service at index %d: validating id: %w", i, err)
-
- continue
- }
-
- if len(svc.Rules) == 0 {
- log.Info("service filter: no rules for service with id %s", id)
-
- continue
- }
-
- text := strings.Join(svc.Rules, "\n")
-
- var rl *ruleListFilter
- rl, err = newRuleListFltFromStr(
- text,
- agd.FilterListIDBlockedService,
- svc.ID,
- cacheSize,
- useCache,
- )
- if err != nil {
- errs[i] = fmt.Errorf("compiling %s: %w", svc.ID, err)
-
- continue
- }
-
- services[id] = rl
- }
-
- err = errors.Join(errs...)
- if err != nil {
- return nil, fmt.Errorf("converting blocked services: %w", err)
- }
-
- return services, nil
-}
-
-// svcIndexRespService is the struct for a filter from the JSON response from
-// a blocked service index API.
-type svcIndexRespService struct {
- ID string `json:"id"`
- Rules []string `json:"rules"`
-}
diff --git a/internal/filter/storage.go b/internal/filter/storage.go
index 3a3b114..d4f9fe3 100644
--- a/internal/filter/storage.go
+++ b/internal/filter/storage.go
@@ -12,11 +12,17 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdnet"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/composite"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/custom"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/safesearch"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/serviceblock"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache"
- "github.com/prometheus/client_golang/prometheus"
)
// Filter storage
@@ -48,22 +54,22 @@ type DefaultStorage struct {
http *agdhttp.Client
// ruleLists are the filter list ID to a rule list filter map.
- ruleLists map[agd.FilterListID]*ruleListFilter
+ ruleLists map[agd.FilterListID]*rulelist.Refreshable
// services is the service blocking filter.
- services *serviceBlocker
+ services *serviceblock.Filter
// safeBrowsing is the general safe browsing filter.
- safeBrowsing *HashPrefix
+ safeBrowsing *hashprefix.Filter
// adultBlocking is the adult content blocking safe browsing filter.
- adultBlocking *HashPrefix
+ adultBlocking *hashprefix.Filter
// genSafeSearch is the general safe search filter.
- genSafeSearch *safeSearch
+ genSafeSearch *safesearch.Filter
// ytSafeSearch is the YouTube safe search filter.
- ytSafeSearch *safeSearch
+ ytSafeSearch *safesearch.Filter
// now returns the current time.
now func() (t time.Time)
@@ -73,7 +79,7 @@ type DefaultStorage struct {
errColl agd.ErrorCollector
// customFilters is the storage of custom filters for profiles.
- customFilters *customFilters
+ customFilters *custom.Filters
// cacheDir is the path to the directory where the cached filter files are
// put. The directory must exist.
@@ -84,7 +90,7 @@ type DefaultStorage struct {
refreshIvl time.Duration
// RuleListCacheSize defines the size of the LRU cache of rule-list
- // filteirng results.
+ // filtering results.
ruleListCacheSize int
// useRuleListCache, if true, enables rule list cache.
@@ -110,11 +116,11 @@ type DefaultStorageConfig struct {
// SafeBrowsing is the configuration for the default safe browsing filter.
// It must not be nil.
- SafeBrowsing *HashPrefix
+ SafeBrowsing *hashprefix.Filter
// AdultBlocking is the configuration for the adult content blocking safe
// browsing filter. It must not be nil.
- AdultBlocking *HashPrefix
+ AdultBlocking *hashprefix.Filter
// Now is a function that returns current time.
Now func() (now time.Time)
@@ -142,7 +148,7 @@ type DefaultStorageConfig struct {
SafeSearchCacheTTL time.Duration
// RuleListCacheSize defines the size of the LRU cache of rule-list
- // filteirng results.
+ // filtering results.
RuleListCacheSize int
// RefreshIvl is the refresh interval for this storage. It defines how
@@ -155,49 +161,49 @@ type DefaultStorageConfig struct {
// NewDefaultStorage returns a new filter storage. c must not be nil.
func NewDefaultStorage(c *DefaultStorageConfig) (s *DefaultStorage, err error) {
- genSafeSearch := newSafeSearch(&safeSearchConfig{
- resolver: c.Resolver,
- errColl: c.ErrColl,
- list: &agd.FilterList{
+ genSafeSearch := safesearch.New(&safesearch.Config{
+ List: &agd.FilterList{
URL: c.GeneralSafeSearchRulesURL,
ID: agd.FilterListIDGeneralSafeSearch,
RefreshIvl: c.RefreshIvl,
},
- cacheDir: c.CacheDir,
- ttl: c.SafeSearchCacheTTL,
- cacheSize: c.SafeSearchCacheSize,
+ Resolver: c.Resolver,
+ ErrColl: c.ErrColl,
+ CacheDir: c.CacheDir,
+ CacheTTL: c.SafeSearchCacheTTL,
+ CacheSize: c.SafeSearchCacheSize,
})
- ytSafeSearch := newSafeSearch(&safeSearchConfig{
- resolver: c.Resolver,
- errColl: c.ErrColl,
- list: &agd.FilterList{
+ ytSafeSearch := safesearch.New(&safesearch.Config{
+ List: &agd.FilterList{
URL: c.YoutubeSafeSearchRulesURL,
ID: agd.FilterListIDYoutubeSafeSearch,
RefreshIvl: c.RefreshIvl,
},
- cacheDir: c.CacheDir,
- ttl: c.SafeSearchCacheTTL,
- cacheSize: c.SafeSearchCacheSize,
+ Resolver: c.Resolver,
+ ErrColl: c.ErrColl,
+ CacheDir: c.CacheDir,
+ CacheTTL: c.SafeSearchCacheTTL,
+ CacheSize: c.SafeSearchCacheSize,
})
s = &DefaultStorage{
mu: &sync.RWMutex{},
url: c.FilterIndexURL,
http: agdhttp.NewClient(&agdhttp.ClientConfig{
- Timeout: defaultFilterRefreshTimeout,
+ Timeout: internal.DefaultFilterRefreshTimeout,
}),
- services: newServiceBlocker(c.BlockedServiceIndexURL, c.ErrColl),
+ services: serviceblock.New(c.BlockedServiceIndexURL, c.ErrColl),
safeBrowsing: c.SafeBrowsing,
adultBlocking: c.AdultBlocking,
genSafeSearch: genSafeSearch,
ytSafeSearch: ytSafeSearch,
now: c.Now,
errColl: c.ErrColl,
- customFilters: &customFilters{
- cache: gcache.New(c.CustomFilterCacheSize).LRU().Build(),
- errColl: c.ErrColl,
- },
+ customFilters: custom.New(
+ gcache.New(c.CustomFilterCacheSize).LRU().Build(),
+ c.ErrColl,
+ ),
cacheDir: c.CacheDir,
refreshIvl: c.RefreshIvl,
ruleListCacheSize: c.RuleListCacheSize,
@@ -221,71 +227,63 @@ func (s *DefaultStorage) FilterFromContext(ctx context.Context, ri *agd.RequestI
return s.filterForProfile(ctx, ri)
}
- flt := &compFilter{}
+ c := &composite.Config{}
g := ri.FilteringGroup
if g.RuleListsEnabled {
- flt.ruleLists = append(flt.ruleLists, s.filters(g.RuleListIDs)...)
+ c.RuleLists = s.filters(g.RuleListIDs)
}
- flt.safeBrowsing, flt.adultBlocking = s.safeBrowsingForGroup(g)
- flt.genSafeSearch, flt.ytSafeSearch = s.safeSearchForGroup(g)
+ c.SafeBrowsing, c.AdultBlocking = s.safeBrowsingForGroup(g)
+ c.GeneralSafeSearch, c.YouTubeSafeSearch = s.safeSearchForGroup(g)
- return flt
+ return composite.New(c)
}
// filterForProfile returns a composite filter for profile.
func (s *DefaultStorage) filterForProfile(ctx context.Context, ri *agd.RequestInfo) (f Interface) {
- flt := &compFilter{}
-
p := ri.Profile
if !p.FilteringEnabled {
// According to the current requirements, this means that the profile
// should receive no filtering at all.
- return flt
+ return composite.New(nil)
}
d := ri.Device
if d != nil && !d.FilteringEnabled {
// According to the current requirements, this means that the device
// should receive no filtering at all.
- return flt
+ return composite.New(nil)
}
// Assume that if we have a profile then we also have a device.
- flt.ruleLists = s.filters(p.RuleListIDs)
- flt.ruleLists = s.customFilters.appendRuleLists(ctx, flt.ruleLists, p)
+ c := &composite.Config{}
+ c.RuleLists = s.filters(p.RuleListIDs)
+ c.Custom = s.customFilters.Get(ctx, p)
pp := p.Parental
parentalEnabled := pp != nil && pp.Enabled && s.pcBySchedule(pp.Schedule)
- flt.ruleLists = append(flt.ruleLists, s.serviceFilters(p, parentalEnabled)...)
+ c.ServiceLists = s.serviceFilters(ctx, p, parentalEnabled)
- flt.safeBrowsing, flt.adultBlocking = s.safeBrowsingForProfile(p, parentalEnabled)
- flt.genSafeSearch, flt.ytSafeSearch = s.safeSearchForProfile(p, parentalEnabled)
+ c.SafeBrowsing, c.AdultBlocking = s.safeBrowsingForProfile(p, parentalEnabled)
+ c.GeneralSafeSearch, c.YouTubeSafeSearch = s.safeSearchForProfile(p, parentalEnabled)
- return flt
+ return composite.New(c)
}
// serviceFilters returns the blocked service rule lists for the profile.
-//
-// TODO(a.garipov): Consider not using ruleListFilter for service filters. Due
-// to performance reasons, it would be better to simply go through all enabled
-// rules sequentially instead. Alternatively, rework the urlfilter.DNSEngine
-// and make it use the sequential scan if the number of rules is less than some
-// constant value.
-//
-// See AGDNS-342.
func (s *DefaultStorage) serviceFilters(
+ ctx context.Context,
p *agd.Profile,
parentalEnabled bool,
-) (rls []*ruleListFilter) {
+) (rls []*rulelist.Immutable) {
if !parentalEnabled || len(p.Parental.BlockedServices) == 0 {
return nil
}
- return s.services.ruleLists(p.Parental.BlockedServices)
+ return s.services.RuleLists(ctx, p.Parental.BlockedServices)
}
// pcBySchedule returns true if the profile's schedule allows parental control
@@ -304,7 +302,7 @@ func (s *DefaultStorage) pcBySchedule(sch *agd.ParentalProtectionSchedule) (ok b
func (s *DefaultStorage) safeBrowsingForProfile(
p *agd.Profile,
parentalEnabled bool,
-) (safeBrowsing, adultBlocking *HashPrefix) {
+) (safeBrowsing, adultBlocking *hashprefix.Filter) {
if p.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}
@@ -321,7 +319,7 @@ func (s *DefaultStorage) safeBrowsingForProfile(
func (s *DefaultStorage) safeSearchForProfile(
p *agd.Profile,
parentalEnabled bool,
-) (gen, yt *safeSearch) {
+) (gen, yt *safesearch.Filter) {
if !parentalEnabled {
return nil, nil
}
@@ -341,7 +339,7 @@ func (s *DefaultStorage) safeSearchForProfile(
// in the filtering group. g must not be nil.
func (s *DefaultStorage) safeBrowsingForGroup(
g *agd.FilteringGroup,
-) (safeBrowsing, adultBlocking *HashPrefix) {
+) (safeBrowsing, adultBlocking *hashprefix.Filter) {
if g.SafeBrowsingEnabled {
safeBrowsing = s.safeBrowsing
}
@@ -355,7 +353,7 @@ func (s *DefaultStorage) safeBrowsingForGroup(
// safeSearchForGroup returns safe search filters based on the information in
// the filtering group. g must not be nil.
-func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *safeSearch) {
+func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *safesearch.Filter) {
if !g.ParentalEnabled {
return nil, nil
}
@@ -372,7 +370,7 @@ func (s *DefaultStorage) safeSearchForGroup(g *agd.FilteringGroup) (gen, yt *saf
}
// filters returns all rule list filters with the given filtering rule list IDs.
-func (s *DefaultStorage) filters(ids []agd.FilterListID) (rls []*ruleListFilter) {
+func (s *DefaultStorage) filters(ids []agd.FilterListID) (rls []*rulelist.Refreshable) {
if len(ids) == 0 {
return nil
}
@@ -430,7 +428,7 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err
log.Info("%s: got %d filter lists from index after validations", strgLogPrefix, len(fls))
- ruleLists := make(map[agd.FilterListID]*ruleListFilter, len(resp.Filters))
+ ruleLists := make(map[agd.FilterListID]*rulelist.Refreshable, len(resp.Filters))
for _, fl := range fls {
if _, ok := ruleLists[fl.ID]; ok {
agd.Collectf(ctx, s.errColl, "%s: duplicated id %q", strgLogPrefix, fl.ID)
@@ -438,23 +436,26 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err
continue
}
- // TODO(a.garipov): Cache these.
- promLabels := prometheus.Labels{"filter": string(fl.ID)}
-
- rl := newRuleListFilter(fl, s.cacheDir, s.ruleListCacheSize, s.useRuleListCache)
- err = rl.refresh(ctx, acceptStale)
+ fltIDStr := string(fl.ID)
+ rl := rulelist.NewRefreshable(
+ fl,
+ s.cacheDir,
+ s.ruleListCacheSize,
+ s.useRuleListCache,
+ )
+ err = rl.Refresh(ctx, acceptStale)
if err == nil {
ruleLists[fl.ID] = rl
- metrics.FilterUpdatedStatus.With(promLabels).Set(1)
- metrics.FilterUpdatedTime.With(promLabels).SetToCurrentTime()
- metrics.FilterRulesTotal.With(promLabels).Set(float64(rl.engine.RulesCount))
+ metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(1)
+ metrics.FilterUpdatedTime.WithLabelValues(fltIDStr).SetToCurrentTime()
+ metrics.FilterRulesTotal.WithLabelValues(fltIDStr).Set(float64(rl.RulesCount()))
continue
}
agd.Collectf(ctx, s.errColl, "%s: refreshing %q: %w", strgLogPrefix, fl.ID, err)
- metrics.FilterUpdatedStatus.With(promLabels).Set(0)
+ metrics.FilterUpdatedStatus.WithLabelValues(fltIDStr).Set(0)
// If we can't get the new filter, and there is an old version of the
// same rule list, use it.
@@ -466,17 +467,20 @@ func (s *DefaultStorage) refresh(ctx context.Context, acceptStale bool) (err err
log.Info("%s: got %d filter lists from index after compilation", strgLogPrefix, len(ruleLists))
- err = s.services.refresh(ctx, s.ruleListCacheSize, s.useRuleListCache)
+ err = s.services.Refresh(ctx, s.ruleListCacheSize, s.useRuleListCache)
if err != nil {
- return fmt.Errorf("refreshing service blocker: %w", err)
+ const errFmt = "refreshing blocked services: %w"
+ agd.Collectf(ctx, s.errColl, errFmt, err)
+
+ return fmt.Errorf(errFmt, err)
}
- err = s.genSafeSearch.refresh(ctx, acceptStale)
+ err = s.genSafeSearch.Refresh(ctx, acceptStale)
if err != nil {
return fmt.Errorf("refreshing safe search: %w", err)
}
- err = s.ytSafeSearch.refresh(ctx, acceptStale)
+ err = s.ytSafeSearch.Refresh(ctx, acceptStale)
if err != nil {
return fmt.Errorf("refreshing safe search: %w", err)
}
@@ -515,17 +519,10 @@ func (s *DefaultStorage) loadIndex(ctx context.Context) (resp *filterIndexResp,
}
// setRuleLists replaces the storage's rule lists.
-func (s *DefaultStorage) setRuleLists(ruleLists map[agd.FilterListID]*ruleListFilter) {
+func (s *DefaultStorage) setRuleLists(ruleLists map[agd.FilterListID]*rulelist.Refreshable) {
s.mu.Lock()
defer s.mu.Unlock()
- for id, rl := range s.ruleLists {
- err := rl.Close()
- if err != nil {
- log.Error("%s: closing rule list %q: %s", strgLogPrefix, id, err)
- }
- }
-
s.ruleLists = ruleLists
}
diff --git a/internal/filter/storage_test.go b/internal/filter/storage_test.go
index 9839abc..dcb4199 100644
--- a/internal/filter/storage_test.go
+++ b/internal/filter/storage_test.go
@@ -5,13 +5,16 @@ import (
"io"
"net"
"os"
+ "path/filepath"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
- "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashstorage"
+ "github.com/AdguardTeam/AdGuardDNS/internal/filter/hashprefix"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -62,7 +65,6 @@ func TestStorage_FilterFromContext(t *testing.T) {
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
var r filter.Result
r, err = f.FilterRequest(ctx, req, ri)
@@ -88,7 +90,6 @@ func TestStorage_FilterFromContext(t *testing.T) {
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
var r filter.Result
r, err = f.FilterRequest(ctx, req, ri)
@@ -114,7 +115,6 @@ func TestStorage_FilterFromContext(t *testing.T) {
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
var r filter.Result
r, err = f.FilterRequest(ctx, req, ri)
@@ -147,12 +147,12 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
- hashes, err := hashstorage.New(safeBrowsingHost)
+ hashes, err := hashprefix.NewStorage(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
- c.SafeBrowsing, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
+ c.SafeBrowsing, err = hashprefix.NewFilter(&hashprefix.FilterConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
@@ -202,7 +202,6 @@ func TestStorage_FilterFromContext_customAllow(t *testing.T) {
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
r, err := f.FilterRequest(ctx, req, ri)
require.NoError(t, err)
@@ -240,13 +239,13 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
_, err = io.WriteString(tmpFile, safeBrowsingHost+"\n")
require.NoError(t, err)
- hashes, err := hashstorage.New(safeBrowsingHost)
+ hashes, err := hashprefix.NewStorage(safeBrowsingHost)
require.NoError(t, err)
c := prepareConf(t)
// Use AdultBlocking, because SafeBrowsing is NOT affected by the schedule.
- c.AdultBlocking, err = filter.NewHashPrefix(&filter.HashPrefixConfig{
+ c.AdultBlocking, err = hashprefix.NewFilter(&hashprefix.FilterConfig{
Hashes: hashes,
ErrColl: errColl,
Resolver: resolver,
@@ -272,7 +271,7 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
// Set up our profile with the schedule that disables filtering at the
// current moment.
sch := &agd.ParentalProtectionSchedule{
- TimeZone: time.UTC,
+ TimeZone: agdtime.UTC(),
Week: &agd.WeeklySchedule{
time.Sunday: agd.ZeroLengthDayRange(),
time.Monday: agd.ZeroLengthDayRange(),
@@ -320,7 +319,6 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
// schedule.
f := s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
r, err := f.FilterRequest(ctx, req, ri)
require.NoError(t, err)
@@ -332,7 +330,6 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
f = s.FilterFromContext(ctx, ri)
require.NotNil(t, f)
- testutil.CleanupAndRequireSuccess(t, f.Close)
r, err = f.FilterRequest(ctx, req, ri)
require.NoError(t, err)
@@ -343,6 +340,578 @@ func TestStorage_FilterFromContext_schedule(t *testing.T) {
assert.Equal(t, rm.List, agd.FilterListIDAdultBlocking)
}
+func TestStorage_FilterFromContext_ruleList_request(t *testing.T) {
+ c := prepareConf(t)
+
+ c.ErrColl = &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, err error) { panic("not implemented") },
+ }
+
+ s, err := filter.NewDefaultStorage(c)
+ require.NoError(t, err)
+
+ g := &agd.FilteringGroup{
+ ID: "default",
+ RuleListIDs: []agd.FilterListID{testFilterID},
+ RuleListsEnabled: true,
+ }
+
+ p := &agd.Profile{
+ RuleListIDs: []agd.FilterListID{testFilterID},
+ FilteringEnabled: true,
+ RuleListsEnabled: true,
+ }
+
+ t.Run("blocked", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: blockedFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, blockedHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedHost)
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: allowedFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, allowedHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedHost)
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("blocked_client", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: blockedClientFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, blockedClientHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedClientHost)
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed_client", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: allowedClientFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, allowedClientHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedClientHost)
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("blocked_device", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: blockedDeviceFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, p, blockedDeviceHost, deviceIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedDeviceHost)
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed_device", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: allowedDeviceFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, p, allowedDeviceHost, deviceIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedDeviceHost)
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("none", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: otherNetFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Nil(t, r)
+ })
+}
+
+func TestStorage_FilterFromContext_ruleList_response(t *testing.T) {
+ c := prepareConf(t)
+
+ c.ErrColl = &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, err error) { panic("not implemented") },
+ }
+
+ s, err := filter.NewDefaultStorage(c)
+ require.NoError(t, err)
+
+ g := &agd.FilteringGroup{
+ ID: "default",
+ RuleListIDs: []agd.FilterListID{testFilterID},
+ RuleListsEnabled: true,
+ }
+
+ ri := newReqInfo(g, nil, otherNetHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ question := []dns.Question{{
+ Name: otherNetFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }}
+
+ t.Run("blocked_a", func(t *testing.T) {
+ resp := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.A{
+ A: blockedIP4,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedIP4.String())
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed_a", func(t *testing.T) {
+ resp := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.A{
+ A: allowedIP4,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedIP4.String())
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("blocked_cname", func(t *testing.T) {
+ resp := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: blockedFQDN,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedHost)
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed_cname", func(t *testing.T) {
+ resp := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: allowedFQDN,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedHost)
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("blocked_client", func(t *testing.T) {
+ resp := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: blockedClientFQDN,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, resp, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, blockedClientHost)
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("allowed_client", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: allowedClientFQDN,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, req, ri)
+ require.NoError(t, err)
+
+ ra := testutil.RequireTypeAssert[*filter.ResultAllowed](t, r)
+
+ assert.Contains(t, ra.Rule, allowedClientHost)
+ assert.Equal(t, ra.List, testFilterID)
+ })
+
+ t.Run("exception_cname", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: "cname.exception.",
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Nil(t, r)
+ })
+
+ t.Run("exception_cname_blocked", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: "cname.blocked.",
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterResponse(ctx, req, ri)
+ require.NoError(t, err)
+
+ rb := testutil.RequireTypeAssert[*filter.ResultBlocked](t, r)
+
+ assert.Contains(t, rb.Rule, "cname.blocked")
+ assert.Equal(t, rb.List, testFilterID)
+ })
+
+ t.Run("none", func(t *testing.T) {
+ req := &dns.Msg{
+ Question: question,
+ Answer: []dns.RR{&dns.CNAME{
+ Target: otherOrgFQDN,
+ }},
+ }
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Nil(t, r)
+ })
+}
+
+func TestStorage_FilterFromContext_safeBrowsing(t *testing.T) {
+ cacheDir := t.TempDir()
+ cachePath := filepath.Join(cacheDir, string(agd.FilterListIDSafeBrowsing))
+ err := os.WriteFile(cachePath, []byte(safeBrowsingHost+"\n"), 0o644)
+ require.NoError(t, err)
+
+ hashes, err := hashprefix.NewStorage("")
+ require.NoError(t, err)
+
+ errColl := &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, err error) {
+ panic("not implemented")
+ },
+ }
+
+ resolver := &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ _ netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ return []net.IP{safeBrowsingSafeIP4}, nil
+ },
+ }
+
+ c := prepareConf(t)
+
+ c.SafeBrowsing, err = hashprefix.NewFilter(&hashprefix.FilterConfig{
+ Hashes: hashes,
+ ErrColl: errColl,
+ Resolver: resolver,
+ ID: agd.FilterListIDSafeBrowsing,
+ CachePath: cachePath,
+ ReplacementHost: safeBrowsingSafeHost,
+ Staleness: 1 * time.Hour,
+ CacheTTL: 10 * time.Second,
+ CacheSize: 100,
+ })
+ require.NoError(t, err)
+
+ c.ErrColl = errColl
+ c.Resolver = resolver
+
+ s, err := filter.NewDefaultStorage(c)
+ require.NoError(t, err)
+
+ g := &agd.FilteringGroup{
+ ID: "default",
+ RuleListIDs: []agd.FilterListID{},
+ ParentalEnabled: true,
+ SafeBrowsingEnabled: true,
+ }
+
+ // Test
+
+ req := &dns.Msg{
+ Question: []dns.Question{{
+ Name: safeBrowsingSubFQDN,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }},
+ }
+
+ ri := newReqInfo(g, nil, safeBrowsingSubHost, clientIP, dns.TypeA)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r)
+
+ assert.Equal(t, rm.Rule, agd.FilterRuleText(safeBrowsingHost))
+ assert.Equal(t, rm.List, agd.FilterListIDSafeBrowsing)
+}
+
+func TestStorage_FilterFromContext_safeSearch(t *testing.T) {
+ numLookupIP := 0
+ resolver := &agdtest.Resolver{
+ OnLookupIP: func(
+ _ context.Context,
+ fam netutil.AddrFamily,
+ _ string,
+ ) (ips []net.IP, err error) {
+ numLookupIP++
+
+ if fam == netutil.AddrFamilyIPv4 {
+ return []net.IP{safeSearchIPRespIP4}, nil
+ }
+
+ return []net.IP{safeSearchIPRespIP6}, nil
+ },
+ }
+
+ c := prepareConf(t)
+
+ c.ErrColl = &agdtest.ErrorCollector{
+ OnCollect: func(_ context.Context, err error) { panic("not implemented") },
+ }
+
+ c.Resolver = resolver
+
+ s, err := filter.NewDefaultStorage(c)
+ require.NoError(t, err)
+
+ g := &agd.FilteringGroup{
+ ID: "default",
+ ParentalEnabled: true,
+ GeneralSafeSearch: true,
+ }
+
+ testCases := []struct {
+ name string
+ host string
+ wantIP net.IP
+ rrtype uint16
+ wantLookups int
+ }{{
+ name: "ip4",
+ host: safeSearchIPHost,
+ wantIP: safeSearchIPRespIP4,
+ rrtype: dns.TypeA,
+ wantLookups: 1,
+ }, {
+ name: "ip6",
+ host: safeSearchIPHost,
+ wantIP: safeSearchIPRespIP6,
+ rrtype: dns.TypeAAAA,
+ wantLookups: 1,
+ }, {
+ name: "host_ip4",
+ host: safeSearchHost,
+ wantIP: safeSearchIPRespIP4,
+ rrtype: dns.TypeA,
+ wantLookups: 1,
+ }, {
+ name: "host_ip6",
+ host: safeSearchHost,
+ wantIP: safeSearchIPRespIP6,
+ rrtype: dns.TypeAAAA,
+ wantLookups: 1,
+ }}
+
+ for _, tc := range testCases {
+ numLookupIP = 0
+ req := dnsservertest.CreateMessage(tc.host, tc.rrtype)
+
+ t.Run(tc.name, func(t *testing.T) {
+ ri := newReqInfo(g, nil, tc.host, clientIP, tc.rrtype)
+ ctx := agd.ContextWithRequestInfo(context.Background(), ri)
+
+ f := s.FilterFromContext(ctx, ri)
+ require.NotNil(t, f)
+
+ var r filter.Result
+ r, err = f.FilterRequest(ctx, req, ri)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.wantLookups, numLookupIP)
+
+ rm := testutil.RequireTypeAssert[*filter.ResultModified](t, r)
+ assert.Contains(t, rm.Rule, tc.host)
+ assert.Equal(t, rm.List, agd.FilterListIDGeneralSafeSearch)
+
+ res := rm.Msg
+ require.NotNil(t, res)
+
+ if tc.wantIP == nil {
+ assert.Nil(t, res.Answer)
+
+ return
+ }
+
+ require.Len(t, res.Answer, 1)
+
+ switch ans := res.Answer[0]; ans := ans.(type) {
+ case *dns.A:
+ assert.Equal(t, tc.rrtype, ans.Hdr.Rrtype)
+ assert.Equal(t, tc.wantIP, ans.A)
+ case *dns.AAAA:
+ assert.Equal(t, tc.rrtype, ans.Hdr.Rrtype)
+ assert.Equal(t, tc.wantIP, ans.AAAA)
+ default:
+ t.Fatalf("unexpected answer type %T(%[1]v)", ans)
+ }
+ })
+ }
+}
+
var (
defaultStorageSink *filter.DefaultStorage
errSink error
diff --git a/internal/geoip/asntops.go b/internal/geoip/asntops.go
index c5d9fcf..d69f7bf 100644
--- a/internal/geoip/asntops.go
+++ b/internal/geoip/asntops.go
@@ -11,16 +11,13 @@ var allTopASNs = map[agd.ASN]struct{}{
812: {},
852: {},
1136: {},
- 1213: {},
1221: {},
1241: {},
1257: {},
1267: {},
1547: {},
1680: {},
- 1955: {},
- 2107: {},
- 2108: {},
+ 1901: {},
2110: {},
2116: {},
2119: {},
@@ -35,7 +32,6 @@ var allTopASNs = map[agd.ASN]struct{}{
3209: {},
3212: {},
3215: {},
- 3216: {},
3238: {},
3243: {},
3249: {},
@@ -54,7 +50,6 @@ var allTopASNs = map[agd.ASN]struct{}{
3855: {},
4007: {},
4134: {},
- 4230: {},
4609: {},
4638: {},
4648: {},
@@ -96,6 +91,7 @@ var allTopASNs = map[agd.ASN]struct{}{
6306: {},
6327: {},
6400: {},
+ 6535: {},
6568: {},
6639: {},
6661: {},
@@ -135,16 +131,17 @@ var allTopASNs = map[agd.ASN]struct{}{
8359: {},
8374: {},
8376: {},
+ 8386: {},
8400: {},
8402: {},
8412: {},
8447: {},
8452: {},
8473: {},
- 8544: {},
8551: {},
8560: {},
8585: {},
+ 8632: {},
8661: {},
8680: {},
8681: {},
@@ -208,7 +205,6 @@ var allTopASNs = map[agd.ASN]struct{}{
11315: {},
11427: {},
11556: {},
- 11562: {},
11594: {},
11664: {},
11816: {},
@@ -230,6 +226,7 @@ var allTopASNs = map[agd.ASN]struct{}{
12709: {},
12716: {},
12735: {},
+ 12741: {},
12764: {},
12793: {},
12810: {},
@@ -244,8 +241,8 @@ var allTopASNs = map[agd.ASN]struct{}{
12997: {},
13036: {},
13046: {},
- 13092: {},
13122: {},
+ 13124: {},
13194: {},
13280: {},
13285: {},
@@ -255,6 +252,7 @@ var allTopASNs = map[agd.ASN]struct{}{
13999: {},
14061: {},
14080: {},
+ 14117: {},
14434: {},
14522: {},
14638: {},
@@ -272,6 +270,7 @@ var allTopASNs = map[agd.ASN]struct{}{
15480: {},
15502: {},
15557: {},
+ 15659: {},
15704: {},
15706: {},
15735: {},
@@ -283,6 +282,7 @@ var allTopASNs = map[agd.ASN]struct{}{
15958: {},
15962: {},
15964: {},
+ 15994: {},
16010: {},
16019: {},
16028: {},
@@ -291,9 +291,9 @@ var allTopASNs = map[agd.ASN]struct{}{
16135: {},
16232: {},
16276: {},
+ 16322: {},
16345: {},
16437: {},
- 16509: {},
16637: {},
16705: {},
17072: {},
@@ -330,12 +330,10 @@ var allTopASNs = map[agd.ASN]struct{}{
20294: {},
20473: {},
20634: {},
- 20661: {},
20776: {},
20845: {},
20875: {},
20910: {},
- 20963: {},
20978: {},
21003: {},
21183: {},
@@ -350,6 +348,7 @@ var allTopASNs = map[agd.ASN]struct{}{
21450: {},
21497: {},
21575: {},
+ 21744: {},
21826: {},
21928: {},
21996: {},
@@ -357,7 +356,6 @@ var allTopASNs = map[agd.ASN]struct{}{
22069: {},
22085: {},
22351: {},
- 22581: {},
22724: {},
22773: {},
22927: {},
@@ -371,11 +369,11 @@ var allTopASNs = map[agd.ASN]struct{}{
23693: {},
23752: {},
23889: {},
+ 23917: {},
23955: {},
23969: {},
24158: {},
24203: {},
- 24337: {},
24378: {},
24389: {},
24432: {},
@@ -386,6 +384,7 @@ var allTopASNs = map[agd.ASN]struct{}{
24722: {},
24757: {},
24835: {},
+ 24852: {},
24921: {},
24940: {},
25019: {},
@@ -419,6 +418,7 @@ var allTopASNs = map[agd.ASN]struct{}{
27781: {},
27800: {},
27831: {},
+ 27839: {},
27882: {},
27884: {},
27895: {},
@@ -449,6 +449,7 @@ var allTopASNs = map[agd.ASN]struct{}{
29247: {},
29256: {},
29310: {},
+ 29314: {},
29355: {},
29357: {},
29447: {},
@@ -463,16 +464,15 @@ var allTopASNs = map[agd.ASN]struct{}{
29975: {},
30689: {},
30722: {},
- 30844: {},
30873: {},
30969: {},
30985: {},
30986: {},
+ 30987: {},
30990: {},
30992: {},
30999: {},
31012: {},
- 31027: {},
31037: {},
31042: {},
31122: {},
@@ -483,12 +483,11 @@ var allTopASNs = map[agd.ASN]struct{}{
31213: {},
31224: {},
31252: {},
- 31404: {},
31452: {},
+ 31549: {},
31615: {},
31721: {},
32020: {},
- 32189: {},
33363: {},
33392: {},
33567: {},
@@ -513,14 +512,12 @@ var allTopASNs = map[agd.ASN]struct{}{
35228: {},
35432: {},
35444: {},
- 35725: {},
35805: {},
+ 35807: {},
35819: {},
- 35892: {},
35900: {},
36290: {},
36549: {},
- 36866: {},
36873: {},
36884: {},
36890: {},
@@ -541,6 +538,7 @@ var allTopASNs = map[agd.ASN]struct{}{
36974: {},
36988: {},
36992: {},
+ 36994: {},
36996: {},
36998: {},
36999: {},
@@ -574,16 +572,16 @@ var allTopASNs = map[agd.ASN]struct{}{
37284: {},
37287: {},
37294: {},
+ 37303: {},
37309: {},
37323: {},
+ 37336: {},
37337: {},
37342: {},
37343: {},
- 37349: {},
37371: {},
37376: {},
37385: {},
- 37406: {},
37410: {},
37424: {},
37440: {},
@@ -593,6 +591,7 @@ var allTopASNs = map[agd.ASN]struct{}{
37457: {},
37460: {},
37461: {},
+ 37463: {},
37473: {},
37492: {},
37508: {},
@@ -601,8 +600,8 @@ var allTopASNs = map[agd.ASN]struct{}{
37526: {},
37529: {},
37531: {},
- 37541: {},
37550: {},
+ 37552: {},
37559: {},
37563: {},
37575: {},
@@ -610,13 +609,13 @@ var allTopASNs = map[agd.ASN]struct{}{
37594: {},
37611: {},
37612: {},
- 37614: {},
37616: {},
37645: {},
37649: {},
37671: {},
37693: {},
37705: {},
+ 38008: {},
38009: {},
38077: {},
38195: {},
@@ -629,12 +628,15 @@ var allTopASNs = map[agd.ASN]struct{}{
38742: {},
38800: {},
38819: {},
+ 38875: {},
38901: {},
38999: {},
39010: {},
39232: {},
39603: {},
+ 39611: {},
39642: {},
+ 39737: {},
39891: {},
40945: {},
41164: {},
@@ -642,9 +644,6 @@ var allTopASNs = map[agd.ASN]struct{}{
41329: {},
41330: {},
41557: {},
- 41564: {},
- 41653: {},
- 41697: {},
41738: {},
41750: {},
41897: {},
@@ -654,15 +653,19 @@ var allTopASNs = map[agd.ASN]struct{}{
42298: {},
42313: {},
42437: {},
- 42532: {},
42560: {},
+ 42610: {},
42772: {},
42779: {},
+ 42837: {},
+ 42841: {},
42863: {},
+ 42960: {},
42961: {},
+ 43019: {},
43197: {},
- 43242: {},
43447: {},
+ 43513: {},
43557: {},
43571: {},
43612: {},
@@ -675,7 +678,10 @@ var allTopASNs = map[agd.ASN]struct{}{
44087: {},
44143: {},
44244: {},
+ 44395: {},
+ 44489: {},
44558: {},
+ 44575: {},
44869: {},
45143: {},
45168: {},
@@ -702,7 +708,7 @@ var allTopASNs = map[agd.ASN]struct{}{
47331: {},
47377: {},
47394: {},
- 47588: {},
+ 47524: {},
47589: {},
47883: {},
47956: {},
@@ -715,7 +721,6 @@ var allTopASNs = map[agd.ASN]struct{}{
48728: {},
48832: {},
48887: {},
- 48953: {},
49273: {},
49800: {},
49902: {},
@@ -725,14 +730,14 @@ var allTopASNs = map[agd.ASN]struct{}{
50251: {},
50266: {},
50616: {},
+ 50810: {},
50973: {},
- 50979: {},
51207: {},
51375: {},
51407: {},
51495: {},
- 51684: {},
51765: {},
+ 51852: {},
51896: {},
52228: {},
52233: {},
@@ -744,8 +749,8 @@ var allTopASNs = map[agd.ASN]struct{}{
52341: {},
52362: {},
52398: {},
- 52433: {},
52468: {},
+ 53667: {},
55330: {},
55430: {},
55805: {},
@@ -753,10 +758,12 @@ var allTopASNs = map[agd.ASN]struct{}{
55850: {},
55943: {},
55944: {},
+ 56017: {},
56055: {},
56089: {},
56167: {},
56300: {},
+ 56369: {},
56653: {},
56665: {},
56696: {},
@@ -766,18 +773,16 @@ var allTopASNs = map[agd.ASN]struct{}{
57388: {},
57513: {},
57704: {},
- 58065: {},
58224: {},
58460: {},
58731: {},
+ 58952: {},
59257: {},
- 59588: {},
59989: {},
60068: {},
60258: {},
61143: {},
61461: {},
- 62563: {},
63949: {},
64466: {},
131178: {},
@@ -791,29 +796,32 @@ var allTopASNs = map[agd.ASN]struct{}{
132167: {},
132199: {},
132471: {},
+ 132486: {},
132618: {},
- 132831: {},
133385: {},
133481: {},
133579: {},
+ 133606: {},
133612: {},
- 133897: {},
134783: {},
135409: {},
136255: {},
+ 136950: {},
137412: {},
+ 137824: {},
138179: {},
139759: {},
139831: {},
139898: {},
139922: {},
+ 140504: {},
196838: {},
197207: {},
197830: {},
198279: {},
198605: {},
199140: {},
- 199155: {},
+ 199276: {},
199731: {},
200134: {},
201167: {},
@@ -822,52 +830,58 @@ var allTopASNs = map[agd.ASN]struct{}{
202087: {},
202254: {},
202422: {},
+ 202448: {},
203214: {},
203953: {},
203995: {},
204170: {},
- 204317: {},
- 204342: {},
- 204649: {},
+ 204279: {},
+ 205110: {},
205368: {},
205714: {},
206026: {},
206067: {},
206206: {},
206262: {},
+ 207369: {},
207569: {},
207651: {},
207810: {},
+ 209424: {},
+ 210003: {},
210315: {},
210542: {},
211144: {},
+ 212238: {},
+ 212370: {},
213155: {},
- 213373: {},
+ 213371: {},
262145: {},
262186: {},
262197: {},
262202: {},
262210: {},
262239: {},
+ 262589: {},
263238: {},
263725: {},
263783: {},
263824: {},
264628: {},
264645: {},
+ 264663: {},
264668: {},
264731: {},
+ 266673: {},
269729: {},
271773: {},
327697: {},
- 327707: {},
327712: {},
327725: {},
327738: {},
327756: {},
327765: {},
327769: {},
- 327776: {},
327799: {},
327802: {},
327885: {},
@@ -876,7 +890,6 @@ var allTopASNs = map[agd.ASN]struct{}{
328061: {},
328079: {},
328088: {},
- 328136: {},
328140: {},
328169: {},
328191: {},
@@ -886,16 +899,13 @@ var allTopASNs = map[agd.ASN]struct{}{
328286: {},
328297: {},
328309: {},
- 328411: {},
328453: {},
328469: {},
328488: {},
- 328586: {},
- 328605: {},
- 328708: {},
+ 328539: {},
328755: {},
328943: {},
- 329020: {},
+ 393275: {},
394311: {},
395561: {},
396304: {},
@@ -908,9 +918,9 @@ var allTopASNs = map[agd.ASN]struct{}{
var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryAD: 6752,
agd.CountryAE: 5384,
- agd.CountryAF: 55330,
+ agd.CountryAF: 132471,
agd.CountryAG: 11594,
- agd.CountryAI: 11139,
+ agd.CountryAI: 2740,
agd.CountryAL: 21183,
agd.CountryAM: 12297,
agd.CountryAO: 37119,
@@ -931,55 +941,54 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryBI: 327799,
agd.CountryBJ: 37424,
agd.CountryBL: 3215,
- agd.CountryBM: 3855,
+ agd.CountryBM: 32020,
agd.CountryBN: 10094,
agd.CountryBO: 6568,
agd.CountryBQ: 27745,
- agd.CountryBR: 26599,
+ agd.CountryBR: 28573,
agd.CountryBS: 15146,
agd.CountryBT: 18024,
agd.CountryBW: 14988,
- agd.CountryBY: 25106,
- agd.CountryBZ: 10269,
+ agd.CountryBY: 6697,
+ agd.CountryBZ: 212370,
agd.CountryCA: 812,
- agd.CountryCC: 198605,
agd.CountryCD: 37020,
agd.CountryCF: 37460,
- agd.CountryCG: 37451,
+ agd.CountryCG: 36924,
agd.CountryCH: 3303,
agd.CountryCI: 29571,
agd.CountryCK: 10131,
- agd.CountryCL: 27651,
+ agd.CountryCL: 7418,
agd.CountryCM: 30992,
agd.CountryCN: 4134,
- agd.CountryCO: 26611,
+ agd.CountryCO: 10620,
agd.CountryCR: 11830,
agd.CountryCU: 27725,
agd.CountryCV: 37517,
agd.CountryCW: 52233,
- agd.CountryCX: 198605,
- agd.CountryCY: 6866,
+ agd.CountryCY: 202448,
agd.CountryCZ: 5610,
agd.CountryDE: 3320,
agd.CountryDJ: 30990,
- agd.CountryDK: 3292,
+ agd.CountryDK: 9009,
agd.CountryDM: 40945,
agd.CountryDO: 6400,
agd.CountryDZ: 36947,
agd.CountryEC: 27947,
agd.CountryEE: 3249,
agd.CountryEG: 8452,
+ agd.CountryEH: 6713,
agd.CountryER: 24757,
agd.CountryES: 12479,
agd.CountryET: 24757,
agd.CountryFI: 51765,
agd.CountryFJ: 38442,
- agd.CountryFK: 204649,
+ agd.CountryFK: 198605,
agd.CountryFM: 139759,
agd.CountryFO: 15389,
agd.CountryFR: 3215,
- agd.CountryGA: 16058,
- agd.CountryGB: 2856,
+ agd.CountryGA: 36924,
+ agd.CountryGB: 60068,
agd.CountryGD: 46650,
agd.CountryGE: 16010,
agd.CountryGF: 3215,
@@ -987,7 +996,7 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryGH: 30986,
agd.CountryGI: 8301,
agd.CountryGL: 8818,
- agd.CountryGM: 37309,
+ agd.CountryGM: 25250,
agd.CountryGN: 37461,
agd.CountryGP: 3215,
agd.CountryGQ: 37173,
@@ -1002,13 +1011,13 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryHT: 27653,
agd.CountryHU: 5483,
agd.CountryID: 7713,
- agd.CountryIE: 15502,
+ agd.CountryIE: 6830,
agd.CountryIL: 1680,
agd.CountryIM: 13122,
agd.CountryIN: 55836,
agd.CountryIO: 17458,
agd.CountryIQ: 203214,
- agd.CountryIR: 44244,
+ agd.CountryIR: 58224,
agd.CountryIS: 43571,
agd.CountryIT: 1267,
agd.CountryJE: 8680,
@@ -1019,21 +1028,21 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryKG: 50223,
agd.CountryKH: 38623,
agd.CountryKI: 134783,
- agd.CountryKM: 328061,
+ agd.CountryKM: 36939,
agd.CountryKN: 11139,
agd.CountryKR: 4766,
agd.CountryKW: 29357,
agd.CountryKY: 6639,
agd.CountryKZ: 206026,
agd.CountryLA: 9873,
- agd.CountryLB: 38999,
+ agd.CountryLB: 42003,
agd.CountryLC: 15344,
agd.CountryLI: 20634,
agd.CountryLK: 18001,
- agd.CountryLR: 37094,
+ agd.CountryLR: 37410,
agd.CountryLS: 33567,
agd.CountryLT: 8764,
- agd.CountryLU: 9009,
+ agd.CountryLU: 6661,
agd.CountryLV: 24921,
agd.CountryLY: 21003,
agd.CountryMA: 36903,
@@ -1068,37 +1077,35 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryNL: 1136,
agd.CountryNO: 2119,
agd.CountryNP: 17501,
- agd.CountryNR: 45355,
- agd.CountryNU: 198605,
- agd.CountryNZ: 4771,
+ agd.CountryNR: 140504,
+ agd.CountryNZ: 9790,
agd.CountryOM: 28885,
- agd.CountryPA: 11556,
+ agd.CountryPA: 18809,
agd.CountryPE: 12252,
agd.CountryPF: 9471,
agd.CountryPG: 139898,
agd.CountryPH: 9299,
agd.CountryPK: 45669,
- agd.CountryPL: 43447,
+ agd.CountryPL: 5617,
agd.CountryPM: 3695,
- agd.CountryPR: 21928,
+ agd.CountryPR: 14638,
agd.CountryPS: 12975,
- agd.CountryPT: 12353,
+ agd.CountryPT: 3243,
agd.CountryPW: 17893,
agd.CountryPY: 23201,
agd.CountryQA: 42298,
- agd.CountryRE: 37002,
+ agd.CountryRE: 3215,
agd.CountryRO: 8708,
agd.CountryRS: 8400,
agd.CountryRU: 8359,
agd.CountryRW: 36924,
agd.CountrySA: 39891,
agd.CountrySB: 45891,
- agd.CountrySC: 131267,
+ agd.CountrySC: 36958,
agd.CountrySD: 15706,
- agd.CountrySE: 1257,
+ agd.CountrySE: 60068,
agd.CountrySG: 4773,
agd.CountrySI: 3212,
- agd.CountrySJ: 198605,
agd.CountrySK: 6855,
agd.CountrySL: 37164,
agd.CountrySM: 15433,
@@ -1113,7 +1120,7 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountrySZ: 328169,
agd.CountryTC: 394311,
agd.CountryTD: 327802,
- agd.CountryTG: 24691,
+ agd.CountryTG: 36924,
agd.CountryTH: 131445,
agd.CountryTJ: 43197,
agd.CountryTK: 4648,
@@ -1121,8 +1128,9 @@ var countryTopASNs = map[agd.Country]agd.ASN{
agd.CountryTM: 51495,
agd.CountryTN: 37705,
agd.CountryTO: 38201,
- agd.CountryTR: 16135,
+ agd.CountryTR: 47331,
agd.CountryTT: 27800,
+ agd.CountryTV: 23917,
agd.CountryTW: 3462,
agd.CountryTZ: 36908,
agd.CountryUA: 15895,
diff --git a/internal/geoip/asntops_generate.go b/internal/geoip/asntops_generate.go
index 60eee58..9a32879 100644
--- a/internal/geoip/asntops_generate.go
+++ b/internal/geoip/asntops_generate.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
)
@@ -22,7 +23,7 @@ func main() {
req, err := http.NewRequest(http.MethodGet, countriesASNURL, nil)
check(err)
- req.Header.Add("User-Agent", agdhttp.UserAgent())
+ req.Header.Add(httphdr.UserAgent, agdhttp.UserAgent())
resp, err := c.Do(req)
check(err)
diff --git a/internal/geoip/file.go b/internal/geoip/file.go
index 432e3ca..5d29b81 100644
--- a/internal/geoip/file.go
+++ b/internal/geoip/file.go
@@ -23,7 +23,8 @@ type FileConfig struct {
// ASNPath is the path to the GeoIP database of ASNs.
ASNPath string
- // CountryPath is the path to the GeoIP database of countries.
+ // CountryPath is the path to the GeoIP database of countries. The
+ // databases containing subdivisions and cities info are also supported.
CountryPath string
// HostCacheSize is how many lookups are cached by hostname. Zero means no
@@ -207,18 +208,16 @@ func (f *File) Data(host string, ip netip.Addr) (l *agd.Location, err error) {
return nil, fmt.Errorf("looking up asn: %w", err)
}
- ctry, cont, err := f.lookupCtry(ip)
+ l = &agd.Location{
+ ASN: asn,
+ }
+
+ err = f.setCtry(l, ip)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
- l = &agd.Location{
- Country: ctry,
- Continent: cont,
- ASN: asn,
- }
-
f.setCaches(host, cacheKey, l)
return l, nil
@@ -271,29 +270,36 @@ type countryResult struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
+ Subdivisions []struct {
+ ISOCode string `maxminddb:"iso_code"`
+ } `maxminddb:"subdivisions"`
}
-// lookupCtry looks up and returns the country and continent parts of the GeoIP
-// data for ip.
-func (f *File) lookupCtry(ip netip.Addr) (ctry agd.Country, cont agd.Continent, err error) {
+// setCtry looks up and sets the country, continent and the subdivision parts
+// of the GeoIP data for ip into loc. loc must not be nil.
+func (f *File) setCtry(loc *agd.Location, ip netip.Addr) (err error) {
// TODO(a.garipov): Remove AsSlice if oschwald/maxminddb-golang#88 is done.
var res countryResult
err = f.country.Lookup(ip.AsSlice(), &res)
if err != nil {
- return ctry, cont, fmt.Errorf("looking up country: %w", err)
+ return fmt.Errorf("looking up country: %w", err)
}
- ctry, err = agd.NewCountry(res.Country.ISOCode)
+ loc.Country, err = agd.NewCountry(res.Country.ISOCode)
if err != nil {
- return ctry, cont, fmt.Errorf("converting country: %w", err)
+ return fmt.Errorf("converting country: %w", err)
}
- cont, err = agd.NewContinent(res.Continent.Code)
+ loc.Continent, err = agd.NewContinent(res.Continent.Code)
if err != nil {
- return ctry, cont, fmt.Errorf("converting continent: %w", err)
+ return fmt.Errorf("converting continent: %w", err)
}
- return ctry, cont, nil
+ if len(res.Subdivisions) > 0 {
+ loc.TopSubdivision = res.Subdivisions[0].ISOCode
+ }
+
+ return nil
}
// setCaches sets the GeoIP data into the caches.
diff --git a/internal/geoip/file_test.go b/internal/geoip/file_test.go
index 58c629c..1715546 100644
--- a/internal/geoip/file_test.go
+++ b/internal/geoip/file_test.go
@@ -11,7 +11,31 @@ import (
"github.com/stretchr/testify/require"
)
-func TestFile_Data(t *testing.T) {
+func TestFile_Data_cityDB(t *testing.T) {
+ conf := &geoip.FileConfig{
+ ASNPath: asnPath,
+ CountryPath: cityPath,
+ HostCacheSize: 0,
+ IPCacheSize: 1,
+ }
+
+ g, err := geoip.NewFile(conf)
+ require.NoError(t, err)
+
+ d, err := g.Data(testHost, testIPWithASN)
+ require.NoError(t, err)
+
+ assert.Equal(t, testASN, d.ASN)
+
+ d, err = g.Data(testHost, testIPWithSubdiv)
+ require.NoError(t, err)
+
+ assert.Equal(t, testCtry, d.Country)
+ assert.Equal(t, testCont, d.Continent)
+ assert.Equal(t, testSubdiv, d.TopSubdivision)
+}
+
+func TestFile_Data_countryDB(t *testing.T) {
conf := &geoip.FileConfig{
ASNPath: asnPath,
CountryPath: countryPath,
@@ -27,17 +51,18 @@ func TestFile_Data(t *testing.T) {
assert.Equal(t, testASN, d.ASN)
- d, err = g.Data(testHost, testIPWithCountry)
+ d, err = g.Data(testHost, testIPWithSubdiv)
require.NoError(t, err)
assert.Equal(t, testCtry, d.Country)
assert.Equal(t, testCont, d.Continent)
+ assert.Empty(t, d.TopSubdivision)
}
func TestFile_Data_hostCache(t *testing.T) {
conf := &geoip.FileConfig{
ASNPath: asnPath,
- CountryPath: countryPath,
+ CountryPath: cityPath,
HostCacheSize: 1,
IPCacheSize: 1,
}
@@ -64,7 +89,7 @@ func TestFile_Data_hostCache(t *testing.T) {
func TestFile_SubnetByLocation(t *testing.T) {
conf := &geoip.FileConfig{
ASNPath: asnPath,
- CountryPath: countryPath,
+ CountryPath: cityPath,
HostCacheSize: 0,
IPCacheSize: 1,
}
@@ -91,7 +116,7 @@ var errSink error
func BenchmarkFile_Data(b *testing.B) {
conf := &geoip.FileConfig{
ASNPath: asnPath,
- CountryPath: countryPath,
+ CountryPath: cityPath,
HostCacheSize: 0,
IPCacheSize: 1,
}
@@ -103,7 +128,7 @@ func BenchmarkFile_Data(b *testing.B) {
// Change the eighth byte in testIPWithCountry to create a different address
// in the same network.
- ipSlice := testIPWithCountry.AsSlice()
+ ipSlice := ipCountry1.AsSlice()
ipSlice[7] = 1
ipCountry2, ok := netip.AddrFromSlice(ipSlice)
require.True(b, ok)
@@ -142,7 +167,7 @@ var fileSink *geoip.File
func BenchmarkNewFile(b *testing.B) {
conf := &geoip.FileConfig{
ASNPath: asnPath,
- CountryPath: countryPath,
+ CountryPath: cityPath,
HostCacheSize: 0,
IPCacheSize: 1,
}
diff --git a/internal/geoip/geoip_test.go b/internal/geoip/geoip_test.go
index b6d2d50..2768867 100644
--- a/internal/geoip/geoip_test.go
+++ b/internal/geoip/geoip_test.go
@@ -15,6 +15,7 @@ func TestMain(m *testing.M) {
// Paths to test data.
const (
asnPath = "./testdata/GeoLite2-ASN-Test.mmdb"
+ cityPath = "./testdata/GeoIP2-City-Test.mmdb"
countryPath = "./testdata/GeoIP2-Country-Test.mmdb"
)
@@ -24,12 +25,16 @@ const (
testOtherHost = "other.example.com"
)
-// Test data. See https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoLite2-ASN-Test.json
-// and https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-Country-Test.json.
+// Test data. See [ASN], [city], and [country] testing datum.
+//
+// [ASN]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoLite2-ASN-Test.json
+// [city]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-City-Test.json
+// [country]: https://github.com/maxmind/MaxMind-DB/blob/2bf1713b3b5adcb022cf4bb77eb0689beaadcfef/source-data/GeoIP2-Country-Test.json
const (
- testASN agd.ASN = 1221
- testCtry agd.Country = agd.CountryJP
- testCont agd.Continent = agd.ContinentAS
+ testASN agd.ASN = 1221
+ testCtry agd.Country = agd.CountryUS
+ testCont agd.Continent = agd.ContinentNA
+ testSubdiv string = "WA"
testIPv4SubnetCtry = agd.CountryUS
testIPv6SubnetCtry = agd.CountryJP
@@ -38,7 +43,13 @@ const (
// testIPWithASN has ASN set to 1221 in the test database.
var testIPWithASN = netip.MustParseAddr("1.128.0.0")
-// testIPWithCountry has country set to Japan in the test database.
+// testIPWithSubdiv has country set to USA and the subdivision set to Washington
+// in the city-aware test database. It has no subdivision in the country-aware
+// test database but resolves into USA as well.
+var testIPWithSubdiv = netip.MustParseAddr("216.160.83.56")
+
+// testIPWithCountry has country set to Japan in the country-aware test
+// database.
var testIPWithCountry = netip.MustParseAddr("2001:218::")
// Subnets for CountrySubnet tests.
diff --git a/internal/geoip/testdata/GeoIP2-City-Test.mmdb b/internal/geoip/testdata/GeoIP2-City-Test.mmdb
new file mode 100644
index 0000000..43ab5ed
Binary files /dev/null and b/internal/geoip/testdata/GeoIP2-City-Test.mmdb differ
diff --git a/internal/metrics/datastorage.go b/internal/metrics/backend.go
similarity index 74%
rename from internal/metrics/datastorage.go
rename to internal/metrics/backend.go
index 8874d52..87761f4 100644
--- a/internal/metrics/datastorage.go
+++ b/internal/metrics/backend.go
@@ -5,6 +5,24 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)
+// DevicesCountGauge is a gauge with the total number of user devices loaded
+// from the backend.
+var DevicesCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "devices_total",
+ Subsystem: subsystemBackend,
+ Namespace: namespace,
+ Help: "The total number of user devices loaded from the backend.",
+})
+
+// DevicesNewCountGauge is a gauge with the number of user devices downloaded
+// during the last sync.
+var DevicesNewCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "devices_newly_synced_total",
+ Subsystem: subsystemBackend,
+ Namespace: namespace,
+ Help: "The number of user devices that were changed or added since the previous sync.",
+})
+
// ProfilesCountGauge is a gauge with the total number of user profiles loaded
// from the backend.
var ProfilesCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
diff --git a/internal/metrics/connlimiter.go b/internal/metrics/connlimiter.go
new file mode 100644
index 0000000..154d8d1
--- /dev/null
+++ b/internal/metrics/connlimiter.go
@@ -0,0 +1,45 @@
+package metrics
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+// ConnLimiterLimits is the gauge vector for showing the configured limits of
+// the number of active stream-connections.
+var ConnLimiterLimits = promauto.NewGaugeVec(prometheus.GaugeOpts{
+ Name: "limits",
+ Namespace: namespace,
+ Subsystem: subsystemConnLimiter,
+ Help: `The current limits of the number of active stream-connections: ` +
+ `kind="stop" for the stopping limit and kind="resume" for the resuming one.`,
+}, []string{"kind"})
+
+// ConnLimiterActiveStreamConns is the gauge vector for the number of active
+// stream-connections.
+var ConnLimiterActiveStreamConns = promauto.NewGaugeVec(prometheus.GaugeOpts{
+ Name: "active_stream_conns",
+ Namespace: namespace,
+ Subsystem: subsystemConnLimiter,
+ Help: `The number of currently active stream-connections.`,
+}, []string{"name", "proto", "addr"})
+
+// StreamConnWaitDuration is a histogram with the duration of waiting times for
+// accepting stream connections.
+var StreamConnWaitDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
+ Name: "stream_conn_wait_duration_seconds",
+ Subsystem: subsystemConnLimiter,
+ Namespace: namespace,
+ Help: "How long a stream connection waits for an accept, in seconds.",
+ Buckets: []float64{0.00001, 0.01, 0.1, 1, 10, 30, 60},
+}, []string{"name", "proto", "addr"})
+
+// StreamConnLifeDuration is a histogram with the duration of lives of stream
+// connections.
+var StreamConnLifeDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
+ Name: "stream_conn_life_duration_seconds",
+ Subsystem: subsystemConnLimiter,
+ Namespace: namespace,
+ Help: "How long a stream connection lives, in seconds.",
+ Buckets: []float64{0.1, 1, 5, 10, 30, 60},
+}, []string{"name", "proto", "addr"})
diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go
index 5371034..012d721 100644
--- a/internal/metrics/metrics.go
+++ b/internal/metrics/metrics.go
@@ -1,5 +1,6 @@
// Package metrics contains definitions of most of the prometheus metrics
// that we use in AdGuard DNS.
+//
// TODO(ameshkov): consider not using promauto.
package metrics
@@ -16,6 +17,7 @@ const (
subsystemApplication = "app"
subsystemBackend = "backend"
subsystemBillStat = "billstat"
+ subsystemConnLimiter = "connlimiter"
subsystemConsul = "consul"
subsystemDNSCheck = "dnscheck"
subsystemDNSDB = "dnsdb"
@@ -30,8 +32,8 @@ const (
subsystemWebSvc = "websvc"
)
-// SetUpGauge signals that the server has been started.
-// We're using a function here to avoid circular dependencies.
+// SetUpGauge signals that the server has been started. Use a function here to
+// avoid circular dependencies.
func SetUpGauge(version, buildtime, branch, revision, goversion string) {
upGauge := promauto.NewGauge(
prometheus.GaugeOpts{
diff --git a/internal/metrics/research.go b/internal/metrics/research.go
index 50646a6..5c4b34e 100644
--- a/internal/metrics/research.go
+++ b/internal/metrics/research.go
@@ -1,8 +1,10 @@
package metrics
import (
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
+ "github.com/prometheus/common/model"
)
// ResearchRequestsPerCountryTotal counts the total number of queries per
@@ -23,39 +25,80 @@ var ResearchBlockedRequestsPerCountryTotal = promauto.NewCounterVec(prometheus.C
Help: "The number of blocked DNS queries per country from anonymous users.",
}, []string{"filter", "country"})
+// ResearchRequestsPerSubdivTotal counts the total number of queries per country
+// from anonymous users.
+var ResearchRequestsPerSubdivTotal = promauto.NewCounterVec(prometheus.CounterOpts{
+ Name: "requests_per_subdivision_total",
+ Namespace: namespace,
+ Subsystem: subsystemResearch,
+ Help: `The total number of DNS queries per countries with top ` +
+ `subdivision from anonymous users.`,
+}, []string{"country", "subdivision"})
+
+// ResearchBlockedRequestsPerSubdivTotal counts the number of blocked queries
+// per country from anonymous users.
+var ResearchBlockedRequestsPerSubdivTotal = promauto.NewCounterVec(prometheus.CounterOpts{
+ Name: "blocked_per_subdivision_total",
+ Namespace: namespace,
+ Subsystem: subsystemResearch,
+ Help: `The number of blocked DNS queries per countries with top ` +
+ `subdivision from anonymous users.`,
+}, []string{"filter", "country", "subdivision"})
+
// ReportResearchMetrics reports metrics to prometheus that we may need to
// conduct researches.
-//
-// TODO(ameshkov): use [agd.Profile] arg when recursive dependency is resolved.
func ReportResearchMetrics(
- anonymous bool,
- filteringEnabled bool,
- asn string,
- ctry string,
- filterID string,
+ ri *agd.RequestInfo,
+ filterID agd.FilterListID,
blocked bool,
) {
- // The current research metrics only count queries that come to public
- // DNS servers where filtering is enabled.
- if !filteringEnabled || !anonymous {
+ filteringEnabled := ri.FilteringGroup != nil &&
+ ri.FilteringGroup.RuleListsEnabled &&
+ len(ri.FilteringGroup.RuleListIDs) > 0
+
+ // The current research metrics only count queries that come to public DNS
+ // servers where filtering is enabled.
+ if !filteringEnabled || ri.Profile != nil {
return
}
- // Ignore AdGuard ASN specifically in order to avoid counting queries that
- // come from the monitoring. This part is ugly, but since these metrics
- // are a one-time deal, this is acceptable.
- //
- // TODO(ameshkov): think of a better way later if we need to do that again.
- if asn == "212772" {
- return
+ var ctry, subdiv string
+ if l := ri.Location; l != nil {
+ // Ignore AdGuard ASN specifically in order to avoid counting queries
+ // that come from the monitoring. This part is ugly, but since these
+ // metrics are a one-time deal, this is acceptable.
+ //
+ // TODO(ameshkov): Think of a better way later if we need to do that
+ // again.
+ if l.ASN == 212772 {
+ return
+ }
+
+ ctry = string(l.Country)
+ if model.LabelValue(l.TopSubdivision).IsValid() {
+ subdiv = l.TopSubdivision
+ }
}
if blocked {
- ResearchBlockedRequestsPerCountryTotal.WithLabelValues(
- filterID,
- ctry,
- ).Inc()
+ reportResearchBlocked(string(filterID), ctry, subdiv)
}
- ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
+ reportResearchRequest(ctry, subdiv)
+}
+
+// reportResearchBlocked reports on a blocked request to the research metrics.
+func reportResearchBlocked(fltID, ctry, subdiv string) {
+ ResearchBlockedRequestsPerCountryTotal.WithLabelValues(fltID, ctry).Inc()
+ if subdiv != "" {
+ ResearchBlockedRequestsPerSubdivTotal.WithLabelValues(fltID, ctry, subdiv).Inc()
+ }
+}
+
+// reportResearchBlocked reports on a request to the research metrics.
+func reportResearchRequest(ctry, subdiv string) {
+ ResearchRequestsPerCountryTotal.WithLabelValues(ctry).Inc()
+ if subdiv != "" {
+ ResearchRequestsPerSubdivTotal.WithLabelValues(ctry, subdiv).Inc()
+ }
}
diff --git a/internal/metrics/tls.go b/internal/metrics/tls.go
index b619b9c..0f6baa7 100644
--- a/internal/metrics/tls.go
+++ b/internal/metrics/tls.go
@@ -93,6 +93,8 @@ func TLSMetricsAfterHandshake(
tlsVersionToString(state.Version),
BoolString(state.DidResume),
tls.CipherSuiteName(state.CipherSuite),
+ // Don't validate the negotiated protocol since it's expected to
+ // contain only ASCII after negotiation itself.
state.NegotiatedProtocol,
sLabel,
).Inc()
@@ -112,12 +114,17 @@ func TLSMetricsBeforeHandshake(proto string) (f func(*tls.ClientHelloInfo) (*tls
}
}
+ supProtos := make([]string, len(info.SupportedProtos))
+ for i := range info.SupportedProtos {
+ supProtos[i] = strings.ToValidUTF8(info.SupportedProtos[i], "")
+ }
+
// Stick to using WithLabelValues instead of With in order to avoid
// extra allocations on prometheus.Labels. The labels order is VERY
// important here.
TLSHandshakeAttemptsTotal.WithLabelValues(
proto,
- strings.Join(info.SupportedProtos, ","),
+ strings.Join(supProtos, ","),
tlsVersionToString(maxVersion),
).Inc()
@@ -127,7 +134,6 @@ func TLSMetricsBeforeHandshake(proto string) (f func(*tls.ClientHelloInfo) (*tls
// tlsVersionToString converts TLS version to string.
func tlsVersionToString(ver uint16) (tlsVersion string) {
- tlsVersion = "unknown"
switch ver {
case tls.VersionTLS13:
tlsVersion = "tls1.3"
@@ -137,7 +143,10 @@ func tlsVersionToString(ver uint16) (tlsVersion string) {
tlsVersion = "tls1.1"
case tls.VersionTLS10:
tlsVersion = "tls1.0"
+ default:
+ tlsVersion = "unknown"
}
+
return tlsVersion
}
@@ -151,8 +160,8 @@ func serverNameToLabel(
srvCerts []tls.Certificate,
) (label string) {
if sni == "" {
- // SNI is not provided, so the request is probably made on the
- // IP address.
+ // SNI is not provided, so the request is probably made on the IP
+ // address.
return fmt.Sprintf("%s: other", srvName)
}
diff --git a/internal/metrics/tls_test.go b/internal/metrics/tls_test.go
index e3ccda7..9fd28f7 100644
--- a/internal/metrics/tls_test.go
+++ b/internal/metrics/tls_test.go
@@ -7,7 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_model/go"
+ io_prometheus_client "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -135,3 +135,18 @@ outerLoop:
return assert.Truef(t, ok, "%s not found in server name labels", wantLabel)
}
+
+func TestTLSMetricsBeforeHandshake(t *testing.T) {
+ f := metrics.TLSMetricsBeforeHandshake("srv-name")
+
+ var conf *tls.Config
+ var err error
+ require.NotPanics(t, func() {
+ conf, err = f(&tls.ClientHelloInfo{
+ SupportedProtos: []string{"\xC0\xC1\xF5\xF6\xF7\xF8\xF9\xFA\xFB\xFC\xFD\xFE\xFF"},
+ })
+ })
+ require.NoError(t, err)
+
+ assert.Nil(t, conf)
+}
diff --git a/internal/profiledb/error.go b/internal/profiledb/error.go
new file mode 100644
index 0000000..40814ec
--- /dev/null
+++ b/internal/profiledb/error.go
@@ -0,0 +1,7 @@
+package profiledb
+
+import "github.com/AdguardTeam/golibs/errors"
+
+// ErrDeviceNotFound is an error returned by lookup methods when a device
+// couldn't be found.
+const ErrDeviceNotFound errors.Error = "device not found"
diff --git a/internal/profiledb/internal/filecachejson/filecachejson.go b/internal/profiledb/internal/filecachejson/filecachejson.go
new file mode 100644
index 0000000..3c3110e
--- /dev/null
+++ b/internal/profiledb/internal/filecachejson/filecachejson.go
@@ -0,0 +1,117 @@
+// Package filecachejson contains an implementation of the file-cache storage
+// that encodes data using JSON.
+package filecachejson
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
+ "github.com/google/renameio"
+)
+
+// Storage is the file-cache storage that encodes data using JSON.
+type Storage struct {
+ path string
+}
+
+// New returns a new JSON-encoded file-cache storage.
+func New(cachePath string) (s *Storage) {
+ return &Storage{
+ path: cachePath,
+ }
+}
+
+// fileCache is the structure for the JSON filesystem cache of a profile
+// database.
+//
+// NOTE: Do not change fields of this structure without incrementing
+// [internal.FileCacheVersion].
+type fileCache struct {
+ SyncTime time.Time `json:"sync_time"`
+ Profiles []*agd.Profile `json:"profiles"`
+ Devices []*agd.Device `json:"devices"`
+ Version int32 `json:"version"`
+}
+
+// logPrefix is the logging prefix for the JSON-encoded file-cache.
+const logPrefix = "profiledb json cache"
+
+var _ internal.FileCacheStorage = (*Storage)(nil)
+
+// Load implements the [internal.FileCacheStorage] interface for *Storage.
+func (s *Storage) Load() (c *internal.FileCache, err error) {
+ log.Info("%s: loading", logPrefix)
+
+ data, err := s.loadFromFile()
+ if err != nil {
+ return nil, fmt.Errorf("loading from file: %w", err)
+ }
+
+ if data == nil {
+ log.Info("%s: file not present", logPrefix)
+
+ return nil, nil
+ }
+
+ return &internal.FileCache{
+ SyncTime: data.SyncTime,
+ Profiles: data.Profiles,
+ Devices: data.Devices,
+ Version: data.Version,
+ }, nil
+}
+
+// loadFromFile loads the profile data from cache file.
+func (s *Storage) loadFromFile() (data *fileCache, err error) {
+ file, err := os.Open(s.path)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ // File could be deleted or not yet created, go on.
+ return nil, nil
+ }
+
+ return nil, err
+ }
+ defer func() { err = errors.WithDeferred(err, file.Close()) }()
+
+ data = &fileCache{}
+ err = json.NewDecoder(file).Decode(data)
+ if err != nil {
+ return nil, fmt.Errorf("decoding json: %w", err)
+ }
+
+ return data, nil
+}
+
+// Store implements the [internal.FileCacheStorage] interface for *Storage.
+func (s *Storage) Store(c *internal.FileCache) (err error) {
+ profNum := len(c.Profiles)
+ log.Info("%s: saving %d profiles to %q", logPrefix, profNum, s.path)
+ defer log.Info("%s: saved %d profiles to %q", logPrefix, profNum, s.path)
+
+ data := &fileCache{
+ SyncTime: c.SyncTime,
+ Profiles: c.Profiles,
+ Devices: c.Devices,
+ Version: c.Version,
+ }
+
+ cache, err := json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("encoding json: %w", err)
+ }
+
+ err = renameio.WriteFile(s.path, cache, 0o600)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/profiledb/internal/filecachejson/filecachejson_test.go b/internal/profiledb/internal/filecachejson/filecachejson_test.go
new file mode 100644
index 0000000..7c017f1
--- /dev/null
+++ b/internal/profiledb/internal/filecachejson/filecachejson_test.go
@@ -0,0 +1,53 @@
+package filecachejson_test
+
+import (
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachejson"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+func TestStorage(t *testing.T) {
+ prof, dev := profiledbtest.NewProfile(t)
+ cachePath := filepath.Join(t.TempDir(), "profiles.json")
+ s := filecachejson.New(cachePath)
+ require.NotNil(t, s)
+
+ fc := &internal.FileCache{
+ SyncTime: time.Now().Round(0).UTC(),
+ Profiles: []*agd.Profile{prof},
+ Devices: []*agd.Device{dev},
+ Version: internal.FileCacheVersion,
+ }
+
+ err := s.Store(fc)
+ require.NoError(t, err)
+
+ gotFC, err := s.Load()
+ require.NoError(t, err)
+ require.NotNil(t, gotFC)
+ require.NotEmpty(t, *gotFC)
+
+ assert.Equal(t, fc, gotFC)
+}
+
+func TestStorage_Load_noFile(t *testing.T) {
+ cachePath := filepath.Join(t.TempDir(), "profiles.json")
+ s := filecachejson.New(cachePath)
+ require.NotNil(t, s)
+
+ fc, err := s.Load()
+ assert.NoError(t, err)
+ assert.Nil(t, fc)
+}
diff --git a/internal/profiledb/internal/filecachepb/filecache.pb.go b/internal/profiledb/internal/filecachepb/filecache.pb.go
new file mode 100644
index 0000000..2c72dfc
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/filecache.pb.go
@@ -0,0 +1,1165 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.30.0
+// protoc v4.22.3
+// source: filecache.proto
+
+package filecachepb
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ durationpb "google.golang.org/protobuf/types/known/durationpb"
+ timestamppb "google.golang.org/protobuf/types/known/timestamppb"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type FileCache struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ SyncTime *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=sync_time,json=syncTime,proto3" json:"sync_time,omitempty"`
+ Profiles []*Profile `protobuf:"bytes,2,rep,name=profiles,proto3" json:"profiles,omitempty"`
+ Devices []*Device `protobuf:"bytes,3,rep,name=devices,proto3" json:"devices,omitempty"`
+ Version int32 `protobuf:"varint,4,opt,name=version,proto3" json:"version,omitempty"`
+}
+
+func (x *FileCache) Reset() {
+ *x = FileCache{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *FileCache) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*FileCache) ProtoMessage() {}
+
+func (x *FileCache) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use FileCache.ProtoReflect.Descriptor instead.
+func (*FileCache) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *FileCache) GetSyncTime() *timestamppb.Timestamp {
+ if x != nil {
+ return x.SyncTime
+ }
+ return nil
+}
+
+func (x *FileCache) GetProfiles() []*Profile {
+ if x != nil {
+ return x.Profiles
+ }
+ return nil
+}
+
+func (x *FileCache) GetDevices() []*Device {
+ if x != nil {
+ return x.Devices
+ }
+ return nil
+}
+
+func (x *FileCache) GetVersion() int32 {
+ if x != nil {
+ return x.Version
+ }
+ return 0
+}
+
+type Profile struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Parental *ParentalProtectionSettings `protobuf:"bytes,1,opt,name=parental,proto3" json:"parental,omitempty"`
+ // Types that are assignable to BlockingMode:
+ //
+ // *Profile_BlockingModeCustomIp
+ // *Profile_BlockingModeNxdomain
+ // *Profile_BlockingModeNullIp
+ // *Profile_BlockingModeRefused
+ BlockingMode isProfile_BlockingMode `protobuf_oneof:"blocking_mode"`
+ ProfileId string `protobuf:"bytes,6,opt,name=profile_id,json=profileId,proto3" json:"profile_id,omitempty"`
+ UpdateTime *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"`
+ DeviceIds []string `protobuf:"bytes,8,rep,name=device_ids,json=deviceIds,proto3" json:"device_ids,omitempty"`
+ RuleListIds []string `protobuf:"bytes,9,rep,name=rule_list_ids,json=ruleListIds,proto3" json:"rule_list_ids,omitempty"`
+ CustomRules []string `protobuf:"bytes,10,rep,name=custom_rules,json=customRules,proto3" json:"custom_rules,omitempty"`
+ FilteredResponseTtl *durationpb.Duration `protobuf:"bytes,11,opt,name=filtered_response_ttl,json=filteredResponseTtl,proto3" json:"filtered_response_ttl,omitempty"`
+ FilteringEnabled bool `protobuf:"varint,12,opt,name=filtering_enabled,json=filteringEnabled,proto3" json:"filtering_enabled,omitempty"`
+ SafeBrowsingEnabled bool `protobuf:"varint,13,opt,name=safe_browsing_enabled,json=safeBrowsingEnabled,proto3" json:"safe_browsing_enabled,omitempty"`
+ RuleListsEnabled bool `protobuf:"varint,14,opt,name=rule_lists_enabled,json=ruleListsEnabled,proto3" json:"rule_lists_enabled,omitempty"`
+ QueryLogEnabled bool `protobuf:"varint,15,opt,name=query_log_enabled,json=queryLogEnabled,proto3" json:"query_log_enabled,omitempty"`
+ Deleted bool `protobuf:"varint,16,opt,name=deleted,proto3" json:"deleted,omitempty"`
+ BlockPrivateRelay bool `protobuf:"varint,17,opt,name=block_private_relay,json=blockPrivateRelay,proto3" json:"block_private_relay,omitempty"`
+ BlockFirefoxCanary bool `protobuf:"varint,18,opt,name=block_firefox_canary,json=blockFirefoxCanary,proto3" json:"block_firefox_canary,omitempty"`
+}
+
+func (x *Profile) Reset() {
+ *x = Profile{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[1]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *Profile) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Profile) ProtoMessage() {}
+
+func (x *Profile) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[1]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use Profile.ProtoReflect.Descriptor instead.
+func (*Profile) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *Profile) GetParental() *ParentalProtectionSettings {
+ if x != nil {
+ return x.Parental
+ }
+ return nil
+}
+
+func (m *Profile) GetBlockingMode() isProfile_BlockingMode {
+ if m != nil {
+ return m.BlockingMode
+ }
+ return nil
+}
+
+func (x *Profile) GetBlockingModeCustomIp() *BlockingModeCustomIP {
+ if x, ok := x.GetBlockingMode().(*Profile_BlockingModeCustomIp); ok {
+ return x.BlockingModeCustomIp
+ }
+ return nil
+}
+
+func (x *Profile) GetBlockingModeNxdomain() *BlockingModeNXDOMAIN {
+ if x, ok := x.GetBlockingMode().(*Profile_BlockingModeNxdomain); ok {
+ return x.BlockingModeNxdomain
+ }
+ return nil
+}
+
+func (x *Profile) GetBlockingModeNullIp() *BlockingModeNullIP {
+ if x, ok := x.GetBlockingMode().(*Profile_BlockingModeNullIp); ok {
+ return x.BlockingModeNullIp
+ }
+ return nil
+}
+
+func (x *Profile) GetBlockingModeRefused() *BlockingModeREFUSED {
+ if x, ok := x.GetBlockingMode().(*Profile_BlockingModeRefused); ok {
+ return x.BlockingModeRefused
+ }
+ return nil
+}
+
+func (x *Profile) GetProfileId() string {
+ if x != nil {
+ return x.ProfileId
+ }
+ return ""
+}
+
+func (x *Profile) GetUpdateTime() *timestamppb.Timestamp {
+ if x != nil {
+ return x.UpdateTime
+ }
+ return nil
+}
+
+func (x *Profile) GetDeviceIds() []string {
+ if x != nil {
+ return x.DeviceIds
+ }
+ return nil
+}
+
+func (x *Profile) GetRuleListIds() []string {
+ if x != nil {
+ return x.RuleListIds
+ }
+ return nil
+}
+
+func (x *Profile) GetCustomRules() []string {
+ if x != nil {
+ return x.CustomRules
+ }
+ return nil
+}
+
+func (x *Profile) GetFilteredResponseTtl() *durationpb.Duration {
+ if x != nil {
+ return x.FilteredResponseTtl
+ }
+ return nil
+}
+
+func (x *Profile) GetFilteringEnabled() bool {
+ if x != nil {
+ return x.FilteringEnabled
+ }
+ return false
+}
+
+func (x *Profile) GetSafeBrowsingEnabled() bool {
+ if x != nil {
+ return x.SafeBrowsingEnabled
+ }
+ return false
+}
+
+func (x *Profile) GetRuleListsEnabled() bool {
+ if x != nil {
+ return x.RuleListsEnabled
+ }
+ return false
+}
+
+func (x *Profile) GetQueryLogEnabled() bool {
+ if x != nil {
+ return x.QueryLogEnabled
+ }
+ return false
+}
+
+func (x *Profile) GetDeleted() bool {
+ if x != nil {
+ return x.Deleted
+ }
+ return false
+}
+
+func (x *Profile) GetBlockPrivateRelay() bool {
+ if x != nil {
+ return x.BlockPrivateRelay
+ }
+ return false
+}
+
+func (x *Profile) GetBlockFirefoxCanary() bool {
+ if x != nil {
+ return x.BlockFirefoxCanary
+ }
+ return false
+}
+
+type isProfile_BlockingMode interface {
+ isProfile_BlockingMode()
+}
+
+type Profile_BlockingModeCustomIp struct {
+ BlockingModeCustomIp *BlockingModeCustomIP `protobuf:"bytes,2,opt,name=blocking_mode_custom_ip,json=blockingModeCustomIp,proto3,oneof"`
+}
+
+type Profile_BlockingModeNxdomain struct {
+ BlockingModeNxdomain *BlockingModeNXDOMAIN `protobuf:"bytes,3,opt,name=blocking_mode_nxdomain,json=blockingModeNxdomain,proto3,oneof"`
+}
+
+type Profile_BlockingModeNullIp struct {
+ BlockingModeNullIp *BlockingModeNullIP `protobuf:"bytes,4,opt,name=blocking_mode_null_ip,json=blockingModeNullIp,proto3,oneof"`
+}
+
+type Profile_BlockingModeRefused struct {
+ BlockingModeRefused *BlockingModeREFUSED `protobuf:"bytes,5,opt,name=blocking_mode_refused,json=blockingModeRefused,proto3,oneof"`
+}
+
+func (*Profile_BlockingModeCustomIp) isProfile_BlockingMode() {}
+
+func (*Profile_BlockingModeNxdomain) isProfile_BlockingMode() {}
+
+func (*Profile_BlockingModeNullIp) isProfile_BlockingMode() {}
+
+func (*Profile_BlockingModeRefused) isProfile_BlockingMode() {}
+
+type ParentalProtectionSettings struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Schedule *ParentalProtectionSchedule `protobuf:"bytes,1,opt,name=schedule,proto3" json:"schedule,omitempty"`
+ BlockedServices []string `protobuf:"bytes,2,rep,name=blocked_services,json=blockedServices,proto3" json:"blocked_services,omitempty"`
+ Enabled bool `protobuf:"varint,3,opt,name=enabled,proto3" json:"enabled,omitempty"`
+ BlockAdult bool `protobuf:"varint,4,opt,name=block_adult,json=blockAdult,proto3" json:"block_adult,omitempty"`
+ GeneralSafeSearch bool `protobuf:"varint,5,opt,name=general_safe_search,json=generalSafeSearch,proto3" json:"general_safe_search,omitempty"`
+ YoutubeSafeSearch bool `protobuf:"varint,6,opt,name=youtube_safe_search,json=youtubeSafeSearch,proto3" json:"youtube_safe_search,omitempty"`
+}
+
+func (x *ParentalProtectionSettings) Reset() {
+ *x = ParentalProtectionSettings{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[2]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *ParentalProtectionSettings) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*ParentalProtectionSettings) ProtoMessage() {}
+
+func (x *ParentalProtectionSettings) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[2]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use ParentalProtectionSettings.ProtoReflect.Descriptor instead.
+func (*ParentalProtectionSettings) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{2}
+}
+
+func (x *ParentalProtectionSettings) GetSchedule() *ParentalProtectionSchedule {
+ if x != nil {
+ return x.Schedule
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSettings) GetBlockedServices() []string {
+ if x != nil {
+ return x.BlockedServices
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSettings) GetEnabled() bool {
+ if x != nil {
+ return x.Enabled
+ }
+ return false
+}
+
+func (x *ParentalProtectionSettings) GetBlockAdult() bool {
+ if x != nil {
+ return x.BlockAdult
+ }
+ return false
+}
+
+func (x *ParentalProtectionSettings) GetGeneralSafeSearch() bool {
+ if x != nil {
+ return x.GeneralSafeSearch
+ }
+ return false
+}
+
+func (x *ParentalProtectionSettings) GetYoutubeSafeSearch() bool {
+ if x != nil {
+ return x.YoutubeSafeSearch
+ }
+ return false
+}
+
+type ParentalProtectionSchedule struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ TimeZone string `protobuf:"bytes,1,opt,name=time_zone,json=timeZone,proto3" json:"time_zone,omitempty"`
+ Mon *DayRange `protobuf:"bytes,2,opt,name=mon,proto3" json:"mon,omitempty"`
+ Tue *DayRange `protobuf:"bytes,3,opt,name=tue,proto3" json:"tue,omitempty"`
+ Wed *DayRange `protobuf:"bytes,4,opt,name=wed,proto3" json:"wed,omitempty"`
+ Thu *DayRange `protobuf:"bytes,5,opt,name=thu,proto3" json:"thu,omitempty"`
+ Fri *DayRange `protobuf:"bytes,6,opt,name=fri,proto3" json:"fri,omitempty"`
+ Sat *DayRange `protobuf:"bytes,7,opt,name=sat,proto3" json:"sat,omitempty"`
+ Sun *DayRange `protobuf:"bytes,8,opt,name=sun,proto3" json:"sun,omitempty"`
+}
+
+func (x *ParentalProtectionSchedule) Reset() {
+ *x = ParentalProtectionSchedule{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[3]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *ParentalProtectionSchedule) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*ParentalProtectionSchedule) ProtoMessage() {}
+
+func (x *ParentalProtectionSchedule) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[3]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use ParentalProtectionSchedule.ProtoReflect.Descriptor instead.
+func (*ParentalProtectionSchedule) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{3}
+}
+
+func (x *ParentalProtectionSchedule) GetTimeZone() string {
+ if x != nil {
+ return x.TimeZone
+ }
+ return ""
+}
+
+func (x *ParentalProtectionSchedule) GetMon() *DayRange {
+ if x != nil {
+ return x.Mon
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetTue() *DayRange {
+ if x != nil {
+ return x.Tue
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetWed() *DayRange {
+ if x != nil {
+ return x.Wed
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetThu() *DayRange {
+ if x != nil {
+ return x.Thu
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetFri() *DayRange {
+ if x != nil {
+ return x.Fri
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetSat() *DayRange {
+ if x != nil {
+ return x.Sat
+ }
+ return nil
+}
+
+func (x *ParentalProtectionSchedule) GetSun() *DayRange {
+ if x != nil {
+ return x.Sun
+ }
+ return nil
+}
+
+type DayRange struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
+ End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"`
+}
+
+func (x *DayRange) Reset() {
+ *x = DayRange{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[4]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *DayRange) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*DayRange) ProtoMessage() {}
+
+func (x *DayRange) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[4]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use DayRange.ProtoReflect.Descriptor instead.
+func (*DayRange) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{4}
+}
+
+func (x *DayRange) GetStart() uint32 {
+ if x != nil {
+ return x.Start
+ }
+ return 0
+}
+
+func (x *DayRange) GetEnd() uint32 {
+ if x != nil {
+ return x.End
+ }
+ return 0
+}
+
+type BlockingModeCustomIP struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Ipv4 []byte `protobuf:"bytes,1,opt,name=ipv4,proto3" json:"ipv4,omitempty"`
+ Ipv6 []byte `protobuf:"bytes,2,opt,name=ipv6,proto3" json:"ipv6,omitempty"`
+}
+
+func (x *BlockingModeCustomIP) Reset() {
+ *x = BlockingModeCustomIP{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[5]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *BlockingModeCustomIP) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*BlockingModeCustomIP) ProtoMessage() {}
+
+func (x *BlockingModeCustomIP) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[5]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use BlockingModeCustomIP.ProtoReflect.Descriptor instead.
+func (*BlockingModeCustomIP) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{5}
+}
+
+func (x *BlockingModeCustomIP) GetIpv4() []byte {
+ if x != nil {
+ return x.Ipv4
+ }
+ return nil
+}
+
+func (x *BlockingModeCustomIP) GetIpv6() []byte {
+ if x != nil {
+ return x.Ipv6
+ }
+ return nil
+}
+
+type BlockingModeNXDOMAIN struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+}
+
+func (x *BlockingModeNXDOMAIN) Reset() {
+ *x = BlockingModeNXDOMAIN{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[6]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *BlockingModeNXDOMAIN) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*BlockingModeNXDOMAIN) ProtoMessage() {}
+
+func (x *BlockingModeNXDOMAIN) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[6]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use BlockingModeNXDOMAIN.ProtoReflect.Descriptor instead.
+func (*BlockingModeNXDOMAIN) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{6}
+}
+
+type BlockingModeNullIP struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+}
+
+func (x *BlockingModeNullIP) Reset() {
+ *x = BlockingModeNullIP{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[7]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *BlockingModeNullIP) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*BlockingModeNullIP) ProtoMessage() {}
+
+func (x *BlockingModeNullIP) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[7]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use BlockingModeNullIP.ProtoReflect.Descriptor instead.
+func (*BlockingModeNullIP) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{7}
+}
+
+type BlockingModeREFUSED struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+}
+
+func (x *BlockingModeREFUSED) Reset() {
+ *x = BlockingModeREFUSED{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[8]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *BlockingModeREFUSED) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*BlockingModeREFUSED) ProtoMessage() {}
+
+func (x *BlockingModeREFUSED) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[8]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use BlockingModeREFUSED.ProtoReflect.Descriptor instead.
+func (*BlockingModeREFUSED) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{8}
+}
+
+type Device struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ DeviceId string `protobuf:"bytes,1,opt,name=device_id,json=deviceId,proto3" json:"device_id,omitempty"`
+ LinkedIp []byte `protobuf:"bytes,2,opt,name=linked_ip,json=linkedIp,proto3" json:"linked_ip,omitempty"`
+ DeviceName string `protobuf:"bytes,3,opt,name=device_name,json=deviceName,proto3" json:"device_name,omitempty"`
+ DedicatedIps [][]byte `protobuf:"bytes,4,rep,name=dedicated_ips,json=dedicatedIps,proto3" json:"dedicated_ips,omitempty"`
+ FilteringEnabled bool `protobuf:"varint,5,opt,name=filtering_enabled,json=filteringEnabled,proto3" json:"filtering_enabled,omitempty"`
+}
+
+func (x *Device) Reset() {
+ *x = Device{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_filecache_proto_msgTypes[9]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *Device) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Device) ProtoMessage() {}
+
+func (x *Device) ProtoReflect() protoreflect.Message {
+ mi := &file_filecache_proto_msgTypes[9]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use Device.ProtoReflect.Descriptor instead.
+func (*Device) Descriptor() ([]byte, []int) {
+ return file_filecache_proto_rawDescGZIP(), []int{9}
+}
+
+func (x *Device) GetDeviceId() string {
+ if x != nil {
+ return x.DeviceId
+ }
+ return ""
+}
+
+func (x *Device) GetLinkedIp() []byte {
+ if x != nil {
+ return x.LinkedIp
+ }
+ return nil
+}
+
+func (x *Device) GetDeviceName() string {
+ if x != nil {
+ return x.DeviceName
+ }
+ return ""
+}
+
+func (x *Device) GetDedicatedIps() [][]byte {
+ if x != nil {
+ return x.DedicatedIps
+ }
+ return nil
+}
+
+func (x *Device) GetFilteringEnabled() bool {
+ if x != nil {
+ return x.FilteringEnabled
+ }
+ return false
+}
+
+var File_filecache_proto protoreflect.FileDescriptor
+
+var file_filecache_proto_rawDesc = []byte{
+ 0x0a, 0x0f, 0x66, 0x69, 0x6c, 0x65, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x12, 0x09, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x1a, 0x1e, 0x67, 0x6f,
+ 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75,
+ 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f,
+ 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69,
+ 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xbb, 0x01,
+ 0x0a, 0x09, 0x46, 0x69, 0x6c, 0x65, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x37, 0x0a, 0x09, 0x73,
+ 0x79, 0x6e, 0x63, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a,
+ 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
+ 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x08, 0x73, 0x79, 0x6e, 0x63,
+ 0x54, 0x69, 0x6d, 0x65, 0x12, 0x2e, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x73,
+ 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65,
+ 0x64, 0x62, 0x2e, 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x66,
+ 0x69, 0x6c, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x07, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18,
+ 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64,
+ 0x62, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x07, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65,
+ 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01,
+ 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x82, 0x08, 0x0a, 0x07,
+ 0x50, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x12, 0x41, 0x0a, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e,
+ 0x74, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x70, 0x72, 0x6f, 0x66,
+ 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x50, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72,
+ 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73,
+ 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x12, 0x58, 0x0a, 0x17, 0x62, 0x6c,
+ 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x63, 0x75, 0x73, 0x74,
+ 0x6f, 0x6d, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x70, 0x72,
+ 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67,
+ 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x49, 0x50, 0x48, 0x00, 0x52, 0x14,
+ 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74,
+ 0x6f, 0x6d, 0x49, 0x70, 0x12, 0x57, 0x0a, 0x16, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67,
+ 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x6e, 0x78, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62,
+ 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x58, 0x44,
+ 0x4f, 0x4d, 0x41, 0x49, 0x4e, 0x48, 0x00, 0x52, 0x14, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e,
+ 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x78, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x52, 0x0a,
+ 0x15, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x5f, 0x6e,
+ 0x75, 0x6c, 0x6c, 0x5f, 0x69, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x70,
+ 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e,
+ 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49, 0x50, 0x48, 0x00, 0x52, 0x12, 0x62,
+ 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49,
+ 0x70, 0x12, 0x54, 0x0a, 0x15, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f,
+ 0x64, 0x65, 0x5f, 0x72, 0x65, 0x66, 0x75, 0x73, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b,
+ 0x32, 0x1e, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x42, 0x6c, 0x6f,
+ 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44,
+ 0x48, 0x00, 0x52, 0x13, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65,
+ 0x52, 0x65, 0x66, 0x75, 0x73, 0x65, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x6f, 0x66, 0x69,
+ 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x6f,
+ 0x66, 0x69, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x3b, 0x0a, 0x0b, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65,
+ 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
+ 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
+ 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54,
+ 0x69, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64,
+ 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49,
+ 0x64, 0x73, 0x12, 0x22, 0x0a, 0x0d, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x5f,
+ 0x69, 0x64, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x75, 0x6c, 0x65, 0x4c,
+ 0x69, 0x73, 0x74, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d,
+ 0x5f, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x63, 0x75,
+ 0x73, 0x74, 0x6f, 0x6d, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x4d, 0x0a, 0x15, 0x66, 0x69, 0x6c,
+ 0x74, 0x65, 0x72, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74,
+ 0x74, 0x6c, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
+ 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74,
+ 0x69, 0x6f, 0x6e, 0x52, 0x13, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x65, 0x64, 0x52, 0x65, 0x73,
+ 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x54, 0x74, 0x6c, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6c, 0x74,
+ 0x65, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0c, 0x20,
+ 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x45, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x15, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x62, 0x72,
+ 0x6f, 0x77, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0d,
+ 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x73, 0x61, 0x66, 0x65, 0x42, 0x72, 0x6f, 0x77, 0x73, 0x69,
+ 0x6e, 0x67, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2c, 0x0a, 0x12, 0x72, 0x75, 0x6c,
+ 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x73, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18,
+ 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x75, 0x6c, 0x65, 0x4c, 0x69, 0x73, 0x74, 0x73,
+ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x71, 0x75, 0x65, 0x72, 0x79,
+ 0x5f, 0x6c, 0x6f, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x4c, 0x6f, 0x67, 0x45, 0x6e, 0x61, 0x62,
+ 0x6c, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x18, 0x10,
+ 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x12, 0x2e, 0x0a,
+ 0x13, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x5f, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x5f, 0x72,
+ 0x65, 0x6c, 0x61, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x62, 0x6c, 0x6f, 0x63,
+ 0x6b, 0x50, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x30, 0x0a,
+ 0x14, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x66, 0x6f, 0x78, 0x5f, 0x63,
+ 0x61, 0x6e, 0x61, 0x72, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x62, 0x6c, 0x6f,
+ 0x63, 0x6b, 0x46, 0x69, 0x72, 0x65, 0x66, 0x6f, 0x78, 0x43, 0x61, 0x6e, 0x61, 0x72, 0x79, 0x42,
+ 0x0f, 0x0a, 0x0d, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65,
+ 0x22, 0xa5, 0x02, 0x0a, 0x1a, 0x50, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f,
+ 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12,
+ 0x41, 0x0a, 0x08, 0x73, 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x25, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x50, 0x61,
+ 0x72, 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e,
+ 0x53, 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x52, 0x08, 0x73, 0x63, 0x68, 0x65, 0x64, 0x75,
+ 0x6c, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x73, 0x65,
+ 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x62, 0x6c,
+ 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x18, 0x0a,
+ 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07,
+ 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x62, 0x6c, 0x6f, 0x63, 0x6b,
+ 0x5f, 0x61, 0x64, 0x75, 0x6c, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x62, 0x6c,
+ 0x6f, 0x63, 0x6b, 0x41, 0x64, 0x75, 0x6c, 0x74, 0x12, 0x2e, 0x0a, 0x13, 0x67, 0x65, 0x6e, 0x65,
+ 0x72, 0x61, 0x6c, 0x5f, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x18,
+ 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x53, 0x61,
+ 0x66, 0x65, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x12, 0x2e, 0x0a, 0x13, 0x79, 0x6f, 0x75, 0x74,
+ 0x75, 0x62, 0x65, 0x5f, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x73, 0x65, 0x61, 0x72, 0x63, 0x68, 0x18,
+ 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x79, 0x6f, 0x75, 0x74, 0x75, 0x62, 0x65, 0x53, 0x61,
+ 0x66, 0x65, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x22, 0xca, 0x02, 0x0a, 0x1a, 0x50, 0x61, 0x72,
+ 0x65, 0x6e, 0x74, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53,
+ 0x63, 0x68, 0x65, 0x64, 0x75, 0x6c, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x5f,
+ 0x7a, 0x6f, 0x6e, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x69, 0x6d, 0x65,
+ 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x25, 0x0a, 0x03, 0x6d, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61,
+ 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x6d, 0x6f, 0x6e, 0x12, 0x25, 0x0a, 0x03, 0x74,
+ 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69,
+ 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x74,
+ 0x75, 0x65, 0x12, 0x25, 0x0a, 0x03, 0x77, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32,
+ 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52,
+ 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x77, 0x65, 0x64, 0x12, 0x25, 0x0a, 0x03, 0x74, 0x68, 0x75,
+ 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65,
+ 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x74, 0x68, 0x75,
+ 0x12, 0x25, 0x0a, 0x03, 0x66, 0x72, 0x69, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e,
+ 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e,
+ 0x67, 0x65, 0x52, 0x03, 0x66, 0x72, 0x69, 0x12, 0x25, 0x0a, 0x03, 0x73, 0x61, 0x74, 0x18, 0x07,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62,
+ 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x03, 0x73, 0x61, 0x74, 0x12, 0x25,
+ 0x0a, 0x03, 0x73, 0x75, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x70, 0x72,
+ 0x6f, 0x66, 0x69, 0x6c, 0x65, 0x64, 0x62, 0x2e, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67, 0x65,
+ 0x52, 0x03, 0x73, 0x75, 0x6e, 0x22, 0x32, 0x0a, 0x08, 0x44, 0x61, 0x79, 0x52, 0x61, 0x6e, 0x67,
+ 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d,
+ 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x22, 0x3e, 0x0a, 0x14, 0x42, 0x6c, 0x6f,
+ 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x49,
+ 0x50, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
+ 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x0c, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x22, 0x16, 0x0a, 0x14, 0x42, 0x6c, 0x6f,
+ 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x4e, 0x58, 0x44, 0x4f, 0x4d, 0x41, 0x49,
+ 0x4e, 0x22, 0x14, 0x0a, 0x12, 0x42, 0x6c, 0x6f, 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64,
+ 0x65, 0x4e, 0x75, 0x6c, 0x6c, 0x49, 0x50, 0x22, 0x15, 0x0a, 0x13, 0x42, 0x6c, 0x6f, 0x63, 0x6b,
+ 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x22, 0xb5,
+ 0x01, 0x0a, 0x06, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x76,
+ 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x65,
+ 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6c, 0x69, 0x6e, 0x6b, 0x65, 0x64,
+ 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x6c, 0x69, 0x6e, 0x6b, 0x65,
+ 0x64, 0x49, 0x70, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61,
+ 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65,
+ 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x64, 0x69, 0x63, 0x61, 0x74, 0x65,
+ 0x64, 0x5f, 0x69, 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0c, 0x64, 0x65, 0x64,
+ 0x69, 0x63, 0x61, 0x74, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6c,
+ 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05,
+ 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x69, 0x6e, 0x67, 0x45,
+ 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x0f, 0x5a, 0x0d, 0x2e, 0x2f, 0x66, 0x69, 0x6c, 0x65,
+ 0x63, 0x61, 0x63, 0x68, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_filecache_proto_rawDescOnce sync.Once
+ file_filecache_proto_rawDescData = file_filecache_proto_rawDesc
+)
+
+func file_filecache_proto_rawDescGZIP() []byte {
+ file_filecache_proto_rawDescOnce.Do(func() {
+ file_filecache_proto_rawDescData = protoimpl.X.CompressGZIP(file_filecache_proto_rawDescData)
+ })
+ return file_filecache_proto_rawDescData
+}
+
+var file_filecache_proto_msgTypes = make([]protoimpl.MessageInfo, 10)
+var file_filecache_proto_goTypes = []interface{}{
+ (*FileCache)(nil), // 0: profiledb.FileCache
+ (*Profile)(nil), // 1: profiledb.Profile
+ (*ParentalProtectionSettings)(nil), // 2: profiledb.ParentalProtectionSettings
+ (*ParentalProtectionSchedule)(nil), // 3: profiledb.ParentalProtectionSchedule
+ (*DayRange)(nil), // 4: profiledb.DayRange
+ (*BlockingModeCustomIP)(nil), // 5: profiledb.BlockingModeCustomIP
+ (*BlockingModeNXDOMAIN)(nil), // 6: profiledb.BlockingModeNXDOMAIN
+ (*BlockingModeNullIP)(nil), // 7: profiledb.BlockingModeNullIP
+ (*BlockingModeREFUSED)(nil), // 8: profiledb.BlockingModeREFUSED
+ (*Device)(nil), // 9: profiledb.Device
+ (*timestamppb.Timestamp)(nil), // 10: google.protobuf.Timestamp
+ (*durationpb.Duration)(nil), // 11: google.protobuf.Duration
+}
+var file_filecache_proto_depIdxs = []int32{
+ 10, // 0: profiledb.FileCache.sync_time:type_name -> google.protobuf.Timestamp
+ 1, // 1: profiledb.FileCache.profiles:type_name -> profiledb.Profile
+ 9, // 2: profiledb.FileCache.devices:type_name -> profiledb.Device
+ 2, // 3: profiledb.Profile.parental:type_name -> profiledb.ParentalProtectionSettings
+ 5, // 4: profiledb.Profile.blocking_mode_custom_ip:type_name -> profiledb.BlockingModeCustomIP
+ 6, // 5: profiledb.Profile.blocking_mode_nxdomain:type_name -> profiledb.BlockingModeNXDOMAIN
+ 7, // 6: profiledb.Profile.blocking_mode_null_ip:type_name -> profiledb.BlockingModeNullIP
+ 8, // 7: profiledb.Profile.blocking_mode_refused:type_name -> profiledb.BlockingModeREFUSED
+ 10, // 8: profiledb.Profile.update_time:type_name -> google.protobuf.Timestamp
+ 11, // 9: profiledb.Profile.filtered_response_ttl:type_name -> google.protobuf.Duration
+ 3, // 10: profiledb.ParentalProtectionSettings.schedule:type_name -> profiledb.ParentalProtectionSchedule
+ 4, // 11: profiledb.ParentalProtectionSchedule.mon:type_name -> profiledb.DayRange
+ 4, // 12: profiledb.ParentalProtectionSchedule.tue:type_name -> profiledb.DayRange
+ 4, // 13: profiledb.ParentalProtectionSchedule.wed:type_name -> profiledb.DayRange
+ 4, // 14: profiledb.ParentalProtectionSchedule.thu:type_name -> profiledb.DayRange
+ 4, // 15: profiledb.ParentalProtectionSchedule.fri:type_name -> profiledb.DayRange
+ 4, // 16: profiledb.ParentalProtectionSchedule.sat:type_name -> profiledb.DayRange
+ 4, // 17: profiledb.ParentalProtectionSchedule.sun:type_name -> profiledb.DayRange
+ 18, // [18:18] is the sub-list for method output_type
+ 18, // [18:18] is the sub-list for method input_type
+ 18, // [18:18] is the sub-list for extension type_name
+ 18, // [18:18] is the sub-list for extension extendee
+ 0, // [0:18] is the sub-list for field type_name
+}
+
+func init() { file_filecache_proto_init() }
+func file_filecache_proto_init() {
+ if File_filecache_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_filecache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*FileCache); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Profile); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*ParentalProtectionSettings); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*ParentalProtectionSchedule); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*DayRange); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*BlockingModeCustomIP); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*BlockingModeNXDOMAIN); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*BlockingModeNullIP); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*BlockingModeREFUSED); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_filecache_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Device); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ file_filecache_proto_msgTypes[1].OneofWrappers = []interface{}{
+ (*Profile_BlockingModeCustomIp)(nil),
+ (*Profile_BlockingModeNxdomain)(nil),
+ (*Profile_BlockingModeNullIp)(nil),
+ (*Profile_BlockingModeRefused)(nil),
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_filecache_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 10,
+ NumExtensions: 0,
+ NumServices: 0,
+ },
+ GoTypes: file_filecache_proto_goTypes,
+ DependencyIndexes: file_filecache_proto_depIdxs,
+ MessageInfos: file_filecache_proto_msgTypes,
+ }.Build()
+ File_filecache_proto = out.File
+ file_filecache_proto_rawDesc = nil
+ file_filecache_proto_goTypes = nil
+ file_filecache_proto_depIdxs = nil
+}
diff --git a/internal/profiledb/internal/filecachepb/filecache.proto b/internal/profiledb/internal/filecachepb/filecache.proto
new file mode 100644
index 0000000..6fd8f68
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/filecache.proto
@@ -0,0 +1,82 @@
+syntax = "proto3";
+
+package profiledb;
+
+option go_package = "./filecachepb";
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+
+message FileCache {
+ google.protobuf.Timestamp sync_time = 1;
+ repeated Profile profiles = 2;
+ repeated Device devices = 3;
+ int32 version = 4;
+}
+
+message Profile {
+ ParentalProtectionSettings parental = 1;
+ oneof blocking_mode {
+ BlockingModeCustomIP blocking_mode_custom_ip = 2;
+ BlockingModeNXDOMAIN blocking_mode_nxdomain = 3;
+ BlockingModeNullIP blocking_mode_null_ip = 4;
+ BlockingModeREFUSED blocking_mode_refused = 5;
+ }
+ string profile_id = 6;
+ google.protobuf.Timestamp update_time = 7;
+ repeated string device_ids = 8;
+ repeated string rule_list_ids = 9;
+ repeated string custom_rules = 10;
+ google.protobuf.Duration filtered_response_ttl = 11;
+ bool filtering_enabled = 12;
+ bool safe_browsing_enabled = 13;
+ bool rule_lists_enabled = 14;
+ bool query_log_enabled = 15;
+ bool deleted = 16;
+ bool block_private_relay = 17;
+ bool block_firefox_canary = 18;
+}
+
+message ParentalProtectionSettings {
+ ParentalProtectionSchedule schedule = 1;
+ repeated string blocked_services = 2;
+ bool enabled = 3;
+ bool block_adult = 4;
+ bool general_safe_search = 5;
+ bool youtube_safe_search = 6;
+}
+
+message ParentalProtectionSchedule {
+ string time_zone = 1;
+ DayRange mon = 2;
+ DayRange tue = 3;
+ DayRange wed = 4;
+ DayRange thu = 5;
+ DayRange fri = 6;
+ DayRange sat = 7;
+ DayRange sun = 8;
+}
+
+message DayRange {
+ uint32 start = 1;
+ uint32 end = 2;
+}
+
+message BlockingModeCustomIP {
+ bytes ipv4 = 1;
+ bytes ipv6 = 2;
+}
+
+message BlockingModeNXDOMAIN {}
+
+message BlockingModeNullIP {}
+
+message BlockingModeREFUSED {}
+
+message Device {
+ string device_id = 1;
+ bytes linked_ip = 2;
+ string device_name = 3;
+ repeated bytes dedicated_ips = 4;
+ bool filtering_enabled = 5;
+}
diff --git a/internal/profiledb/internal/filecachepb/filecachepb.go b/internal/profiledb/internal/filecachepb/filecachepb.go
new file mode 100644
index 0000000..71a8438
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/filecachepb.go
@@ -0,0 +1,388 @@
+// Package filecachepb contains the protobuf structures for the profile cache.
+package filecachepb
+
+import (
+ "fmt"
+ "net/netip"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "google.golang.org/protobuf/types/known/durationpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+)
+
+// toInternal converts the protobuf-encoded data into a cache structure.
+func toInternal(fc *FileCache) (c *internal.FileCache, err error) {
+ profiles, err := profilesToInternal(fc.Profiles)
+ if err != nil {
+ return nil, fmt.Errorf("converting profiles: %w", err)
+ }
+
+ devices, err := devicesFromProtobuf(fc.Devices)
+ if err != nil {
+ return nil, fmt.Errorf("converting devices: %w", err)
+ }
+
+ return &internal.FileCache{
+ SyncTime: fc.SyncTime.AsTime(),
+ Profiles: profiles,
+ Devices: devices,
+ Version: fc.Version,
+ }, nil
+}
+
+// toProtobuf converts the cache structure into protobuf structure for encoding.
+func toProtobuf(c *internal.FileCache) (pbFileCache *FileCache) {
+ return &FileCache{
+ SyncTime: timestamppb.New(c.SyncTime),
+ Profiles: profilesToProtobuf(c.Profiles),
+ Devices: devicesToProtobuf(c.Devices),
+ Version: c.Version,
+ }
+}
+
+// profilesToInternal converts protobuf profile structures into internal ones.
+func profilesToInternal(pbProfiles []*Profile) (profiles []*agd.Profile, err error) {
+ profiles = make([]*agd.Profile, 0, len(pbProfiles))
+ for i, pbProf := range pbProfiles {
+ var prof *agd.Profile
+ prof, err = pbProf.toInternal()
+ if err != nil {
+ return nil, fmt.Errorf("profile at index %d: %w", i, err)
+ }
+
+ profiles = append(profiles, prof)
+ }
+
+ return profiles, nil
+}
+
+// toInternal converts a protobuf profile structure to an internal one.
+func (x *Profile) toInternal() (prof *agd.Profile, err error) {
+ parental, err := x.Parental.toInternal()
+ if err != nil {
+ return nil, fmt.Errorf("parental: %w", err)
+ }
+
+ m, err := blockingModeToInternal(x.BlockingMode)
+ if err != nil {
+ return nil, fmt.Errorf("blocking mode: %w", err)
+ }
+
+ return &agd.Profile{
+ Parental: parental,
+ BlockingMode: m,
+ ID: agd.ProfileID(x.ProfileId),
+ UpdateTime: x.UpdateTime.AsTime(),
+ // Consider device IDs to have been prevalidated.
+ DeviceIDs: unsafelyConvertStrSlice[string, agd.DeviceID](x.DeviceIds),
+ // Consider rule-list IDs to have been prevalidated.
+ RuleListIDs: unsafelyConvertStrSlice[string, agd.FilterListID](x.RuleListIds),
+ // Consider rule-list IDs to have been prevalidated.
+ CustomRules: unsafelyConvertStrSlice[string, agd.FilterRuleText](x.CustomRules),
+ FilteredResponseTTL: x.FilteredResponseTtl.AsDuration(),
+ FilteringEnabled: x.FilteringEnabled,
+ SafeBrowsingEnabled: x.SafeBrowsingEnabled,
+ RuleListsEnabled: x.RuleListsEnabled,
+ QueryLogEnabled: x.QueryLogEnabled,
+ Deleted: x.Deleted,
+ BlockPrivateRelay: x.BlockPrivateRelay,
+ BlockFirefoxCanary: x.BlockFirefoxCanary,
+ }, nil
+}
+
+// toInternal converts a protobuf parental-settings structure to an internal
+// one.
+func (x *ParentalProtectionSettings) toInternal() (s *agd.ParentalProtectionSettings, err error) {
+ if x == nil {
+ return nil, nil
+ }
+
+ schedule, err := x.Schedule.toInternal()
+ if err != nil {
+ return nil, fmt.Errorf("schedule: %w", err)
+ }
+
+ return &agd.ParentalProtectionSettings{
+ Schedule: schedule,
+ // Consider block service IDs to have been prevalidated.
+ BlockedServices: unsafelyConvertStrSlice[string, agd.BlockedServiceID](
+ x.BlockedServices,
+ ),
+ Enabled: x.Enabled,
+ BlockAdult: x.BlockAdult,
+ GeneralSafeSearch: x.GeneralSafeSearch,
+ YoutubeSafeSearch: x.YoutubeSafeSearch,
+ }, nil
+}
+
+// toInternal converts a protobuf protection-schedule structure to an internal
+// one.
+func (x *ParentalProtectionSchedule) toInternal() (s *agd.ParentalProtectionSchedule, err error) {
+ if x == nil {
+ return nil, nil
+ }
+
+ loc, err := agdtime.LoadLocation(x.TimeZone)
+ if err != nil {
+ return nil, fmt.Errorf("time zone: %w", err)
+ }
+
+ return &agd.ParentalProtectionSchedule{
+ // Consider the lengths to be prevalidated.
+ Week: &agd.WeeklySchedule{
+ time.Monday: {Start: uint16(x.Mon.Start), End: uint16(x.Mon.End)},
+ time.Tuesday: {Start: uint16(x.Tue.Start), End: uint16(x.Tue.End)},
+ time.Wednesday: {Start: uint16(x.Wed.Start), End: uint16(x.Wed.End)},
+ time.Thursday: {Start: uint16(x.Thu.Start), End: uint16(x.Thu.End)},
+ time.Friday: {Start: uint16(x.Fri.Start), End: uint16(x.Fri.End)},
+ time.Saturday: {Start: uint16(x.Sat.Start), End: uint16(x.Sat.End)},
+ time.Sunday: {Start: uint16(x.Sun.Start), End: uint16(x.Sun.End)},
+ },
+ TimeZone: loc,
+ }, nil
+}
+
+// blockingModeToInternal converts a protobuf blocking-mode sum-type to an
+// internal one.
+func blockingModeToInternal(
+ pbBlockingMode isProfile_BlockingMode,
+) (m dnsmsg.BlockingModeCodec, err error) {
+ switch pbm := pbBlockingMode.(type) {
+ case *Profile_BlockingModeCustomIp:
+ custom := &dnsmsg.BlockingModeCustomIP{}
+ err = custom.IPv4.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv4)
+ if err != nil {
+ return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv4: %w", err)
+ }
+
+ err = custom.IPv6.UnmarshalBinary(pbm.BlockingModeCustomIp.Ipv6)
+ if err != nil {
+ return dnsmsg.BlockingModeCodec{}, fmt.Errorf("bad custom ipv6: %w", err)
+ }
+
+ m.Mode = custom
+ case *Profile_BlockingModeNxdomain:
+ m.Mode = &dnsmsg.BlockingModeNXDOMAIN{}
+ case *Profile_BlockingModeNullIp:
+ m.Mode = &dnsmsg.BlockingModeNullIP{}
+ case *Profile_BlockingModeRefused:
+ m.Mode = &dnsmsg.BlockingModeREFUSED{}
+ default:
+ // Consider unhandled type-switch cases programmer errors.
+ panic(fmt.Errorf("bad pb blocking mode %T(%[1]v)", m))
+ }
+
+ return m, nil
+}
+
+// devicesToInternal converts protobuf device structures into internal ones.
+func devicesFromProtobuf(pbDevices []*Device) (devices []*agd.Device, err error) {
+ devices = make([]*agd.Device, 0, len(pbDevices))
+ for i, pbDev := range pbDevices {
+ var dev *agd.Device
+ dev, err = pbDev.toInternal()
+ if err != nil {
+ return nil, fmt.Errorf("device at index %d: %w", i, err)
+ }
+
+ devices = append(devices, dev)
+ }
+
+ return devices, nil
+}
+
+// toInternal converts a protobuf device structure to an internal one.
+func (x *Device) toInternal() (d *agd.Device, err error) {
+ var linkedIP netip.Addr
+ err = linkedIP.UnmarshalBinary(x.LinkedIp)
+ if err != nil {
+ return nil, fmt.Errorf("linked ip: %w", err)
+ }
+
+ var dedicatedIPs []netip.Addr
+ dedicatedIPs, err = byteSlicesToIPs(x.DedicatedIps)
+ if err != nil {
+ return nil, fmt.Errorf("dedicated ips: %w", err)
+ }
+
+ return &agd.Device{
+ // Consider device IDs to have been prevalidated.
+ ID: agd.DeviceID(x.DeviceId),
+ LinkedIP: linkedIP,
+ // Consider device names to have been prevalidated.
+ Name: agd.DeviceName(x.DeviceName),
+ DedicatedIPs: dedicatedIPs,
+ FilteringEnabled: x.FilteringEnabled,
+ }, nil
+}
+
+// byteSlicesToIPs converts a slice of byte slices into a slice of netip.Addrs.
+func byteSlicesToIPs(data [][]byte) (ips []netip.Addr, err error) {
+ if data == nil {
+ return nil, nil
+ }
+
+ ips = make([]netip.Addr, 0, len(data))
+ for i, ipData := range data {
+ var ip netip.Addr
+ err = ip.UnmarshalBinary(ipData)
+ if err != nil {
+ return nil, fmt.Errorf("ip at index %d: %w", i, err)
+ }
+
+ ips = append(ips, ip)
+ }
+
+ return ips, nil
+}
+
+// profilesToProtobuf converts a slice of profiles to protobuf structures.
+func profilesToProtobuf(profiles []*agd.Profile) (pbProfiles []*Profile) {
+ pbProfiles = make([]*Profile, 0, len(profiles))
+ for _, p := range profiles {
+ pbProfiles = append(pbProfiles, &Profile{
+ Parental: parentalToProtobuf(p.Parental),
+ BlockingMode: blockingModeToProtobuf(p.BlockingMode),
+ ProfileId: string(p.ID),
+ UpdateTime: timestamppb.New(p.UpdateTime),
+ DeviceIds: unsafelyConvertStrSlice[agd.DeviceID, string](p.DeviceIDs),
+ RuleListIds: unsafelyConvertStrSlice[agd.FilterListID, string](p.RuleListIDs),
+ CustomRules: unsafelyConvertStrSlice[agd.FilterRuleText, string](p.CustomRules),
+ FilteredResponseTtl: durationpb.New(p.FilteredResponseTTL),
+ FilteringEnabled: p.FilteringEnabled,
+ SafeBrowsingEnabled: p.SafeBrowsingEnabled,
+ RuleListsEnabled: p.RuleListsEnabled,
+ QueryLogEnabled: p.QueryLogEnabled,
+ Deleted: p.Deleted,
+ BlockPrivateRelay: p.BlockPrivateRelay,
+ BlockFirefoxCanary: p.BlockFirefoxCanary,
+ })
+ }
+
+ return pbProfiles
+}
+
+// parentalToProtobuf converts parental settings to protobuf structure.
+func parentalToProtobuf(s *agd.ParentalProtectionSettings) (pbSetts *ParentalProtectionSettings) {
+ if s == nil {
+ return nil
+ }
+
+ return &ParentalProtectionSettings{
+ Schedule: scheduleToProtobuf(s.Schedule),
+ BlockedServices: unsafelyConvertStrSlice[agd.BlockedServiceID, string](s.BlockedServices),
+ Enabled: s.Enabled,
+ BlockAdult: s.BlockAdult,
+ GeneralSafeSearch: s.GeneralSafeSearch,
+ YoutubeSafeSearch: s.YoutubeSafeSearch,
+ }
+}
+
+// parentalToProtobuf converts parental-settings schedule to protobuf structure.
+func scheduleToProtobuf(s *agd.ParentalProtectionSchedule) (pbSched *ParentalProtectionSchedule) {
+ if s == nil {
+ return nil
+ }
+
+ return &ParentalProtectionSchedule{
+ TimeZone: s.TimeZone.String(),
+ Mon: &DayRange{
+ Start: uint32(s.Week[time.Monday].Start),
+ End: uint32(s.Week[time.Monday].End),
+ },
+ Tue: &DayRange{
+ Start: uint32(s.Week[time.Tuesday].Start),
+ End: uint32(s.Week[time.Tuesday].End),
+ },
+ Wed: &DayRange{
+ Start: uint32(s.Week[time.Wednesday].Start),
+ End: uint32(s.Week[time.Wednesday].End),
+ },
+ Thu: &DayRange{
+ Start: uint32(s.Week[time.Thursday].Start),
+ End: uint32(s.Week[time.Thursday].End),
+ },
+ Fri: &DayRange{
+ Start: uint32(s.Week[time.Friday].Start),
+ End: uint32(s.Week[time.Friday].End),
+ },
+ Sat: &DayRange{
+ Start: uint32(s.Week[time.Saturday].Start),
+ End: uint32(s.Week[time.Saturday].End),
+ },
+ Sun: &DayRange{
+ Start: uint32(s.Week[time.Sunday].Start),
+ End: uint32(s.Week[time.Sunday].End),
+ },
+ }
+}
+
+// blockingModeToProtobuf converts a blocking-mode sum-type to a protobuf one.
+func blockingModeToProtobuf(m dnsmsg.BlockingModeCodec) (pbBlockingMode isProfile_BlockingMode) {
+ switch m := m.Mode.(type) {
+ case *dnsmsg.BlockingModeCustomIP:
+ return &Profile_BlockingModeCustomIp{
+ BlockingModeCustomIp: &BlockingModeCustomIP{
+ Ipv4: ipToBytes(m.IPv4),
+ Ipv6: ipToBytes(m.IPv6),
+ },
+ }
+ case *dnsmsg.BlockingModeNXDOMAIN:
+ return &Profile_BlockingModeNxdomain{
+ BlockingModeNxdomain: &BlockingModeNXDOMAIN{},
+ }
+ case *dnsmsg.BlockingModeNullIP:
+ return &Profile_BlockingModeNullIp{
+ BlockingModeNullIp: &BlockingModeNullIP{},
+ }
+ case *dnsmsg.BlockingModeREFUSED:
+ return &Profile_BlockingModeRefused{
+ BlockingModeRefused: &BlockingModeREFUSED{},
+ }
+ default:
+ panic(fmt.Errorf("bad blocking mode %T(%[1]v)", m))
+ }
+}
+
+// ipToBytes is a wrapper around netip.Addr.MarshalBinary that ignores the
+// always-nil error.
+func ipToBytes(ip netip.Addr) (b []byte) {
+ b, _ = ip.MarshalBinary()
+
+ return b
+}
+
+// devicesToProtobuf converts a slice of devices to protobuf structures.
+func devicesToProtobuf(devices []*agd.Device) (pbDevices []*Device) {
+ pbDevices = make([]*Device, 0, len(devices))
+ for _, d := range devices {
+ pbDevices = append(pbDevices, &Device{
+ DeviceId: string(d.ID),
+ LinkedIp: ipToBytes(d.LinkedIP),
+ DeviceName: string(d.Name),
+ DedicatedIps: ipsToByteSlices(d.DedicatedIPs),
+ FilteringEnabled: d.FilteringEnabled,
+ })
+ }
+
+ return pbDevices
+}
+
+// ipsToByteSlices is a wrapper around netip.Addr.MarshalBinary that ignores the
+// always-nil errors.
+func ipsToByteSlices(ips []netip.Addr) (data [][]byte) {
+ if ips == nil {
+ return nil
+ }
+
+ data = make([][]byte, 0, len(ips))
+ for _, ip := range ips {
+ data = append(data, ipToBytes(ip))
+ }
+
+ return data
+}
diff --git a/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go b/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go
new file mode 100644
index 0000000..9658447
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/filecachepb_internal_test.go
@@ -0,0 +1,121 @@
+package filecachepb
+
+import (
+ "encoding/json"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/proto"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// Sinks for benchmarks
+var (
+ bytesSink []byte
+ cacheSink = &internal.FileCache{}
+ errSink error
+ fileCacheSink = &FileCache{}
+)
+
+// envVarName is the environment variable name the presence and value of which
+// define whether to run the benchmarks with the data from the given file.
+//
+// The path should be an absolute path.
+const envVarName = "ADGUARD_DNS_TEST_PROFILEDB_JSON"
+
+// setCacheSink is a helper that allows using a prepared JSON file for loading
+// the data for benchmarks from the environment.
+func setCacheSink(tb testing.TB) {
+ tb.Helper()
+
+ filePath := os.Getenv(envVarName)
+ if filePath == "" {
+ prof, dev := profiledbtest.NewProfile(tb)
+ cacheSink = &internal.FileCache{
+ SyncTime: time.Now().Round(0).UTC(),
+ Profiles: []*agd.Profile{prof},
+ Devices: []*agd.Device{dev},
+ Version: internal.FileCacheVersion,
+ }
+
+ return
+ }
+
+ tb.Logf("using %q as source for profiledb data", filePath)
+
+ data, err := os.ReadFile(filePath)
+ require.NoError(tb, err)
+
+ err = json.Unmarshal(data, cacheSink)
+ require.NoError(tb, err)
+}
+
+func BenchmarkCache(b *testing.B) {
+ setCacheSink(b)
+
+ b.Run("to_protobuf", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ fileCacheSink = toProtobuf(cacheSink)
+ }
+
+ require.NoError(b, errSink)
+ require.NotEmpty(b, fileCacheSink)
+ })
+
+ b.Run("from_protobuf", func(b *testing.B) {
+ var gotCache *internal.FileCache
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ gotCache, errSink = toInternal(fileCacheSink)
+ }
+
+ require.NoError(b, errSink)
+ require.NotEmpty(b, gotCache)
+ })
+
+ b.Run("encode", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ bytesSink, errSink = proto.Marshal(fileCacheSink)
+ }
+
+ require.NoError(b, errSink)
+ require.NotEmpty(b, bytesSink)
+ })
+
+ b.Run("decode", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ errSink = proto.Unmarshal(bytesSink, fileCacheSink)
+ }
+
+ require.NoError(b, errSink)
+ require.NotEmpty(b, fileCacheSink)
+ })
+
+ // Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkCache/to_protobuf-16 674397 1657 ns/op 1240 B/op 22 allocs/op
+ // BenchmarkCache/from_protobuf-16 83577 14285 ns/op 7400 B/op 29 allocs/op
+ // BenchmarkCache/encode-16 563797 1984 ns/op 208 B/op 1 allocs/op
+ // BenchmarkCache/decode-16 273951 5143 ns/op 1288 B/op 31 allocs/op
+}
diff --git a/internal/profiledb/internal/filecachepb/storage.go b/internal/profiledb/internal/filecachepb/storage.go
new file mode 100644
index 0000000..4e2eb96
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/storage.go
@@ -0,0 +1,74 @@
+package filecachepb
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
+ "github.com/google/renameio"
+ "google.golang.org/protobuf/proto"
+)
+
+// Storage is the file-cache storage that encodes data using protobuf.
+type Storage struct {
+ path string
+}
+
+// New returns a new protobuf-encoded file-cache storage.
+func New(cachePath string) (s *Storage) {
+ return &Storage{
+ path: cachePath,
+ }
+}
+
+// logPrefix is the logging prefix for the protobuf-encoded file-cache.
+const logPrefix = "profiledb protobuf cache"
+
+var _ internal.FileCacheStorage = (*Storage)(nil)
+
+// Load implements the [internal.FileCacheStorage] interface for *Storage.
+func (s *Storage) Load() (c *internal.FileCache, err error) {
+ log.Info("%s: loading", logPrefix)
+
+ b, err := os.ReadFile(s.path)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ log.Info("%s: file not present", logPrefix)
+
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ fc := &FileCache{}
+ err = proto.Unmarshal(b, fc)
+ if err != nil {
+ return nil, fmt.Errorf("decoding protobuf: %w", err)
+ }
+
+ return toInternal(fc)
+}
+
+// Store implements the [internal.FileCacheStorage] interface for *Storage.
+func (s *Storage) Store(c *internal.FileCache) (err error) {
+ profNum := len(c.Profiles)
+ log.Info("%s: saving %d profiles to %q", logPrefix, profNum, s.path)
+ defer log.Info("%s: saved %d profiles to %q", logPrefix, profNum, s.path)
+
+ fc := toProtobuf(c)
+ b, err := proto.Marshal(fc)
+ if err != nil {
+ return fmt.Errorf("encoding protobuf: %w", err)
+ }
+
+ err = renameio.WriteFile(s.path, b, 0o600)
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/profiledb/internal/filecachepb/storage_test.go b/internal/profiledb/internal/filecachepb/storage_test.go
new file mode 100644
index 0000000..b9fc366
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/storage_test.go
@@ -0,0 +1,48 @@
+package filecachepb_test
+
+import (
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestStorage(t *testing.T) {
+ prof, dev := profiledbtest.NewProfile(t)
+ cachePath := filepath.Join(t.TempDir(), "profiles.pb")
+ s := filecachepb.New(cachePath)
+ require.NotNil(t, s)
+
+ fc := &internal.FileCache{
+ SyncTime: time.Now().Round(0).UTC(),
+ Profiles: []*agd.Profile{prof},
+ Devices: []*agd.Device{dev},
+ Version: internal.FileCacheVersion,
+ }
+
+ err := s.Store(fc)
+ require.NoError(t, err)
+
+ gotFC, err := s.Load()
+ require.NoError(t, err)
+ require.NotNil(t, gotFC)
+ require.NotEmpty(t, *gotFC)
+
+ assert.Equal(t, fc, gotFC)
+}
+
+func TestStorage_Load_noFile(t *testing.T) {
+ cachePath := filepath.Join(t.TempDir(), "profiles.pb")
+ s := filecachepb.New(cachePath)
+ require.NotNil(t, s)
+
+ fc, err := s.Load()
+ assert.NoError(t, err)
+ assert.Nil(t, fc)
+}
diff --git a/internal/profiledb/internal/filecachepb/unsafe.go b/internal/profiledb/internal/filecachepb/unsafe.go
new file mode 100644
index 0000000..7578533
--- /dev/null
+++ b/internal/profiledb/internal/filecachepb/unsafe.go
@@ -0,0 +1,17 @@
+package filecachepb
+
+import "unsafe"
+
+// unsafelyConvertStrSlice checks if []T1 can be converted to []T2 at compile
+// time and, if so, converts the slice using package unsafe.
+//
+// Slices resulting from this conversion must not be mutated.
+func unsafelyConvertStrSlice[T1, T2 ~string](s []T1) (res []T2) {
+ if s == nil {
+ return nil
+ }
+
+ // #nosec G103 -- Conversion between two slices with the same underlying
+ // element type is safe.
+ return *(*[]T2)(unsafe.Pointer(&s))
+}
diff --git a/internal/profiledb/internal/internal.go b/internal/profiledb/internal/internal.go
new file mode 100644
index 0000000..2d2705a
--- /dev/null
+++ b/internal/profiledb/internal/internal.go
@@ -0,0 +1,48 @@
+// Package internal contains common constants and types that all implementations
+// of the default profile-cache use.
+package internal
+
+import (
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+)
+
+// FileCacheVersion is the version of cached data structure. It must be
+// manually incremented on every change in [agd.Device], [agd.Profile], and any
+// file-cache structures.
+const FileCacheVersion = 6
+
+// FileCache contains the data that is cached on the filesystem.
+type FileCache struct {
+ SyncTime time.Time
+ Profiles []*agd.Profile
+ Devices []*agd.Device
+ Version int32
+}
+
+// FileCacheStorage is the interface for all file caches.
+type FileCacheStorage interface {
+ // Load read the data from the cache file. If the file does not exist, Load
+ // must return a nil *FileCache. Load must return an informative error.
+ Load() (c *FileCache, err error)
+
+ // Store writes the data to the cache file. c must not be nil. Store must
+ // return an informative error.
+ Store(c *FileCache) (err error)
+}
+
+// EmptyFileCacheStorage is the empty file-cache storage that does nothing and
+// returns nils.
+type EmptyFileCacheStorage struct{}
+
+// type check
+var _ FileCacheStorage = EmptyFileCacheStorage{}
+
+// Load implements the [FileCacheStorage] interface for EmptyFileCacheStorage.
+// It does nothing and returns nils.
+func (EmptyFileCacheStorage) Load() (_ *FileCache, _ error) { return nil, nil }
+
+// Store implements the [FileCacheStorage] interface for EmptyFileCacheStorage.
+// It does nothing and returns nil.
+func (EmptyFileCacheStorage) Store(_ *FileCache) (_ error) { return nil }
diff --git a/internal/profiledb/internal/profiledbtest/profiledbtest.go b/internal/profiledb/internal/profiledbtest/profiledbtest.go
new file mode 100644
index 0000000..9d0d28c
--- /dev/null
+++ b/internal/profiledb/internal/profiledbtest/profiledbtest.go
@@ -0,0 +1,79 @@
+// Package profiledbtest contains common helpers for profile-database tests.
+package profiledbtest
+
+import (
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtime"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/stretchr/testify/require"
+)
+
+// ProfileID is the profile ID for tests.
+//
+// Keep in sync with internal/profiledb/testdata/profiles.json.
+const ProfileID agd.ProfileID = "prof1234"
+
+// DeviceID is the profile ID for tests.
+//
+// Keep in sync with internal/profiledb/testdata/profiles.json.
+const DeviceID agd.DeviceID = "dev1234"
+
+// NewProfile returns the common profile and device for tests.
+//
+// Keep in sync with internal/profiledb/testdata/profiles.json.
+func NewProfile(tb testing.TB) (p *agd.Profile, d *agd.Device) {
+ tb.Helper()
+
+ loc, err := agdtime.LoadLocation("Europe/Brussels")
+ require.NoError(tb, err)
+
+ dev := &agd.Device{
+ ID: DeviceID,
+ LinkedIP: netip.MustParseAddr("1.2.3.4"),
+ Name: "dev1",
+ DedicatedIPs: []netip.Addr{
+ netip.MustParseAddr("1.2.4.5"),
+ },
+ FilteringEnabled: true,
+ }
+
+ return &agd.Profile{
+ Parental: &agd.ParentalProtectionSettings{
+ Schedule: &agd.ParentalProtectionSchedule{
+ Week: &agd.WeeklySchedule{
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ {Start: 0, End: 700},
+ },
+ TimeZone: loc,
+ },
+ Enabled: true,
+ },
+ BlockingMode: dnsmsg.BlockingModeCodec{
+ Mode: &dnsmsg.BlockingModeNullIP{},
+ },
+ ID: ProfileID,
+ DeviceIDs: []agd.DeviceID{dev.ID},
+ RuleListIDs: []agd.FilterListID{
+ "adguard_dns_filter",
+ },
+ CustomRules: []agd.FilterRuleText{
+ "|blocked-by-custom.example",
+ },
+ FilteredResponseTTL: 10 * time.Second,
+ FilteringEnabled: true,
+ SafeBrowsingEnabled: true,
+ RuleListsEnabled: true,
+ QueryLogEnabled: true,
+ BlockPrivateRelay: true,
+ BlockFirefoxCanary: true,
+ }, dev
+}
diff --git a/internal/profiledb/profiledb.go b/internal/profiledb/profiledb.go
new file mode 100644
index 0000000..ffa4a95
--- /dev/null
+++ b/internal/profiledb/profiledb.go
@@ -0,0 +1,506 @@
+// Package profiledb defines interfaces for databases of user profiles.
+package profiledb
+
+import (
+ "context"
+ "fmt"
+ "net/netip"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/metrics"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachejson"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/filecachepb"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/log"
+ "golang.org/x/exp/maps"
+ "golang.org/x/exp/slices"
+)
+
+// Interface is the local database of user profiles and devices.
+type Interface interface {
+ // ProfileByDeviceID returns the profile and the device identified by id.
+ ProfileByDeviceID(
+ ctx context.Context,
+ id agd.DeviceID,
+ ) (p *agd.Profile, d *agd.Device, err error)
+
+ // ProfileByDedicatedIP returns the profile and the device identified by its
+ // dedicated DNS server IP address.
+ ProfileByDedicatedIP(
+ ctx context.Context,
+ ip netip.Addr,
+ ) (p *agd.Profile, d *agd.Device, err error)
+
+ // ProfileByLinkedIP returns the profile and the device identified by its
+ // linked IP address.
+ ProfileByLinkedIP(ctx context.Context, ip netip.Addr) (p *agd.Profile, d *agd.Device, err error)
+}
+
+// Default is the default in-memory implementation of the [Interface] interface
+// that can refresh itself from the provided storage.
+type Default struct {
+ // mapsMu protects the profiles, devices, deviceIDToProfileID,
+ // linkedIPToDeviceID, and dedicatedIPToDeviceID maps.
+ mapsMu *sync.RWMutex
+
+ // refreshMu protects syncTime and lastFullSync. These are only used within
+ // Refresh, so this is also basically a refresh serializer.
+ refreshMu *sync.Mutex
+
+ // cache is the filesystem-cache storage used by this profile database.
+ cache internal.FileCacheStorage
+
+ // storage returns the data for this profile DB.
+ storage Storage
+
+ // profiles maps profile IDs to profile records.
+ profiles map[agd.ProfileID]*agd.Profile
+
+ // devices maps device IDs to device records.
+ devices map[agd.DeviceID]*agd.Device
+
+ // deviceIDToProfileID maps device IDs to the ID of their profile.
+ deviceIDToProfileID map[agd.DeviceID]agd.ProfileID
+
+ // linkedIPToDeviceID maps linked IP addresses to the IDs of their devices.
+ linkedIPToDeviceID map[netip.Addr]agd.DeviceID
+
+ // dedicatedIPToDeviceID maps dedicated IP addresses to the IDs of their
+ // devices.
+ dedicatedIPToDeviceID map[netip.Addr]agd.DeviceID
+
+ // syncTime is the time of the last synchronization point. It is received
+ // from the storage during a refresh and is then used in consecutive
+ // requests to the storage, unless it's a full synchronization.
+ syncTime time.Time
+
+ // lastFullSync is the time of the last full synchronization.
+ lastFullSync time.Time
+
+ // fullSyncIvl is the interval between two full synchronizations with the
+ // storage.
+ fullSyncIvl time.Duration
+}
+
+// New returns a new default in-memory profile database with a filesystem cache.
+// The initial refresh is performed immediately with the constant timeout of 1
+// minute, beyond which an empty profiledb is returned. If cacheFilePath is the
+// string "none", filesystem cache is disabled. db is never nil.
+func New(
+ s Storage,
+ fullSyncIvl time.Duration,
+ cacheFilePath string,
+) (db *Default, err error) {
+ var cacheStorage internal.FileCacheStorage
+ if cacheFilePath == "none" {
+ cacheStorage = internal.EmptyFileCacheStorage{}
+ } else {
+ switch ext := filepath.Ext(cacheFilePath); ext {
+ case ".json":
+ cacheStorage = filecachejson.New(cacheFilePath)
+ case ".pb":
+ cacheStorage = filecachepb.New(cacheFilePath)
+ default:
+ return nil, fmt.Errorf("file %q is neither json nor protobuf", cacheFilePath)
+ }
+ }
+
+ db = &Default{
+ mapsMu: &sync.RWMutex{},
+ refreshMu: &sync.Mutex{},
+ cache: cacheStorage,
+ storage: s,
+ syncTime: time.Time{},
+ lastFullSync: time.Time{},
+ profiles: make(map[agd.ProfileID]*agd.Profile),
+ devices: make(map[agd.DeviceID]*agd.Device),
+ deviceIDToProfileID: make(map[agd.DeviceID]agd.ProfileID),
+ linkedIPToDeviceID: make(map[netip.Addr]agd.DeviceID),
+ dedicatedIPToDeviceID: make(map[netip.Addr]agd.DeviceID),
+ fullSyncIvl: fullSyncIvl,
+ }
+
+ err = db.loadFileCache()
+ if err != nil {
+ log.Error("profiledb: fs cache: loading: %s", err)
+ }
+
+ // initialTimeout defines the maximum duration of the first attempt to load
+ // the profiledb.
+ const initialTimeout = 1 * time.Minute
+
+ ctx, cancel := context.WithTimeout(context.Background(), initialTimeout)
+ defer cancel()
+
+ log.Info("profiledb: initial refresh")
+
+ err = db.Refresh(ctx)
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ log.Info("profiledb: warning: initial refresh timeout: %s", err)
+
+ return db, nil
+ }
+
+ return nil, fmt.Errorf("initial refresh: %w", err)
+ }
+
+ log.Info("profiledb: initial refresh succeeded")
+
+ return db, nil
+}
+
+// type check
+var _ agd.Refresher = (*Default)(nil)
+
+// Refresh implements the [Refresher] interface for *Default. It updates the
+// internal maps and the synchronization time using the data it receives from
+// the storage.
+func (db *Default) Refresh(ctx context.Context) (err error) {
+ var totalProfiles, totalDevices int
+ startTime := time.Now()
+ defer func() {
+ metrics.ProfilesSyncTime.SetToCurrentTime()
+ metrics.ProfilesSyncDuration.Observe(time.Since(startTime).Seconds())
+ metrics.ProfilesCountGauge.Set(float64(totalProfiles))
+ metrics.DevicesCountGauge.Set(float64(totalDevices))
+ metrics.SetStatusGauge(metrics.ProfilesSyncStatus, err)
+ }()
+
+ reqID := agd.NewRequestID()
+ ctx = agd.WithRequestID(ctx, reqID)
+
+ defer func() { err = errors.Annotate(err, "req %s: %w", reqID) }()
+
+ db.refreshMu.Lock()
+ defer db.refreshMu.Unlock()
+
+ syncTime := db.syncTime
+
+ sinceLastFullSync := time.Since(db.lastFullSync)
+ isFullSync := sinceLastFullSync >= db.fullSyncIvl
+ if isFullSync {
+ log.Info("profiledb: full sync, %s since %s", sinceLastFullSync, db.lastFullSync)
+
+ syncTime = time.Time{}
+ }
+
+ resp, err := db.storage.Profiles(ctx, &StorageRequest{
+ SyncTime: syncTime,
+ })
+ if err != nil {
+ return fmt.Errorf("updating profiles: %w", err)
+ }
+
+ profiles := resp.Profiles
+ devices := resp.Devices
+ db.setProfiles(profiles, devices, isFullSync)
+
+ profNum := len(profiles)
+ devNum := len(devices)
+ log.Debug("profiledb: req %s: got %d profiles with %d devices", reqID, profNum, devNum)
+
+ metrics.ProfilesNewCountGauge.Set(float64(profNum))
+ metrics.DevicesNewCountGauge.Set(float64(devNum))
+
+ db.syncTime = resp.SyncTime
+ if isFullSync {
+ db.lastFullSync = time.Now()
+
+ err = db.cache.Store(&internal.FileCache{
+ SyncTime: resp.SyncTime,
+ Profiles: resp.Profiles,
+ Devices: resp.Devices,
+ Version: internal.FileCacheVersion,
+ })
+ if err != nil {
+ return fmt.Errorf("saving cache: %w", err)
+ }
+ }
+
+ totalProfiles = len(db.profiles)
+ totalDevices = len(db.devices)
+
+ return nil
+}
+
+// loadFileCache loads the profiles data from the filesystem cache.
+func (db *Default) loadFileCache() (err error) {
+ const logPrefix = "profiledb: cache"
+
+ start := time.Now()
+ log.Info("%s: initial loading", logPrefix)
+
+ c, err := db.cache.Load()
+ if err != nil {
+ // Don't wrap the error, because it's informative enough as is.
+ return err
+ } else if c == nil {
+ log.Info("%s: no cache", logPrefix)
+
+ return nil
+ }
+
+ profNum, devNum := len(c.Profiles), len(c.Devices)
+ log.Info(
+ "%s: got version %d, %d profiles, %d devices in %s",
+ logPrefix,
+ c.Version,
+ profNum,
+ devNum,
+ time.Since(start),
+ )
+
+ if c.Version != internal.FileCacheVersion {
+ log.Info(
+ "%s: version %d is different from %d",
+ logPrefix,
+ c.Version,
+ internal.FileCacheVersion,
+ )
+
+ return nil
+ } else if profNum == 0 || devNum == 0 {
+ log.Info("%s: empty", logPrefix)
+
+ return nil
+ }
+
+ db.setProfiles(c.Profiles, c.Devices, true)
+ db.syncTime, db.lastFullSync = c.SyncTime, c.SyncTime
+
+ return nil
+}
+
+// setProfiles adds or updates the data for all profiles and devices.
+func (db *Default) setProfiles(profiles []*agd.Profile, devices []*agd.Device, isFullSync bool) {
+ db.mapsMu.Lock()
+ defer db.mapsMu.Unlock()
+
+ if isFullSync {
+ maps.Clear(db.profiles)
+ maps.Clear(db.devices)
+ maps.Clear(db.deviceIDToProfileID)
+ maps.Clear(db.linkedIPToDeviceID)
+ maps.Clear(db.dedicatedIPToDeviceID)
+ }
+
+ for _, p := range profiles {
+ db.profiles[p.ID] = p
+
+ for _, devID := range p.DeviceIDs {
+ db.deviceIDToProfileID[devID] = p.ID
+ }
+ }
+
+ for _, d := range devices {
+ devID := d.ID
+ db.devices[devID] = d
+
+ if d.LinkedIP != (netip.Addr{}) {
+ db.linkedIPToDeviceID[d.LinkedIP] = devID
+ }
+
+ for _, dedIP := range d.DedicatedIPs {
+ db.dedicatedIPToDeviceID[dedIP] = devID
+ }
+ }
+}
+
+// type check
+var _ Interface = (*Default)(nil)
+
+// ProfileByDeviceID implements the [Interface] interface for *Default.
+func (db *Default) ProfileByDeviceID(
+ ctx context.Context,
+ id agd.DeviceID,
+) (p *agd.Profile, d *agd.Device, err error) {
+ db.mapsMu.RLock()
+ defer db.mapsMu.RUnlock()
+
+ return db.profileByDeviceID(ctx, id)
+}
+
+// profileByDeviceID returns the profile and the device by the ID of the device,
+// if found. It assumes that db.mapsMu is locked for reading.
+func (db *Default) profileByDeviceID(
+ _ context.Context,
+ id agd.DeviceID,
+) (p *agd.Profile, d *agd.Device, err error) {
+ // Do not use [errors.Annotate] here, because it allocates even when the
+ // error is nil. Also do not use fmt.Errorf in a defer, because it
+ // allocates when a device is not found, which is the most common case.
+
+ profID, ok := db.deviceIDToProfileID[id]
+ if !ok {
+ return nil, nil, ErrDeviceNotFound
+ }
+
+ p, ok = db.profiles[profID]
+ if !ok {
+ // We have an older device record with a deleted profile. Remove it
+ // from our profile DB in a goroutine, since that requires a write lock.
+ go db.removeDevice(id)
+
+ return nil, nil, fmt.Errorf("empty profile: %w", ErrDeviceNotFound)
+ }
+
+ // Reinspect the devices in the profile record to make sure that the device
+ // is still attached to this profile.
+ for _, profDevID := range p.DeviceIDs {
+ if profDevID == id {
+ d = db.devices[id]
+
+ break
+ }
+ }
+
+ if d == nil {
+ // Perhaps, the device has been deleted from this profile. May happen
+ // when the device was found by a linked IP. Remove it from our profile
+ // DB in a goroutine, since that requires a write lock.
+ go db.removeDevice(id)
+
+ return nil, nil, fmt.Errorf("rechecking devices: %w", ErrDeviceNotFound)
+ }
+
+ return p, d, nil
+}
+
+// removeDevice removes the device with the given ID from the database. It is
+// intended to be used as a goroutine.
+func (db *Default) removeDevice(id agd.DeviceID) {
+ defer log.OnPanicAndExit("removeDevice", 1)
+
+ db.mapsMu.Lock()
+ defer db.mapsMu.Unlock()
+
+ delete(db.deviceIDToProfileID, id)
+}
+
+// ProfileByLinkedIP implements the [Interface] interface for *Default. ip must
+// be valid.
+func (db *Default) ProfileByLinkedIP(
+ ctx context.Context,
+ ip netip.Addr,
+) (p *agd.Profile, d *agd.Device, err error) {
+ // Do not use errors.Annotate here, because it allocates even when the error
+ // is nil. Also do not use fmt.Errorf in a defer, because it allocates when
+ // a device is not found, which is the most common case.
+
+ db.mapsMu.RLock()
+ defer db.mapsMu.RUnlock()
+
+ id, ok := db.linkedIPToDeviceID[ip]
+ if !ok {
+ return nil, nil, ErrDeviceNotFound
+ }
+
+ const errPrefix = "profile by device linked ip"
+ p, d, err = db.profileByDeviceID(ctx, id)
+ if err != nil {
+ if errors.Is(err, ErrDeviceNotFound) {
+ // Probably, the device has been deleted. Remove it from our
+ // profile DB in a goroutine, since that requires a write lock.
+ go db.removeLinkedIP(ip)
+ }
+
+ // Don't add the device ID to the error here, since it is already added
+ // by profileByDeviceID.
+ return nil, nil, fmt.Errorf("%s: %w", errPrefix, err)
+ }
+
+ if d.LinkedIP == (netip.Addr{}) {
+ return nil, nil, fmt.Errorf(
+ "%s: device does not have linked ip: %w",
+ errPrefix,
+ ErrDeviceNotFound,
+ )
+ } else if d.LinkedIP != ip {
+ // The linked IP has changed. Remove it from our profile DB in a
+ // goroutine, since that requires a write lock.
+ go db.removeLinkedIP(ip)
+
+ return nil, nil, fmt.Errorf(
+ "%s: %q doesn't match: %w",
+ errPrefix,
+ d.LinkedIP,
+ ErrDeviceNotFound,
+ )
+ }
+
+ return p, d, nil
+}
+
+// removeLinkedIP removes the device link for the given linked IP address from
+// the profile database. It is intended to be used as a goroutine.
+func (db *Default) removeLinkedIP(ip netip.Addr) {
+ defer log.OnPanicAndExit("removeLinkedIP", 1)
+
+ db.mapsMu.Lock()
+ defer db.mapsMu.Unlock()
+
+ delete(db.linkedIPToDeviceID, ip)
+}
+
+// ProfileByDedicatedIP implements the [Interface] interface for *Default. ip
+// must be valid.
+func (db *Default) ProfileByDedicatedIP(
+ ctx context.Context,
+ ip netip.Addr,
+) (p *agd.Profile, d *agd.Device, err error) {
+ // Do not use errors.Annotate here, because it allocates even when the error
+ // is nil. Also do not use fmt.Errorf in a defer, because it allocates when
+ // a device is not found, which is the most common case.
+
+ db.mapsMu.RLock()
+ defer db.mapsMu.RUnlock()
+
+ id, ok := db.dedicatedIPToDeviceID[ip]
+ if !ok {
+ return nil, nil, ErrDeviceNotFound
+ }
+
+ const errPrefix = "profile by device dedicated ip"
+ p, d, err = db.profileByDeviceID(ctx, id)
+ if err != nil {
+ if errors.Is(err, ErrDeviceNotFound) {
+ // Probably, the device has been deleted. Remove it from our
+ // profile DB in a goroutine, since that requires a write lock.
+ go db.removeDedicatedIP(ip)
+ }
+
+ // Don't add the device ID to the error here, since it is already added
+ // by profileByDeviceID.
+ return nil, nil, fmt.Errorf("%s: %w", errPrefix, err)
+ }
+
+ if ipIdx := slices.Index(d.DedicatedIPs, ip); ipIdx < 0 {
+ // Perhaps, the device has changed its dedicated IPs. Remove it from
+ // our profile DB in a goroutine, since that requires a write lock.
+ go db.removeDedicatedIP(ip)
+
+ return nil, nil, fmt.Errorf(
+ "%s: rechecking dedicated ips: %w",
+ errPrefix,
+ ErrDeviceNotFound,
+ )
+ }
+
+ return p, d, nil
+}
+
+// removeDedicatedIP removes the device link for the given dedicated IP address
+// from the profile database. It is intended to be used as a goroutine.
+func (db *Default) removeDedicatedIP(ip netip.Addr) {
+ defer log.OnPanicAndExit("removeDedicatedIP", 1)
+
+ db.mapsMu.Lock()
+ defer db.mapsMu.Unlock()
+
+ delete(db.dedicatedIPToDeviceID, ip)
+}
diff --git a/internal/profiledb/profiledb_test.go b/internal/profiledb/profiledb_test.go
new file mode 100644
index 0000000..a9345b2
--- /dev/null
+++ b/internal/profiledb/profiledb_test.go
@@ -0,0 +1,495 @@
+package profiledb_test
+
+import (
+ "bytes"
+ "context"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+ "github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
+ "github.com/AdguardTeam/AdGuardDNS/internal/profiledb/internal/profiledbtest"
+ "github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMain(m *testing.M) {
+ testutil.DiscardLogOutput(m)
+}
+
+// Common IPs for tests
+//
+// Keep in sync with testdata/profiles.json.
+var (
+ testClientIPv4 = netip.MustParseAddr("1.2.3.4")
+ testOtherClientIPv4 = netip.MustParseAddr("1.2.3.5")
+
+ testDedicatedIPv4 = netip.MustParseAddr("1.2.4.5")
+ testOtherDedicatedIPv4 = netip.MustParseAddr("1.2.4.6")
+)
+
+// testTimeout is the common timeout for tests.
+const testTimeout = 1 * time.Second
+
+// newDefaultProfileDB returns a new default profile database for tests.
+// devicesCh receives the devices that the storage should return in its
+// response.
+func newDefaultProfileDB(tb testing.TB, devices <-chan []*agd.Device) (db *profiledb.Default) {
+ tb.Helper()
+
+ onProfiles := func(
+ _ context.Context,
+ _ *profiledb.StorageRequest,
+ ) (resp *profiledb.StorageResponse, err error) {
+ devices, _ := testutil.RequireReceive(tb, devices, testTimeout)
+ devIDs := make([]agd.DeviceID, 0, len(devices))
+ for _, d := range devices {
+ devIDs = append(devIDs, d.ID)
+ }
+
+ return &profiledb.StorageResponse{
+ Profiles: []*agd.Profile{{
+ BlockingMode: dnsmsg.BlockingModeCodec{
+ Mode: &dnsmsg.BlockingModeNullIP{},
+ },
+ ID: profiledbtest.ProfileID,
+ DeviceIDs: devIDs,
+ }},
+ Devices: devices,
+ }, nil
+ }
+
+ ps := &agdtest.ProfileStorage{
+ OnProfiles: onProfiles,
+ }
+
+ db, err := profiledb.New(ps, 1*time.Minute, "none")
+ require.NoError(tb, err)
+
+ return db
+}
+
+func TestDefaultProfileDB(t *testing.T) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ LinkedIP: testClientIPv4,
+ DedicatedIPs: []netip.Addr{
+ testDedicatedIPv4,
+ },
+ }
+
+ devicesCh := make(chan []*agd.Device, 1)
+ devicesCh <- []*agd.Device{dev}
+ db := newDefaultProfileDB(t, devicesCh)
+
+ t.Run("by_device_id", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
+ defer cancel()
+
+ p, d, err := db.ProfileByDeviceID(ctx, profiledbtest.DeviceID)
+ require.NoError(t, err)
+
+ assert.Equal(t, profiledbtest.ProfileID, p.ID)
+ assert.Equal(t, d, dev)
+ })
+
+ t.Run("by_dedicated_ip", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
+ defer cancel()
+
+ p, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, profiledbtest.ProfileID, p.ID)
+ assert.Equal(t, d, dev)
+ })
+
+ t.Run("by_linked_ip", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
+ defer cancel()
+
+ p, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, profiledbtest.ProfileID, p.ID)
+ assert.Equal(t, d, dev)
+ })
+}
+
+func TestDefaultProfileDB_ProfileByDedicatedIP_removedDevice(t *testing.T) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ DedicatedIPs: []netip.Addr{
+ testDedicatedIPv4,
+ },
+ }
+
+ devicesCh := make(chan []*agd.Device, 2)
+
+ // The first response, the device is still there.
+ devicesCh <- []*agd.Device{dev}
+
+ db := newDefaultProfileDB(t, devicesCh)
+
+ ctx := context.Background()
+ _, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, d, dev)
+
+ // The second response, the device is removed.
+ devicesCh <- nil
+
+ err = db.Refresh(ctx)
+ require.NoError(t, err)
+
+ assert.Eventually(t, func() (ok bool) {
+ _, d, err = db.ProfileByDedicatedIP(ctx, testDedicatedIPv4)
+
+ return errors.Is(err, profiledb.ErrDeviceNotFound)
+ }, testTimeout, testTimeout/10)
+}
+
+func TestDefaultProfileDB_ProfileByDedicatedIP_deviceNewIP(t *testing.T) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ DedicatedIPs: []netip.Addr{
+ testDedicatedIPv4,
+ },
+ }
+
+ devicesCh := make(chan []*agd.Device, 2)
+
+ // The first response, the device is still there.
+ devicesCh <- []*agd.Device{dev}
+
+ db := newDefaultProfileDB(t, devicesCh)
+
+ ctx := context.Background()
+ _, d, err := db.ProfileByDedicatedIP(ctx, testDedicatedIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, d, dev)
+
+ // The second response, the device has a new IP.
+ dev.DedicatedIPs[0] = testOtherDedicatedIPv4
+ devicesCh <- []*agd.Device{dev}
+
+ err = db.Refresh(ctx)
+ require.NoError(t, err)
+
+ assert.Eventually(t, func() (ok bool) {
+ _, _, err = db.ProfileByDedicatedIP(ctx, testDedicatedIPv4)
+
+ if !errors.Is(err, profiledb.ErrDeviceNotFound) {
+ return false
+ }
+
+ _, d, err = db.ProfileByDedicatedIP(ctx, testOtherDedicatedIPv4)
+ if err != nil {
+ return false
+ }
+
+ return d != nil && d.ID == dev.ID
+ }, testTimeout, testTimeout/10)
+}
+
+func TestDefaultProfileDB_ProfileByLinkedIP_removedDevice(t *testing.T) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ LinkedIP: testClientIPv4,
+ }
+
+ devicesCh := make(chan []*agd.Device, 2)
+
+ // The first response, the device is still there.
+ devicesCh <- []*agd.Device{dev}
+
+ db := newDefaultProfileDB(t, devicesCh)
+
+ ctx := context.Background()
+ _, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, d, dev)
+
+ // The second response, the device is removed.
+ devicesCh <- nil
+
+ err = db.Refresh(ctx)
+ require.NoError(t, err)
+
+ assert.Eventually(t, func() (ok bool) {
+ _, d, err = db.ProfileByLinkedIP(ctx, testClientIPv4)
+
+ return errors.Is(err, profiledb.ErrDeviceNotFound)
+ }, testTimeout, testTimeout/10)
+}
+
+func TestDefaultProfileDB_ProfileByLinkedIP_deviceNewIP(t *testing.T) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ LinkedIP: testClientIPv4,
+ }
+
+ devicesCh := make(chan []*agd.Device, 2)
+
+ // The first response, the device is still there.
+ devicesCh <- []*agd.Device{dev}
+
+ db := newDefaultProfileDB(t, devicesCh)
+
+ ctx := context.Background()
+ _, d, err := db.ProfileByLinkedIP(ctx, testClientIPv4)
+ require.NoError(t, err)
+
+ assert.Equal(t, d, dev)
+
+ // The second response, the device has a new IP.
+ dev.LinkedIP = testOtherClientIPv4
+ devicesCh <- []*agd.Device{dev}
+
+ err = db.Refresh(ctx)
+ require.NoError(t, err)
+
+ assert.Eventually(t, func() (ok bool) {
+ _, _, err = db.ProfileByLinkedIP(ctx, testClientIPv4)
+
+ if !errors.Is(err, profiledb.ErrDeviceNotFound) {
+ return false
+ }
+
+ _, d, err = db.ProfileByLinkedIP(ctx, testOtherClientIPv4)
+ if err != nil {
+ return false
+ }
+
+ return d != nil && d.ID == dev.ID
+ }, testTimeout, testTimeout/10)
+}
+
+func TestDefaultProfileDB_fileCache_success(t *testing.T) {
+ var gotSyncTime time.Time
+ onProfiles := func(
+ _ context.Context,
+ req *profiledb.StorageRequest,
+ ) (resp *profiledb.StorageResponse, err error) {
+ gotSyncTime = req.SyncTime
+
+ return &profiledb.StorageResponse{}, nil
+ }
+
+ ps := &agdtest.ProfileStorage{
+ OnProfiles: onProfiles,
+ }
+
+ cacheFileTmplPath := filepath.Join("testdata", "profiles.json")
+ data, err := os.ReadFile(cacheFileTmplPath)
+ require.NoError(t, err)
+
+ // Use the time with monotonic clocks stripped down.
+ wantSyncTime := time.Now().Round(0).UTC()
+ data = bytes.ReplaceAll(
+ data,
+ []byte("SYNC_TIME"),
+ []byte(wantSyncTime.Format(time.RFC3339Nano)),
+ )
+
+ cacheFilePath := filepath.Join(t.TempDir(), "profiles.json")
+ err = os.WriteFile(cacheFilePath, data, 0o600)
+ require.NoError(t, err)
+
+ db, err := profiledb.New(ps, 1*time.Minute, cacheFilePath)
+ require.NoError(t, err)
+ require.NotNil(t, db)
+
+ assert.Equal(t, wantSyncTime, gotSyncTime)
+
+ prof, dev := profiledbtest.NewProfile(t)
+ p, d, err := db.ProfileByDeviceID(context.Background(), dev.ID)
+ require.NoError(t, err)
+
+ assert.Equal(t, dev, d)
+ assert.Equal(t, prof, p)
+}
+
+func TestDefaultProfileDB_fileCache_badVersion(t *testing.T) {
+ storageCalled := false
+ ps := &agdtest.ProfileStorage{
+ OnProfiles: func(
+ _ context.Context,
+ _ *profiledb.StorageRequest,
+ ) (resp *profiledb.StorageResponse, err error) {
+ storageCalled = true
+
+ return &profiledb.StorageResponse{}, nil
+ },
+ }
+
+ cacheFilePath := filepath.Join(t.TempDir(), "profiles.json")
+ err := os.WriteFile(cacheFilePath, []byte(`{"version":1000}`), 0o600)
+ require.NoError(t, err)
+
+ db, err := profiledb.New(ps, 1*time.Minute, cacheFilePath)
+ assert.NoError(t, err)
+ assert.NotNil(t, db)
+ assert.True(t, storageCalled)
+}
+
+// Sinks for benchmarks.
+var (
+ profSink *agd.Profile
+ devSink *agd.Device
+ errSink error
+)
+
+func BenchmarkDefaultProfileDB_ProfileByDeviceID(b *testing.B) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ }
+
+ devicesCh := make(chan []*agd.Device, 1)
+ devicesCh <- []*agd.Device{dev}
+ db := newDefaultProfileDB(b, devicesCh)
+
+ ctx := context.Background()
+
+ b.Run("success", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByDeviceID(ctx, profiledbtest.DeviceID)
+ }
+
+ assert.NotNil(b, profSink)
+ assert.NotNil(b, devSink)
+ assert.NoError(b, errSink)
+ })
+
+ const wrongDevID = profiledbtest.DeviceID + "_bad"
+
+ b.Run("not_found", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByDeviceID(ctx, wrongDevID)
+ }
+
+ assert.Nil(b, profSink)
+ assert.Nil(b, devSink)
+ assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound)
+ })
+
+ // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro
+ // 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkDefaultProfileDB_ProfileByDeviceID/success-16 59396382 21.36 ns/op 0 B/op 0 allocs/op
+ // BenchmarkDefaultProfileDB_ProfileByDeviceID/not_found-16 74497800 16.45 ns/op 0 B/op 0 allocs/op
+}
+
+func BenchmarkDefaultProfileDB_ProfileByLinkedIP(b *testing.B) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ LinkedIP: testClientIPv4,
+ }
+
+ devicesCh := make(chan []*agd.Device, 1)
+ devicesCh <- []*agd.Device{dev}
+ db := newDefaultProfileDB(b, devicesCh)
+
+ ctx := context.Background()
+
+ b.Run("success", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByLinkedIP(ctx, testClientIPv4)
+ }
+
+ assert.NotNil(b, profSink)
+ assert.NotNil(b, devSink)
+ assert.NoError(b, errSink)
+ })
+
+ b.Run("not_found", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByLinkedIP(ctx, testOtherClientIPv4)
+ }
+
+ assert.Nil(b, profSink)
+ assert.Nil(b, devSink)
+ assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound)
+ })
+
+ // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro
+ // 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkDefaultProfileDB_ProfileByLinkedIP/success-16 24822542 44.11 ns/op 0 B/op 0 allocs/op
+ // BenchmarkDefaultProfileDB_ProfileByLinkedIP/not_found-16 63539154 20.04 ns/op 0 B/op 0 allocs/op
+}
+
+func BenchmarkDefaultProfileDB_ProfileByDedicatedIP(b *testing.B) {
+ dev := &agd.Device{
+ ID: profiledbtest.DeviceID,
+ DedicatedIPs: []netip.Addr{
+ testClientIPv4,
+ },
+ }
+
+ devicesCh := make(chan []*agd.Device, 1)
+ devicesCh <- []*agd.Device{dev}
+ db := newDefaultProfileDB(b, devicesCh)
+
+ ctx := context.Background()
+
+ b.Run("success", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByDedicatedIP(ctx, testClientIPv4)
+ }
+
+ assert.NotNil(b, profSink)
+ assert.NotNil(b, devSink)
+ assert.NoError(b, errSink)
+ })
+
+ b.Run("not_found", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ profSink, devSink, errSink = db.ProfileByDedicatedIP(ctx, testOtherClientIPv4)
+ }
+
+ assert.Nil(b, profSink)
+ assert.Nil(b, devSink)
+ assert.ErrorIs(b, errSink, profiledb.ErrDeviceNotFound)
+ })
+
+ // Most recent results, as of 2023-04-10, on a ThinkPad X13 with a Ryzen Pro
+ // 7 CPU:
+ //
+ // goos: linux
+ // goarch: amd64
+ // pkg: github.com/AdguardTeam/AdGuardDNS/internal/profiledb
+ // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
+ // BenchmarkDefaultProfileDB_ProfileByDedicatedIP/success-16 22697658 48.19 ns/op 0 B/op 0 allocs/op
+ // BenchmarkDefaultProfileDB_ProfileByDedicatedIP/not_found-16 61062061 19.89 ns/op 0 B/op 0 allocs/op
+}
diff --git a/internal/profiledb/storage.go b/internal/profiledb/storage.go
new file mode 100644
index 0000000..b3b68ed
--- /dev/null
+++ b/internal/profiledb/storage.go
@@ -0,0 +1,35 @@
+package profiledb
+
+import (
+ "context"
+ "time"
+
+ "github.com/AdguardTeam/AdGuardDNS/internal/agd"
+)
+
+// Storage is a storage from which an [Default] receives data about profiles and
+// devices.
+type Storage interface {
+ // Profiles returns profile and device data that has changed since
+ // req.SyncTime. req must not be nil.
+ Profiles(ctx context.Context, req *StorageRequest) (resp *StorageResponse, err error)
+}
+
+// StorageRequest is the request to [Storage] for profiles and devices.
+type StorageRequest struct {
+ // SyncTime is the last time profiles were synced.
+ SyncTime time.Time
+}
+
+// StorageResponse is the ProfileStorage.Profiles response.
+type StorageResponse struct {
+ // SyncTime is the time that should be saved and used as the next
+ // [ProfilesRequest.SyncTime].
+ SyncTime time.Time
+
+ // Profiles are the profiles data from the [Storage].
+ Profiles []*agd.Profile
+
+ // Devices are the device data from the [Storage].
+ Devices []*agd.Device
+}
diff --git a/internal/profiledb/testdata/profiles.json b/internal/profiledb/testdata/profiles.json
new file mode 100644
index 0000000..d6b7350
--- /dev/null
+++ b/internal/profiledb/testdata/profiles.json
@@ -0,0 +1,77 @@
+{
+ "sync_time": "SYNC_TIME",
+ "profiles": [
+ {
+ "Parental": {
+ "Schedule": {
+ "Week": [
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ },
+ {
+ "Start": 0,
+ "End": 700
+ }
+ ],
+ "TimeZone": "Europe/Brussels"
+ },
+ "Enabled": true
+ },
+ "BlockingMode": {
+ "type": "null_ip"
+ },
+ "ID": "prof1234",
+ "UpdateTime": "0001-01-01T00:00:00.000Z",
+ "DeviceIDs": [
+ "dev1234"
+ ],
+ "RuleListIDs": [
+ "adguard_dns_filter"
+ ],
+ "CustomRules": [
+ "|blocked-by-custom.example"
+ ],
+ "FilteredResponseTTL": 10000000000,
+ "FilteringEnabled": true,
+ "SafeBrowsingEnabled": true,
+ "RuleListsEnabled": true,
+ "QueryLogEnabled": true,
+ "Deleted": false,
+ "BlockPrivateRelay": true,
+ "BlockFirefoxCanary": true
+ }
+ ],
+ "devices": [
+ {
+ "ID": "dev1234",
+ "LinkedIP": "1.2.3.4",
+ "Name": "dev1",
+ "DedicatedIPs": [
+ "1.2.4.5"
+ ],
+ "FilteringEnabled": true
+ }
+ ],
+ "version": 6
+}
diff --git a/internal/tools/go.mod b/internal/tools/go.mod
index 201dc77..ea99e18 100644
--- a/internal/tools/go.mod
+++ b/internal/tools/go.mod
@@ -8,26 +8,27 @@ require (
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28
github.com/kisielk/errcheck v1.6.3
github.com/kyoh86/looppointer v0.2.1
- github.com/securego/gosec/v2 v2.15.0
- golang.org/x/tools v0.7.0
- golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e
- honnef.co/go/tools v0.4.2
- mvdan.cc/gofumpt v0.4.0
- mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95
+ github.com/securego/gosec/v2 v2.16.0
+ golang.org/x/tools v0.9.3
+ golang.org/x/vuln v0.1.0
+ google.golang.org/protobuf v1.30.0
+ honnef.co/go/tools v0.4.3
+ mvdan.cc/gofumpt v0.5.0
+ mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8
)
require (
- github.com/BurntSushi/toml v1.2.1 // indirect
+ github.com/BurntSushi/toml v1.3.1 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
- github.com/gookit/color v1.5.2 // indirect
+ github.com/gookit/color v1.5.3 // indirect
github.com/kyoh86/nolint v0.0.1 // indirect
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/exp v0.0.0-20230307190834-24139beb5833 // indirect
- golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833 // indirect
- golang.org/x/mod v0.9.0 // indirect
- golang.org/x/sync v0.1.0 // indirect
- golang.org/x/sys v0.6.0 // indirect
+ golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 // indirect
+ golang.org/x/mod v0.10.0 // indirect
+ golang.org/x/sync v0.2.0 // indirect
+ golang.org/x/sys v0.8.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/internal/tools/go.sum b/internal/tools/go.sum
index 5653710..9c17051 100644
--- a/internal/tools/go.sum
+++ b/internal/tools/go.sum
@@ -1,28 +1,31 @@
-github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak=
-github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
+github.com/BurntSushi/toml v1.3.1 h1:rHnDkSK+/g6DlREUK73PkmIs60pqrnuduK+JmP++JmU=
+github.com/BurntSushi/toml v1.3.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE=
+github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
-github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
+github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golangci/misspell v0.4.0 h1:KtVB/hTK4bbL/S6bs64rYyk8adjmh1BygbBiaAiX+a0=
github.com/golangci/misspell v0.4.0/go.mod h1:W6O/bwV6lGDxUCChm2ykw9NQdd5bYd1Xkjo88UcWyJc=
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/gookit/color v1.5.2 h1:uLnfXcaFjlrDnQDT+NCBcfhrXqYTx/rcCa6xn01Y8yI=
-github.com/gookit/color v1.5.2/go.mod h1:w8h4bGiHeeBpvQVePTutdbERIUf3oJE5lZ8HM0UgXyg=
+github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE=
+github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE=
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28 h1:9alfqbrhuD+9fLZ4iaAVwhlp5PEhmnBt7yvK2Oy5C1U=
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
github.com/kisielk/errcheck v1.6.3/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw=
-github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kyoh86/looppointer v0.2.1 h1:Jx9fnkBj/JrIryBLMTYNTj9rvc2SrPS98Dg0w7fxdJg=
github.com/kyoh86/looppointer v0.2.1/go.mod h1:q358WcM8cMWU+5vzqukvaZtnJi1kw/MpRHQm3xvTrjw=
@@ -30,21 +33,15 @@ github.com/kyoh86/nolint v0.0.1 h1:GjNxDEkVn2wAxKHtP7iNTrRxytRZ1wXxLV5j4XzGfRU=
github.com/kyoh86/nolint v0.0.1/go.mod h1:1ZiZZ7qqrZ9dZegU96phwVcdQOMKIqRzFJL3ewq9gtI=
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 h1:4kuARK6Y6FxaNu/BnU2OAaLF86eTVhP2hjTB6iMvItA=
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8=
-github.com/onsi/ginkgo/v2 v2.8.0 h1:pAM+oBNPrpXRs+E/8spkeGx9QgekbRVyr74EUvRVOUI=
-github.com/onsi/gomega v1.26.0 h1:03cDLK28U6hWvCAns6NeydX3zIm4SF3ci69ulidS32Q=
-github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
+github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
+github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
-github.com/securego/gosec/v2 v2.15.0 h1:v4Ym7FF58/jlykYmmhZ7mTm7FQvN/setNm++0fgIAtw=
-github.com/securego/gosec/v2 v2.15.0/go.mod h1:VOjTrZOkUtSDt2QLSJmQBMWnvwiQPEjg0l+5juIqGk8=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyOs5U=
+github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -56,25 +53,25 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
-golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833 h1:jWGQJV4niP+CCmFW9ekjA9Zx8vYORzOUH2/Nl5WPuLQ=
-golang.org/x/exp/typeparams v0.0.0-20230307190834-24139beb5833/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
+golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 h1:pnP8r+W8Fm7XJ8CWtXi4S9oJmPBTrkfYN/dNbaPj6Y4=
+golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
-golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
-golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
+golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
+golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
-golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
+golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -84,34 +81,37 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
+golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
+golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
-golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
-golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
-golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e h1:zLJSre6/6VuA0wMLQe9ShzEncb0k8E/z9wVhmlnRdNs=
-golang.org/x/vuln v0.0.0-20230308034057-d4ed0a4fab9e/go.mod h1:ydpjOTRSBwOBFJRP/w5NF2HSPnFg1JxobEZQGOirxgo=
+golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
+golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
+golang.org/x/vuln v0.1.0 h1:9GRdj6wAIkDrsMevuolY+SXERPjQPp2P1ysYA0jpZe0=
+golang.org/x/vuln v0.1.0/go.mod h1:/YuzZYjGbwB8y19CisAppfyw3uTZnuCz3r+qgx/QRzU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
+google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-honnef.co/go/tools v0.4.2 h1:6qXr+R5w+ktL5UkwEbPp+fEvfyoMPche6GkOpGHZcLc=
-honnef.co/go/tools v0.4.2/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA=
-mvdan.cc/gofumpt v0.4.0 h1:JVf4NN1mIpHogBj7ABpgOyZc65/UUOkKQFkoURsz4MM=
-mvdan.cc/gofumpt v0.4.0/go.mod h1:PljLOHDeZqgS8opHRKLzp2It2VBuSdteAgqUfzMTxlQ=
-mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95 h1:n/xhncJPSt0YzfOhnyn41XxUdrWQNgmLBG72FE27Fqw=
-mvdan.cc/unparam v0.0.0-20230125043941-70a0ce6e7b95/go.mod h1:2vU506e8nGWodqcci641NLi4im2twWSq4Lod756epHQ=
+honnef.co/go/tools v0.4.3 h1:o/n5/K5gXqk8Gozvs2cnL0F2S1/g1vcGCAx2vETjITw=
+honnef.co/go/tools v0.4.3/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA=
+mvdan.cc/gofumpt v0.5.0 h1:0EQ+Z56k8tXjj/6TQD25BFNKQXpCvT0rnansIc7Ug5E=
+mvdan.cc/gofumpt v0.5.0/go.mod h1:HBeVDtMKRZpXyxFciAirzdKklDlGu8aAy1wEbH5Y9js=
+mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8 h1:VuJo4Mt0EVPychre4fNlDWDuE5AjXtPJpRUWqZDQhaI=
+mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8/go.mod h1:Oh/d7dEtzsNHGOq1Cdv8aMm3KdKhVvPbRQcM8WFpBR8=
diff --git a/internal/tools/tools.go b/internal/tools/tools.go
index 230d6c4..deaa4f1 100644
--- a/internal/tools/tools.go
+++ b/internal/tools/tools.go
@@ -13,6 +13,7 @@ import (
_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
_ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow"
_ "golang.org/x/vuln/cmd/govulncheck"
+ _ "google.golang.org/protobuf/cmd/protoc-gen-go"
_ "honnef.co/go/tools/cmd/staticcheck"
_ "mvdan.cc/gofumpt"
_ "mvdan.cc/unparam"
diff --git a/internal/websvc/handler.go b/internal/websvc/handler.go
index ed09a80..c84273d 100644
--- a/internal/websvc/handler.go
+++ b/internal/websvc/handler.go
@@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
"github.com/AdguardTeam/golibs/errors"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/unix"
)
@@ -22,7 +23,7 @@ var _ http.Handler = (*Service)(nil)
// ServeHTTP implements the http.Handler interface for *Service.
func (svc *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
respHdr := w.Header()
- respHdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent())
+ respHdr.Set(httphdr.Server, agdhttp.UserAgent())
m, p, rAddr := r.Method, r.URL.Path, r.RemoteAddr
optlog.Debug3("websvc: starting req %s %s from %s", m, p, rAddr)
@@ -57,7 +58,7 @@ func (svc *Service) processRec(
action = "writing 404"
if len(svc.error404) != 0 {
body = svc.error404
- respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
+ respHdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML)
}
metrics.WebSvcError404RequestsTotal.Inc()
@@ -65,7 +66,7 @@ func (svc *Service) processRec(
action = "writing 500"
if len(svc.error500) != 0 {
body = svc.error500
- respHdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
+ respHdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML)
}
metrics.WebSvcError500RequestsTotal.Inc()
@@ -114,7 +115,7 @@ func (svc *Service) serveHTTP(w http.ResponseWriter, r *http.Request) {
func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
f := func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
- hdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent())
+ hdr.Set(httphdr.Server, agdhttp.UserAgent())
switch r.URL.Path {
case "/favicon.ico":
@@ -125,7 +126,7 @@ func safeBrowsingHandler(name string, blockPage []byte) (h http.Handler) {
// the predefined response instead.
serveRobotsDisallow(hdr, w, name)
default:
- hdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextHTML)
+ hdr.Set(httphdr.ContentType, agdhttp.HdrValTextHTML)
_, err := w.Write(blockPage)
if err != nil {
@@ -186,7 +187,7 @@ type StaticFile struct {
// serveRobotsDisallow writes predefined disallow-all response.
func serveRobotsDisallow(hdr http.Header, w http.ResponseWriter, name string) {
- hdr.Set(agdhttp.HdrNameContentType, agdhttp.HdrValTextPlain)
+ hdr.Set(httphdr.ContentType, agdhttp.HdrValTextPlain)
_, err := io.WriteString(w, agdhttp.RobotsDisallowAll)
if err != nil {
diff --git a/internal/websvc/handler_test.go b/internal/websvc/handler_test.go
index 24afecf..35cca21 100644
--- a/internal/websvc/handler_test.go
+++ b/internal/websvc/handler_test.go
@@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,7 +34,7 @@ func TestService_ServeHTTP(t *testing.T) {
"/favicon.ico": {
Content: []byte{},
Headers: http.Header{
- agdhttp.HdrNameContentType: []string{"image/x-icon"},
+ httphdr.ContentType: []string{"image/x-icon"},
},
},
}
@@ -62,8 +63,8 @@ func TestService_ServeHTTP(t *testing.T) {
// Static content path with headers.
h := http.Header{
- agdhttp.HdrNameContentType: []string{"image/x-icon"},
- agdhttp.HdrNameServer: []string{"AdGuardDNS/"},
+ httphdr.ContentType: []string{"image/x-icon"},
+ httphdr.Server: []string{"AdGuardDNS/"},
}
assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, h)
@@ -96,7 +97,7 @@ func assertResponse(
svc.ServeHTTP(rw, r)
assert.Equal(t, statusCode, rw.Code)
- assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(agdhttp.HdrNameServer))
+ assert.Equal(t, agdhttp.UserAgent(), rw.Header().Get(httphdr.Server))
return rw
}
diff --git a/internal/websvc/linkip.go b/internal/websvc/linkip.go
index 66e394b..2270aa9 100644
--- a/internal/websvc/linkip.go
+++ b/internal/websvc/linkip.go
@@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/metrics"
"github.com/AdguardTeam/AdGuardDNS/internal/optlog"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
@@ -47,10 +48,10 @@ func linkedIPHandler(
// Set the X-Forwarded-For header to a nil value to make sure that
// the proxy doesn't add it automatically.
- hdr["X-Forwarded-For"] = nil
+ hdr[httphdr.XForwardedFor] = nil
// Make sure that all requests are marked with our user agent.
- hdr.Set(agdhttp.HdrNameUserAgent, agdhttp.UserAgent())
+ hdr.Set(httphdr.UserAgent, agdhttp.UserAgent())
}
// Use largely the same transport as http.DefaultTransport, but with a
@@ -75,10 +76,10 @@ func linkedIPHandler(
// Delete the Server header value from the upstream.
modifyResponse := func(r *http.Response) (err error) {
- r.Header.Del(agdhttp.HdrNameServer)
+ r.Header.Del(httphdr.Server)
// Make sure that this URL can be used from the web page.
- r.Header.Set(agdhttp.HdrNameAccessControlAllowOrigin, agdhttp.HdrValWildcard)
+ r.Header.Set(httphdr.AccessControlAllowOrigin, agdhttp.HdrValWildcard)
return nil
}
@@ -111,7 +112,7 @@ var _ http.Handler = (*linkedIPProxy)(nil)
func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Set the Server header here, so that all 404 and 500 responses carry it.
respHdr := w.Header()
- respHdr.Set(agdhttp.HdrNameServer, agdhttp.UserAgent())
+ respHdr.Set(httphdr.Server, agdhttp.UserAgent())
m, p, rAddr := r.Method, r.URL.Path, r.RemoteAddr
optlog.Debug3("websvc: starting req %s %s from %s", m, p, rAddr)
@@ -124,9 +125,9 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Remove all proxy headers before sending the request to proxy.
hdr := r.Header
- hdr.Del("Forwarded")
- hdr.Del("True-Client-IP")
- hdr.Del("X-Real-IP")
+ hdr.Del(httphdr.Forwarded)
+ hdr.Del(httphdr.TrueClientIP)
+ hdr.Del(httphdr.XRealIP)
// Set the real IP.
ip, err := netutil.SplitHost(rAddr)
@@ -141,12 +142,12 @@ func (prx *linkedIPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
- hdr.Set("CF-Connecting-IP", ip)
+ hdr.Set(httphdr.CFConnectingIP, ip)
// Set the request ID.
reqID := agd.NewRequestID()
r = r.WithContext(agd.WithRequestID(r.Context(), reqID))
- hdr.Set(agdhttp.HdrNameXRequestID, string(reqID))
+ hdr.Set(httphdr.XRequestID, string(reqID))
log.Debug("%s: proxying %s %s: req %s", prx.logPrefix, m, p, reqID)
diff --git a/internal/websvc/linkip_internal_test.go b/internal/websvc/linkip_internal_test.go
index 745fc42..e325474 100644
--- a/internal/websvc/linkip_internal_test.go
+++ b/internal/websvc/linkip_internal_test.go
@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -22,7 +23,7 @@ func TestLinkedIPProxy_ServeHTTP(t *testing.T) {
upstream := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
pt := testutil.PanicT{}
- rid := r.Header.Get(agdhttp.HdrNameXRequestID)
+ rid := r.Header.Get(httphdr.XRequestID)
require.NotEmpty(pt, rid)
numReq.Add(1)
@@ -103,7 +104,7 @@ func TestLinkedIPProxy_ServeHTTP(t *testing.T) {
assert.Equal(t, prev+tc.diff, numReq.Load(), "req was not expected")
assert.Equal(t, tc.wantCode, rw.Code)
- assert.Equal(t, expectedUserAgent, rw.Header().Get(agdhttp.HdrNameServer))
+ assert.Equal(t, expectedUserAgent, rw.Header().Get(httphdr.Server))
})
}
}
diff --git a/internal/websvc/websvc_test.go b/internal/websvc/websvc_test.go
index 8b663ef..1b8309b 100644
--- a/internal/websvc/websvc_test.go
+++ b/internal/websvc/websvc_test.go
@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/websvc"
+ "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -142,7 +143,7 @@ func assertContent(t *testing.T, addr netip.AddrPort, path string, status int, e
assert.Equal(t, expected, body)
assert.Equal(t, status, resp.StatusCode)
- assert.Equal(t, agdhttp.UserAgent(), resp.Header.Get(agdhttp.HdrNameServer))
+ assert.Equal(t, agdhttp.UserAgent(), resp.Header.Get(httphdr.Server))
}
func startService(t *testing.T, c *websvc.Config) {
diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh
index 33baede..9c5b2b4 100644
--- a/scripts/make/go-lint.sh
+++ b/scripts/make/go-lint.sh
@@ -35,7 +35,7 @@ set -f -u
go_version="$( "${GO:-go}" version )"
readonly go_version
-go_min_version='go1.20.2'
+go_min_version='go1.20.5'
go_version_msg="
warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}).
if you have the version installed, please set the GO environment variable.
@@ -80,6 +80,13 @@ esac
#
# * Package golang.org/x/net/context has been moved into stdlib.
#
+# NOTE: For AdGuard DNS, there are the following exceptions:
+#
+# * internal/profiledb/internal/filecachepb/filecache.pb.go: a file generated
+# by the protobuf compiler.
+#
+# * internal/profiledb/internal/filecachepb/unsafe.go: a “safe” unsafe helper
+# to prevent excessive allocations.
blocklist_imports() {
git grep\
-e '[[:space:]]"errors"$'\
@@ -91,6 +98,8 @@ blocklist_imports() {
-e '[[:space:]]"golang.org/x/net/context"$'\
-n\
-- '*.go'\
+ ':!internal/profiledb/internal/filecachepb/filecache.pb.go'\
+ ':!internal/profiledb/internal/filecachepb/unsafe.go'\
| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 blocked import:\2/'\
|| exit 0
}
@@ -158,7 +167,8 @@ run_linter "$GO" vet ./... "$dnssrvmod"
run_linter govulncheck ./... "$dnssrvmod"
-run_linter gocyclo --over 10 .
+# NOTE: For AdGuard DNS, ignore the generated protobuf file.
+run_linter gocyclo --ignore '\.pb\.go$' --over 10 .
run_linter ineffassign ./... "$dnssrvmod"
@@ -174,7 +184,28 @@ run_linter nilness ./... "$dnssrvmod"
# Do not use fieldalignment on $dnssrvmod, because ameshkov likes to place
# struct fields in an order that he considers more readable.
-run_linter fieldalignment ./...
+#
+# TODO(a.garipov): Remove the loop once golang/go#60509 is fixed.
+(
+ run_linter fieldalignment ./main.go
+
+ set +f
+ for d in ./internal/*/ ./internal/*/*/ ./internal/*/*/*/
+ do
+ case "$d"
+ in
+ (*/testdata/*|\
+ ./internal/dnsserver/*|\
+ ./internal/profiledb/internal/filecachepb/|\
+ ./internal/tools/)
+ continue
+ ;;
+ (*)
+ run_linter fieldalignment "$d"
+ ;;
+ esac
+ done
+)
run_linter -e shadow --strict ./... "$dnssrvmod"
diff --git a/scripts/make/go-tools.sh b/scripts/make/go-tools.sh
index d934458..8ec6135 100644
--- a/scripts/make/go-tools.sh
+++ b/scripts/make/go-tools.sh
@@ -44,6 +44,7 @@ rm -f\
bin/looppointer\
bin/misspell\
bin/nilness\
+ bin/protoc-gen-go\
bin/shadow\
bin/staticcheck\
bin/unparam\
@@ -71,6 +72,7 @@ env\
golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness\
golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow\
golang.org/x/vuln/cmd/govulncheck\
+ google.golang.org/protobuf/cmd/protoc-gen-go\
honnef.co/go/tools/cmd/staticcheck\
mvdan.cc/gofumpt\
mvdan.cc/unparam\
diff --git a/scripts/make/txt-lint.sh b/scripts/make/txt-lint.sh
index 48b926a..6edc8bf 100644
--- a/scripts/make/txt-lint.sh
+++ b/scripts/make/txt-lint.sh
@@ -27,6 +27,26 @@ set -f -u
# Source the common helpers, including not_found.
. ./scripts/make/helper.sh
+# trailing_newlines is a simple check that makes sure that all plain-text files
+# have a trailing newlines to make sure that all tools work correctly with them.
+trailing_newlines() {
+ nl="$( printf "\n" )"
+ readonly nl
+
+ # NOTE: Adjust for your project.
+ git ls-files\
+ ':!*.mmdb'\
+ | while read -r f
+ do
+ if [ "$( tail -c -1 "$f" )" != "$nl" ]
+ then
+ printf '%s: must have a trailing newline\n' "$f"
+ fi
+ done
+}
+
+run_linter -e trailing_newlines
+
git ls-files -- '*.md' '*.yaml' '*.yml'\
| xargs misspell --error\
| sed -e 's/^/misspell: /'