mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
110 lines
2.8 KiB
Go
110 lines
2.8 KiB
Go
package backendpb_test
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/netip"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/backendpb"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/consul"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
|
|
"github.com/AdguardTeam/golibs/testutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
)
|
|
|
|
// testRateLimitServiceServer is the [backendpb.RateLimitServiceServer] for
|
|
// tests.
|
|
type testRateLimitServiceServer struct {
|
|
backendpb.UnimplementedRateLimitServiceServer
|
|
|
|
OnGetRateLimitSettings func(
|
|
ctx context.Context,
|
|
req *backendpb.RateLimitSettingsRequest,
|
|
) (resp *backendpb.RateLimitSettingsResponse, err error)
|
|
}
|
|
|
|
// type check
|
|
var _ backendpb.DNSServiceServer = (*testDNSServiceServer)(nil)
|
|
|
|
// GetRateLimitSettings implements the [backendpb.RateLimitServiceServer]
|
|
// interface for *testRateLimitServiceServer.
|
|
func (s *testRateLimitServiceServer) GetRateLimitSettings(
|
|
ctx context.Context,
|
|
req *backendpb.RateLimitSettingsRequest,
|
|
) (resp *backendpb.RateLimitSettingsResponse, err error) {
|
|
return s.OnGetRateLimitSettings(ctx, req)
|
|
}
|
|
|
|
func TestRateLimiter_Refresh(t *testing.T) {
|
|
var (
|
|
allowedIP = netip.MustParseAddr("1.2.3.4")
|
|
notAllowedIP = netip.MustParseAddr("4.3.2.1")
|
|
|
|
cidr = &backendpb.CidrRange{
|
|
Address: allowedIP.AsSlice(),
|
|
Prefix: 32,
|
|
}
|
|
)
|
|
|
|
srv := &testRateLimitServiceServer{
|
|
OnGetRateLimitSettings: func(
|
|
ctx context.Context,
|
|
req *backendpb.RateLimitSettingsRequest,
|
|
) (resp *backendpb.RateLimitSettingsResponse, err error) {
|
|
return &backendpb.RateLimitSettingsResponse{
|
|
AllowedSubnets: []*backendpb.CidrRange{cidr},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", "localhost:0")
|
|
require.NoError(t, err)
|
|
|
|
grpcSrv := grpc.NewServer(
|
|
grpc.ConnectionTimeout(1*time.Second),
|
|
grpc.Creds(insecure.NewCredentials()),
|
|
)
|
|
backendpb.RegisterRateLimitServiceServer(grpcSrv, srv)
|
|
|
|
go func() {
|
|
pt := testutil.PanicT{}
|
|
|
|
srvErr := grpcSrv.Serve(ln)
|
|
require.NoError(pt, srvErr)
|
|
}()
|
|
t.Cleanup(grpcSrv.GracefulStop)
|
|
|
|
allowlist := ratelimit.NewDynamicAllowlist(nil, nil)
|
|
l, err := backendpb.NewRateLimiter(&backendpb.RateLimiterConfig{
|
|
Logger: backendpb.TestLogger,
|
|
Metrics: consul.EmptyMetrics{},
|
|
GRPCMetrics: backendpb.EmptyGRPCMetrics{},
|
|
Allowlist: allowlist,
|
|
Endpoint: &url.URL{
|
|
Scheme: "grpc",
|
|
Host: ln.Addr().String(),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
|
err = l.Refresh(ctx)
|
|
require.NoError(t, err)
|
|
|
|
ok, err := allowlist.IsAllowed(ctx, allowedIP)
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, ok)
|
|
|
|
ok, err = allowlist.IsAllowed(ctx, notAllowedIP)
|
|
require.NoError(t, err)
|
|
|
|
assert.False(t, ok)
|
|
}
|