mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2025-02-20 11:23:36 +08:00
160 lines
4.0 KiB
Go
160 lines
4.0 KiB
Go
package debugsvc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"maps"
|
|
"net/http"
|
|
"path"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdhttp"
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/agdservice"
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
"github.com/AdguardTeam/golibs/httphdr"
|
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
|
)
|
|
|
|
// RefresherID is a type alias for strings that represent IDs of refreshers.
|
|
//
|
|
// TODO(a.garipov): Consider a newtype with validations.
|
|
type RefresherID = string
|
|
|
|
// Refreshers is a type alias for maps of refresher IDs to Refreshers
|
|
// themselves.
|
|
type Refreshers map[RefresherID]agdservice.Refresher
|
|
|
|
// refreshHandler performs debug refreshes.
|
|
type refreshHandler struct {
|
|
refrs Refreshers
|
|
}
|
|
|
|
// refreshRequest describes the request to the POST /debug/api/refresh HTTP API.
|
|
type refreshRequest struct {
|
|
// Patterns is the slice of path patterns to match the refreshers IDs.
|
|
Patterns []string `json:"ids"`
|
|
}
|
|
|
|
// refreshResponse describes the response to the POST /debug/api/refresh HTTP
|
|
// API.
|
|
type refreshResponse struct {
|
|
Results map[RefresherID]string `json:"results"`
|
|
}
|
|
|
|
// type check
|
|
var _ http.Handler = (*refreshHandler)(nil)
|
|
|
|
// ServeHTTP implements the [http.Handler] interface for *refreshHandler.
|
|
func (h *refreshHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
l := slogutil.MustLoggerFromContext(ctx)
|
|
|
|
req := &refreshRequest{}
|
|
err := json.NewDecoder(r.Body).Decode(req)
|
|
if err != nil {
|
|
l.ErrorContext(ctx, "decoding request", slogutil.KeyError, err)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
reqIDs, err := h.idsFromReq(req.Patterns)
|
|
if err != nil {
|
|
l.ErrorContext(ctx, "validating request", slogutil.KeyError, err)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
resp := &refreshResponse{
|
|
Results: make(map[RefresherID]string, len(reqIDs)),
|
|
}
|
|
|
|
for _, id := range reqIDs {
|
|
resp.Results[id] = h.refresh(ctx, l, id)
|
|
}
|
|
|
|
w.Header().Set(httphdr.ContentType, agdhttp.HdrValApplicationJSON)
|
|
err = json.NewEncoder(w).Encode(resp)
|
|
if err != nil {
|
|
l.ErrorContext(ctx, "writing response", slogutil.KeyError, err)
|
|
}
|
|
}
|
|
|
|
// idsFromReq validates given patterns from the request and returns the IDs of
|
|
// matching refreshers to refresh.
|
|
//
|
|
// TODO(e.burkov): Validate patterns.
|
|
func (h *refreshHandler) idsFromReq(patterns []string) (ids []RefresherID, err error) {
|
|
ok, err := isWildcard(patterns)
|
|
if err != nil {
|
|
// Don't wrap the error, because it's informative enough as is.
|
|
return nil, err
|
|
}
|
|
|
|
refrIDs := slices.Collect(maps.Keys(h.refrs))
|
|
if ok {
|
|
return refrIDs, nil
|
|
}
|
|
|
|
return matchPatterns(refrIDs, patterns), nil
|
|
}
|
|
|
|
// refresh performs a single refresh and returns the result as a string.
|
|
func (h *refreshHandler) refresh(ctx context.Context, l *slog.Logger, id RefresherID) (res string) {
|
|
r, ok := h.refrs[id]
|
|
if !ok {
|
|
return "error: refresher not found"
|
|
}
|
|
|
|
start := time.Now()
|
|
err := r.Refresh(ctx)
|
|
if err != nil {
|
|
l.ErrorContext(ctx, "refresher error", "id", id, slogutil.KeyError, err)
|
|
|
|
return fmt.Sprintf("error: %s", err)
|
|
}
|
|
|
|
l.InfoContext(ctx, "refresh finished", "id", id, "duration", time.Since(start))
|
|
|
|
return "ok"
|
|
}
|
|
|
|
// isWildcard returns true if the list of patterns contains a single wildcard
|
|
// pattern. It also returns an error if the list is empty or contains a
|
|
// wildcard pattern mixed with the others.
|
|
func isWildcard(patterns []string) (ok bool, err error) {
|
|
switch len(patterns) {
|
|
case 0:
|
|
return false, errors.Error("no ids")
|
|
case 1:
|
|
if patterns[0] == "*" {
|
|
return true, nil
|
|
} else {
|
|
return false, nil
|
|
}
|
|
default:
|
|
if slices.Contains(patterns, "*") {
|
|
return false, errors.Error(`"*" cannot be used with other ids`)
|
|
} else {
|
|
return false, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// matchPatterns matches ids against patterns and returns the resulting matches.
|
|
func matchPatterns(ids, patterns []string) (matches []string) {
|
|
for _, pattern := range patterns {
|
|
for _, id := range ids {
|
|
if match, _ := path.Match(pattern, id); match {
|
|
matches = append(matches, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
return matches
|
|
}
|