62 lines
1.8 KiB
Go
Raw Permalink Normal View History

2022-08-26 14:18:35 +03:00
package dnssvc
import (
"context"
"fmt"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver"
2024-01-04 19:22:32 +03:00
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
2022-08-26 14:18:35 +03:00
)
// errCollMetricsListener extends the default prometheus.ServerMetricsListener
// and overrides OnPanic and OnError methods. The point is to collect errors
// from inside the dnsserver.Server in addition to collecting prom metrics.
type errCollMetricsListener struct {
2024-01-04 19:22:32 +03:00
errColl errcoll.Interface
2022-08-26 14:18:35 +03:00
baseListener dnsserver.MetricsListener
}
// type check
var _ dnsserver.MetricsListener = (*errCollMetricsListener)(nil)
// OnRequest implements the dnsserver.MetricsListener interface for
// *errCollMetricsListener.
func (s *errCollMetricsListener) OnRequest(
ctx context.Context,
2024-01-04 19:22:32 +03:00
info *dnsserver.QueryInfo,
2022-08-26 14:18:35 +03:00
rw dnsserver.ResponseWriter,
) {
2024-01-04 19:22:32 +03:00
s.baseListener.OnRequest(ctx, info, rw)
2022-08-26 14:18:35 +03:00
}
// OnInvalidMsg implements the dnsserver.MetricsListener interface for
// *errCollMetricsListener.
func (s *errCollMetricsListener) OnInvalidMsg(ctx context.Context) {
s.baseListener.OnInvalidMsg(ctx)
}
2022-11-07 10:21:24 +03:00
// OnQUICAddressValidation implements the dnsserver.MetricsListener interface
// for *errCollMetricsListener.
func (s *errCollMetricsListener) OnQUICAddressValidation(hit bool) {
s.baseListener.OnQUICAddressValidation(hit)
}
2022-08-26 14:18:35 +03:00
// OnPanic implements the dnsserver.MetricsListener interface for
// *errCollMetricsListener.
func (s *errCollMetricsListener) OnPanic(ctx context.Context, v any) {
err, ok := v.(error)
if !ok {
err = fmt.Errorf("non-error panic: %v", v)
}
s.errColl.Collect(ctx, err)
s.baseListener.OnPanic(ctx, v)
}
// OnError implements the dnsserver.MetricsListener interface for
// *errCollMetricsListener.
func (s *errCollMetricsListener) OnError(ctx context.Context, err error) {
s.errColl.Collect(ctx, err)
s.baseListener.OnError(ctx, err)
}