Andrey Meshkov da0cb6fd0e Sync v2.9.0
2024-10-14 17:44:24 +03:00

160 lines
4.0 KiB
Go

package debugsvc
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"maps"
"net/http"
"path"
"slices"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// RefresherID is a type alias for strings that represent IDs of refreshers.
//
// TODO(a.garipov): Consider a newtype with validations.
type RefresherID = string
// Refreshers is a type alias for maps of refresher IDs to Refreshers
// themselves.
type Refreshers map[RefresherID]agdservice.Refresher
// refreshHandler performs debug refreshes.
type refreshHandler struct {
refrs Refreshers
}
// refreshRequest describes the request to the POST /debug/api/refresh HTTP API.
type refreshRequest struct {
// Patterns is the slice of path patterns to match the refreshers IDs.
Patterns []string `json:"ids"`
}
// refreshResponse describes the response to the POST /debug/api/refresh HTTP
// API.
type refreshResponse struct {
Results map[RefresherID]string `json:"results"`
}
// type check
var _ http.Handler = (*refreshHandler)(nil)
// ServeHTTP implements the [http.Handler] interface for *refreshHandler.
func (h *refreshHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := slogutil.MustLoggerFromContext(ctx)
req := &refreshRequest{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
l.ErrorContext(ctx, "decoding request", slogutil.KeyError, err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
reqIDs, err := h.idsFromReq(req.Patterns)
if err != nil {
l.ErrorContext(ctx, "validating request", slogutil.KeyError, err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := &refreshResponse{
Results: make(map[RefresherID]string, len(reqIDs)),
}
for _, id := range reqIDs {
resp.Results[id] = h.refresh(ctx, l, id)
}
w.Header().Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
l.ErrorContext(ctx, "writing response", slogutil.KeyError, err)
}
}
// idsFromReq validates given patterns from the request and returns the IDs of
// matching refreshers to refresh.
//
// TODO(e.burkov): Validate patterns.
func (h *refreshHandler) idsFromReq(patterns []string) (ids []RefresherID, err error) {
ok, err := isWildcard(patterns)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
refrIDs := slices.Collect(maps.Keys(h.refrs))
if ok {
return refrIDs, nil
}
return matchPatterns(refrIDs, patterns), nil
}
// refresh performs a single refresh and returns the result as a string.
func (h *refreshHandler) refresh(ctx context.Context, l *slog.Logger, id RefresherID) (res string) {
r, ok := h.refrs[id]
if !ok {
return "error: refresher not found"
}
start := time.Now()
err := r.Refresh(ctx)
if err != nil {
l.ErrorContext(ctx, "refresher error", "id", id, slogutil.KeyError, err)
return fmt.Sprintf("error: %s", err)
}
l.InfoContext(ctx, "refresh finished", "id", id, "duration", time.Since(start))
return "ok"
}
// isWildcard returns true if the list of patterns contains a single wildcard
// pattern. It also returns an error if the list is empty or contains a
// wildcard pattern mixed with the others.
func isWildcard(patterns []string) (ok bool, err error) {
switch len(patterns) {
case 0:
return false, errors.Error("no ids")
case 1:
if patterns[0] == "*" {
return true, nil
} else {
return false, nil
}
default:
if slices.Contains(patterns, "*") {
return false, errors.Error(`"*" cannot be used with other ids`)
} else {
return false, nil
}
}
}
// matchPatterns matches ids against patterns and returns the resulting matches.
func matchPatterns(ids, patterns []string) (matches []string) {
for _, pattern := range patterns {
for _, id := range ids {
if match, _ := path.Match(pattern, id); match {
matches = append(matches, id)
}
}
}
return matches
}