181 lines
5.1 KiB
Go
Raw Normal View History

2022-08-26 14:18:35 +03:00
package dnsserver
import (
"context"
2024-01-04 19:22:32 +03:00
"fmt"
2022-08-26 14:18:35 +03:00
"net/url"
"time"
)
2024-01-04 19:22:32 +03:00
// ContextConstructor is an interface for constructing interfaces with
// deadlines, e.g. for request contexts.
type ContextConstructor interface {
New() (ctx context.Context, cancel context.CancelFunc)
}
// DefaultContextConstructor is the default implementation of the
// [ContextConstructor] interface.
type DefaultContextConstructor struct{}
// type check
var _ ContextConstructor = DefaultContextConstructor{}
// New implements the [ContextConstructor] interface for
// DefaultContextConstructor. It returns [context.Background] and an empty
// [context.CancelFunc].
func (DefaultContextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
return context.Background(), func() {}
}
// TimeoutContextConstructor is an implementation of the [ContextConstructor]
// interface that returns a context with the given timeout.
type TimeoutContextConstructor struct {
timeout time.Duration
}
// NewTimeoutContextConstructor returns a new properly initialized
// *TimeoutContextConstructor.
func NewTimeoutContextConstructor(timeout time.Duration) (c *TimeoutContextConstructor) {
return &TimeoutContextConstructor{
timeout: timeout,
}
}
// type check
var _ ContextConstructor = (*TimeoutContextConstructor)(nil)
// New implements the [ContextConstructor] interface for
// *TimeoutContextConstructor. It returns a context with its timeout and the
// corresponding cancelation function.
func (c *TimeoutContextConstructor) New() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
2022-08-26 14:18:35 +03:00
2024-01-04 19:22:32 +03:00
// ctxKey is the type for context keys.
2022-08-26 14:18:35 +03:00
type ctxKey int
const (
ctxKeyServerInfo ctxKey = iota
2022-11-07 10:21:24 +03:00
ctxKeyRequestInfo
2022-08-26 14:18:35 +03:00
)
2024-01-04 19:22:32 +03:00
// type check
var _ fmt.Stringer = ctxKey(0)
// String implements the [fmt.Stringer] interface for ctxKey.
func (k ctxKey) String() (s string) {
switch k {
case ctxKeyServerInfo:
return "dnsserver.ctxKeyServerInfo"
case ctxKeyRequestInfo:
return "dnsserver.ctxKeyRequestInfo"
default:
panic(fmt.Errorf("bad ctx key value %d", k))
}
}
2022-08-26 14:18:35 +03:00
// ServerInfo is a structure that contains basic server information. It is
// attached to every context.Context created inside dnsserver.
type ServerInfo struct {
// Name is the name of the server (Server.Name).
Name string
// Addr is the address that the server is configured to listen on.
Addr string
// Proto is the protocol of the server (Server.Proto).
Proto Protocol
}
2024-01-04 19:22:32 +03:00
// ContextWithServerInfo attaches ServerInfo to the specified context. s should
// not be nil.
func ContextWithServerInfo(parent context.Context, si *ServerInfo) (ctx context.Context) {
return context.WithValue(parent, ctxKeyServerInfo, si)
2022-08-26 14:18:35 +03:00
}
// ServerInfoFromContext gets ServerInfo attached to the context.
2024-01-04 19:22:32 +03:00
func ServerInfoFromContext(ctx context.Context) (si *ServerInfo, found bool) {
v := ctx.Value(ctxKeyServerInfo)
if v == nil {
return nil, false
}
2022-08-26 14:18:35 +03:00
2024-01-04 19:22:32 +03:00
ri, ok := v.(*ServerInfo)
if !ok {
panicBadType(ctxKeyServerInfo, v)
}
return ri, true
2022-08-26 14:18:35 +03:00
}
// MustServerInfoFromContext gets ServerInfo attached to the context and panics
// if it is not found.
2024-01-04 19:22:32 +03:00
func MustServerInfoFromContext(ctx context.Context) (si *ServerInfo) {
si, found := ServerInfoFromContext(ctx)
2022-08-26 14:18:35 +03:00
if !found {
panic("server info not found in the context")
}
2024-01-04 19:22:32 +03:00
return si
2022-08-26 14:18:35 +03:00
}
2022-11-07 10:21:24 +03:00
// RequestInfo is a structure that contains basic request information. It is
// attached to every context.Context linked to processing a DNS request.
type RequestInfo struct {
2024-01-04 19:22:32 +03:00
// URL is the request URL. It is set only if the protocol of the server is
// DoH.
URL *url.URL
// Userinfo is the userinfo from the basic authentication header. It is set
// only if the protocol of the server is DoH.
Userinfo *url.Userinfo
2022-08-26 14:18:35 +03:00
2024-01-04 19:22:32 +03:00
// StartTime is the request's start time. It's never zero value.
StartTime time.Time
2024-07-10 19:49:07 +03:00
// TLSServerName is the original, non-lowercased server name field of the
// client's TLS hello request. It is set only if the protocol of the server
// is either DoQ, DoT or DoH.
2024-01-04 19:22:32 +03:00
//
// TODO(ameshkov): use r.TLS with DoH3 (see addRequestInfo).
TLSServerName string
2022-11-07 10:21:24 +03:00
}
2022-08-26 14:18:35 +03:00
2024-01-04 19:22:32 +03:00
// ContextWithRequestInfo attaches RequestInfo to the specified context. ri
// should not be nil.
func ContextWithRequestInfo(parent context.Context, ri *RequestInfo) (ctx context.Context) {
2022-11-07 10:21:24 +03:00
return context.WithValue(parent, ctxKeyRequestInfo, ri)
2022-08-26 14:18:35 +03:00
}
2022-11-07 10:21:24 +03:00
// RequestInfoFromContext gets RequestInfo from the specified context.
2024-01-04 19:22:32 +03:00
func RequestInfoFromContext(ctx context.Context) (ri *RequestInfo, found bool) {
v := ctx.Value(ctxKeyRequestInfo)
if v == nil {
return nil, false
}
ri, ok := v.(*RequestInfo)
if !ok {
panicBadType(ctxKeyRequestInfo, v)
}
2022-11-07 10:21:24 +03:00
2024-01-04 19:22:32 +03:00
return ri, true
2022-08-26 14:18:35 +03:00
}
2022-11-07 10:21:24 +03:00
// MustRequestInfoFromContext gets RequestInfo attached to the context and
// panics if it is not found.
2024-01-04 19:22:32 +03:00
func MustRequestInfoFromContext(ctx context.Context) (ri *RequestInfo) {
2022-11-07 10:21:24 +03:00
ri, found := RequestInfoFromContext(ctx)
if !found {
panic("request info not found in the context")
}
return ri
2022-08-26 14:18:35 +03:00
}
2024-01-04 19:22:32 +03:00
// panicBadType is a helper that panics with a message about the context key and
// the expected type.
func panicBadType(key ctxKey, v any) {
panic(fmt.Errorf("bad type for %s: %T(%[2]v)", key, v))
2022-08-26 14:18:35 +03:00
}