AdGuardDNS/internal/billstat/runtime_test.go
Andrey Meshkov b6a98906a5 Sync v2.0
2022-08-26 14:18:35 +03:00

130 lines
3.4 KiB
Go

package billstat_test
import (
"context"
"testing"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdtest"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Common constants for tests.
const (
devID = "dev1234"
proto = agd.ProtoDoH
clientCtry = agd.CountryAD
clientASN agd.ASN = 42
)
func TestRuntimeRecorder_success(t *testing.T) {
var gotRecord *billstat.Record
c := &billstat.RuntimeRecorderConfig{
Uploader: &agdtest.BillStatUploader{
OnUpload: func(_ context.Context, records billstat.Records) (err error) {
gotRecord = records[devID]
return nil
},
},
}
r := billstat.NewRuntimeRecorder(c)
ctx := context.Background()
start := time.Now().Truncate(1 * time.Millisecond)
// Record two queries to make sure that the queries counter is properly
// incremented.
const reqNum = 2
var err error
for i := 0; i < reqNum; i++ {
r.Record(ctx, devID, clientCtry, clientASN, start, proto)
}
err = r.Refresh(context.Background())
require.NoError(t, err)
require.NotNil(t, gotRecord)
assert.Equal(t, gotRecord.Time, start)
assert.Equal(t, gotRecord.Country, clientCtry)
assert.Equal(t, gotRecord.ASN, clientASN)
assert.Equal(t, gotRecord.Queries, int32(reqNum))
assert.Equal(t, gotRecord.Proto, proto)
}
func TestRuntimeRecorder_fail(t *testing.T) {
const testError errors.Error = "test error"
var emulateFail bool
var gotRecord *billstat.Record
uploadSync := make(chan agdtest.Signal)
onUpload := func(_ context.Context, records billstat.Records) (err error) {
if emulateFail {
pt := testutil.PanicT{}
// Give the goroutine a signal that it can now record another query
// to emulate a situation where a query gets recorded while an
// upload is in progress.
agdtest.RequireSend(pt, uploadSync, testTimeout)
// Wait for the other query in the goroutine to be recorded and
// proceed to returning an error after that.
agdtest.RequireReceive(pt, uploadSync, testTimeout)
return testError
}
gotRecord = records[devID]
return nil
}
c := &billstat.RuntimeRecorderConfig{
Uploader: &agdtest.BillStatUploader{
OnUpload: onUpload,
},
}
r := billstat.NewRuntimeRecorder(c)
ctx := context.Background()
start := time.Now().Truncate(1 * time.Millisecond)
r.Record(ctx, devID, clientCtry, clientASN, start, proto)
// Request the backend, make a concurrent request while an upload is in
// progress, receive the error, and expect the data to be returned to the
// database and properly merged.
emulateFail = true
go func() {
pt := testutil.PanicT{}
agdtest.RequireReceive(pt, uploadSync, testTimeout)
r.Record(ctx, devID, clientCtry, clientASN, start, proto)
agdtest.RequireSend(pt, uploadSync, testTimeout)
}()
err := r.Refresh(context.Background())
require.ErrorIs(t, err, testError)
require.Nil(t, gotRecord)
// Request the backend again, expect the correct, merged data.
emulateFail = false
err = r.Refresh(context.Background())
require.NoError(t, err)
require.NotNil(t, gotRecord)
assert.Equal(t, gotRecord.Time, start)
assert.Equal(t, gotRecord.Country, clientCtry)
assert.Equal(t, gotRecord.ASN, clientASN)
assert.Equal(t, gotRecord.Queries, int32(2))
assert.Equal(t, gotRecord.Proto, proto)
}