Pull request 2347: AGDNS-2690-global-context
Some checks failed
build / test (macOS-latest) (push) Has been cancelled
build / test (ubuntu-latest) (push) Has been cancelled
build / test (windows-latest) (push) Has been cancelled
lint / go-lint (push) Has been cancelled
lint / eslint (push) Has been cancelled
build / build-release (push) Has been cancelled
build / notify (push) Has been cancelled
lint / notify (push) Has been cancelled

Merge in DNS/adguard-home from AGDNS-2690-global-context to master

Squashed commit of the following:

commit 58d5999e5d9112b3391f988ed76e87eff2919d6b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 19 18:51:41 2025 +0300

    home: imp naming

commit cfb371df59c816be1022d499cc41ffaf2b72d124
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 19 18:42:52 2025 +0300

    home: global context
This commit is contained in:
Stanislav Chzhen 2025-02-19 19:02:56 +03:00
parent a5b073d070
commit 1e0873aa71
16 changed files with 203 additions and 203 deletions

View File

@ -356,7 +356,7 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
// There's no Cookie, check Basic authentication. // There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth() user, pass, ok := r.BasicAuth()
if ok { if ok {
u, _ = Context.auth.findUser(user, pass) u, _ = globalContext.auth.findUser(user, pass)
return u return u
} }

View File

@ -155,7 +155,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return return
} }
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil { if rateLimiter := globalContext.auth.rateLimiter; rateLimiter != nil {
if left := rateLimiter.check(remoteIP); left > 0 { if left := rateLimiter.check(remoteIP); left > 0 {
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds()))) w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
writeErrorWithIP( writeErrorWithIP(
@ -176,10 +176,10 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err) log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
} }
cookie, err := Context.auth.newCookie(req, remoteIP) cookie, err := globalContext.auth.newCookie(req, remoteIP)
if err != nil { if err != nil {
logIP := remoteIP logIP := remoteIP
if Context.auth.trustedProxies.Contains(ip.Unmap()) { if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
logIP = ip.String() logIP = ip.String()
} }
@ -213,7 +213,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
return return
} }
Context.auth.removeSession(c.Value) globalContext.auth.removeSession(c.Value)
c = &http.Cookie{ c = &http.Cookie{
Name: sessionCookieName, Name: sessionCookieName,
@ -232,7 +232,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// RegisterAuthHandlers - register handlers // RegisterAuthHandlers - register handlers
func RegisterAuthHandlers() { func RegisterAuthHandlers() {
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin))) globalContext.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
httpRegister(http.MethodGet, "/control/logout", handleLogout) httpRegister(http.MethodGet, "/control/logout", handleLogout)
} }
@ -254,13 +254,13 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
// Check Basic authentication. // Check Basic authentication.
user, pass, hasBasic := r.BasicAuth() user, pass, hasBasic := r.BasicAuth()
if hasBasic { if hasBasic {
_, isAuthenticated = Context.auth.findUser(user, pass) _, isAuthenticated = globalContext.auth.findUser(user, pass)
if !isAuthenticated { if !isAuthenticated {
log.Info("%s: invalid basic authorization value", pref) log.Info("%s: invalid basic authorization value", pref)
} }
} }
} else { } else {
res := Context.auth.checkSession(cookie.Value) res := globalContext.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK isAuthenticated = res == checkSessionOK
if !isAuthenticated { if !isAuthenticated {
log.Debug("%s: invalid cookie value: %q", pref, cookie) log.Debug("%s: invalid cookie value: %q", pref, cookie)
@ -294,12 +294,12 @@ func optionalAuth(
) (wrapped func(http.ResponseWriter, *http.Request)) { ) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
p := r.URL.Path p := r.URL.Path
authRequired := Context.auth != nil && Context.auth.authRequired() authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
if p == "/login.html" { if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil { if authRequired && err == nil {
// Redirect to the dashboard if already authenticated. // Redirect to the dashboard if already authenticated.
res := Context.auth.checkSession(cookie.Value) res := globalContext.auth.checkSession(cookie.Value)
if res == checkSessionOK { if res == checkSessionOK {
http.Redirect(w, r, "", http.StatusFound) http.Redirect(w, r, "", http.StatusFound)

View File

@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
users := []webUser{ users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, {Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
} }
Context.auth = InitAuth(fn, users, 60, nil, nil) globalContext.auth = InitAuth(fn, users, 60, nil, nil)
handlerCalled := false handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) { handler := func(_ http.ResponseWriter, _ *http.Request) {
@ -68,7 +68,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
// perform login // perform login
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "") cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, cookie) require.NotNil(t, cookie)
@ -114,7 +114,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie) r.Header.Del(httphdr.Cookie)
Context.auth.Close() globalContext.auth.Close()
} }
func TestRealIP(t *testing.T) { func TestRealIP(t *testing.T) {

View File

@ -486,9 +486,9 @@ var config = &configuration{
// configFilePath returns the absolute path to the symlink-evaluated path to the // configFilePath returns the absolute path to the symlink-evaluated path to the
// current config file. // current config file.
func configFilePath() (confPath string) { func configFilePath() (confPath string) {
confPath, err := filepath.EvalSymlinks(Context.confFilePath) confPath, err := filepath.EvalSymlinks(globalContext.confFilePath)
if err != nil { if err != nil {
confPath = Context.confFilePath confPath = globalContext.confFilePath
logFunc := log.Error logFunc := log.Error
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
logFunc = log.Debug logFunc = log.Debug
@ -498,7 +498,7 @@ func configFilePath() (confPath string) {
} }
if !filepath.IsAbs(confPath) { if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, confPath) confPath = filepath.Join(globalContext.workDir, confPath)
} }
return confPath return confPath
@ -530,8 +530,8 @@ func parseConfig() (err error) {
} }
migrator := configmigrate.New(&configmigrate.Config{ migrator := configmigrate.New(&configmigrate.Config{
WorkingDir: Context.workDir, WorkingDir: globalContext.workDir,
DataDir: Context.getDataDir(), DataDir: globalContext.getDataDir(),
}) })
var upgraded bool var upgraded bool
@ -644,27 +644,27 @@ func (c *configuration) write() (err error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if Context.auth != nil { if globalContext.auth != nil {
config.Users = Context.auth.usersList() config.Users = globalContext.auth.usersList()
} }
if Context.tls != nil { if globalContext.tls != nil {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf) globalContext.tls.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf config.TLS = tlsConf
} }
if Context.stats != nil { if globalContext.stats != nil {
statsConf := stats.Config{} statsConf := stats.Config{}
Context.stats.WriteDiskConfig(&statsConf) globalContext.stats.WriteDiskConfig(&statsConf)
config.Stats.Interval = timeutil.Duration(statsConf.Limit) config.Stats.Interval = timeutil.Duration(statsConf.Limit)
config.Stats.Enabled = statsConf.Enabled config.Stats.Enabled = statsConf.Enabled
config.Stats.Ignored = statsConf.Ignored.Values() config.Stats.Ignored = statsConf.Ignored.Values()
} }
if Context.queryLog != nil { if globalContext.queryLog != nil {
dc := querylog.Config{} dc := querylog.Config{}
Context.queryLog.WriteDiskConfig(&dc) globalContext.queryLog.WriteDiskConfig(&dc)
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
config.QueryLog.Enabled = dc.Enabled config.QueryLog.Enabled = dc.Enabled
config.QueryLog.FileEnabled = dc.FileEnabled config.QueryLog.FileEnabled = dc.FileEnabled
@ -673,14 +673,14 @@ func (c *configuration) write() (err error) {
config.QueryLog.Ignored = dc.Ignored.Values() config.QueryLog.Ignored = dc.Ignored.Values()
} }
if Context.filters != nil { if globalContext.filters != nil {
Context.filters.WriteDiskConfig(config.Filtering) globalContext.filters.WriteDiskConfig(config.Filtering)
config.Filters = config.Filtering.Filters config.Filters = config.Filtering.Filters
config.WhitelistFilters = config.Filtering.WhitelistFilters config.WhitelistFilters = config.Filtering.WhitelistFilters
config.UserRules = config.Filtering.UserRules config.UserRules = config.Filtering.UserRules
} }
if s := Context.dnsServer; s != nil { if s := globalContext.dnsServer; s != nil {
c := dnsforward.Config{} c := dnsforward.Config{}
s.WriteDiskConfig(&c) s.WriteDiskConfig(&c)
dns := &config.DNS dns := &config.DNS
@ -695,11 +695,11 @@ func (c *configuration) write() (err error) {
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout()) dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
} }
if Context.dhcpServer != nil { if globalContext.dhcpServer != nil {
Context.dhcpServer.WriteDiskConfig(config.DHCP) globalContext.dhcpServer.WriteDiskConfig(config.DHCP)
} }
config.Clients.Persistent = Context.clients.forConfig() config.Clients.Persistent = globalContext.clients.forConfig()
confPath := configFilePath() confPath := configFilePath()
log.Debug("writing config file %q", confPath) log.Debug("writing config file %q", confPath)
@ -726,14 +726,14 @@ func setContextTLSCipherIDs() (err error) {
if len(config.TLS.OverrideTLSCiphers) == 0 { if len(config.TLS.OverrideTLSCiphers) == 0 {
log.Info("tls: using default ciphers") log.Info("tls: using default ciphers")
Context.tlsCipherIDs = aghtls.SaferCipherSuites() globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
return nil return nil
} }
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers) log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
Context.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers) globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
if err != nil { if err != nil {
return fmt.Errorf("parsing override ciphers: %w", err) return fmt.Errorf("parsing override ciphers: %w", err)
} }

View File

@ -129,10 +129,10 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
protectionDisabledUntil *time.Time protectionDisabledUntil *time.Time
protectionEnabled bool protectionEnabled bool
) )
if Context.dnsServer != nil { if globalContext.dnsServer != nil {
fltConf = &dnsforward.Config{} fltConf = &dnsforward.Config{}
Context.dnsServer.WriteDiskConfig(fltConf) globalContext.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus() protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
} }
var resp statusResponse var resp statusResponse
@ -162,7 +162,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// IsDHCPAvailable field is now false by default for Windows. // IsDHCPAvailable field is now false by default for Windows.
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
resp.IsDHCPAvailable = Context.dhcpServer != nil resp.IsDHCPAvailable = globalContext.dhcpServer != nil
} }
aghhttp.WriteJSONResponseOK(w, r, resp) aghhttp.WriteJSONResponseOK(w, r, resp)
@ -172,7 +172,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// registration of handlers // registration of handlers
// ------------------------ // ------------------------
func registerControlHandlers(web *webAPI) { func registerControlHandlers(web *webAPI) {
Context.mux.HandleFunc( globalContext.mux.HandleFunc(
"/control/version.json", "/control/version.json",
postInstall(optionalAuth(web.handleVersionJSON)), postInstall(optionalAuth(web.handleVersionJSON)),
) )
@ -185,19 +185,19 @@ func registerControlHandlers(web *webAPI) {
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile) httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
// No auth is necessary for DoH/DoT configurations // No auth is necessary for DoH/DoT configurations
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH)) globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
Context.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT)) globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
RegisterAuthHandlers() RegisterAuthHandlers()
} }
func httpRegister(method, url string, handler http.HandlerFunc) { func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" { if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method // "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler)) globalContext.mux.HandleFunc(url, postInstall(handler))
return return
} }
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) globalContext.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
} }
// ensure returns a wrapped handler that makes sure that the request has the // ensure returns a wrapped handler that makes sure that the request has the
@ -223,8 +223,8 @@ func ensure(
return return
} }
Context.controlLock.Lock() globalContext.controlLock.Lock()
defer Context.controlLock.Unlock() defer globalContext.controlLock.Unlock()
} }
handler(w, r) handler(w, r)
@ -293,7 +293,7 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
// preInstall lets the handler run only if firstRun is true, no redirects // preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !Context.firstRun { if !globalContext.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done) // if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return return
@ -320,7 +320,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
// HTTPS-related headers. If proceed is true, the middleware must continue // HTTPS-related headers. If proceed is true, the middleware must continue
// handling the request. // handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) { func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
web := Context.web web := globalContext.web
if web.httpsServer.server == nil { if web.httpsServer.server == nil {
return true return true
} }
@ -409,7 +409,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path path := r.URL.Path
if Context.firstRun && !strings.HasPrefix(path, "/install.") && if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") { !strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound) http.Redirect(w, r, "install.html", http.StatusFound)

View File

@ -428,20 +428,20 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
curConfig := &configuration{} curConfig := &configuration{}
copyInstallSettings(curConfig, config) copyInstallSettings(curConfig, config)
Context.firstRun = false globalContext.firstRun = false
config.DNS.BindHosts = []netip.Addr{req.DNS.IP} config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port config.DNS.Port = req.DNS.Port
config.Filtering.SafeFSPatterns = []string{ config.Filtering.SafeFSPatterns = []string{
filepath.Join(Context.workDir, userFilterDataDir, "*"), filepath.Join(globalContext.workDir, userFilterDataDir, "*"),
} }
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port) config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
u := &webUser{ u := &webUser{
Name: req.Username, Name: req.Username,
} }
err = Context.auth.addUser(u, req.Password) err = globalContext.auth.addUser(u, req.Password)
if err != nil { if err != nil {
Context.firstRun = true globalContext.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err) aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
@ -454,7 +454,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// functions potentially restart the HTTPS server. // functions potentially restart the HTTPS server.
err = startMods(web.baseLogger) err = startMods(web.baseLogger)
if err != nil { if err != nil {
Context.firstRun = true globalContext.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@ -463,7 +463,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
err = config.write() err = config.write()
if err != nil { if err != nil {
Context.firstRun = true globalContext.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
@ -528,7 +528,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
} }
func (web *webAPI) registerInstallHandlers() { func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses))) globalContext.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig))) globalContext.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure))) globalContext.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
} }

View File

@ -165,7 +165,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
} }
tlsConf := &tlsConfigSettings{} tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf) globalContext.tls.WriteDiskConfig(tlsConf)
canUpdate := true canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || if tlsConfUsesPrivilegedPorts(tlsConf) ||

View File

@ -45,9 +45,9 @@ func onConfigModified() {
} }
} }
// initDNS updates all the fields of the [Context] needed to initialize the DNS // initDNS updates all the fields of the [globalContext] needed to initialize the DNS
// server and initializes it at last. It also must not be called unless // server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized. baseLogger must not be nil. // [config] and [globalContext] are initialized. baseLogger must not be nil.
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) { func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
anonymizer := config.anonymizer() anonymizer := config.anonymizer()
@ -58,7 +58,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
Enabled: config.Stats.Enabled, Enabled: config.Stats.Enabled,
ShouldCountClient: Context.clients.shouldCountClient, ShouldCountClient: globalContext.clients.shouldCountClient,
} }
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored) engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
@ -67,7 +67,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
} }
statsConf.Ignored = engine statsConf.Ignored = engine
Context.stats, err = stats.New(statsConf) globalContext.stats, err = stats.New(statsConf)
if err != nil { if err != nil {
return fmt.Errorf("init stats: %w", err) return fmt.Errorf("init stats: %w", err)
} }
@ -77,7 +77,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
Anonymizer: anonymizer, Anonymizer: anonymizer,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple, FindClient: globalContext.clients.findMultiple,
BaseDir: querylogDir, BaseDir: querylogDir,
AnonymizeClientIP: config.DNS.AnonymizeClientIP, AnonymizeClientIP: config.DNS.AnonymizeClientIP,
RotationIvl: time.Duration(config.QueryLog.Interval), RotationIvl: time.Duration(config.QueryLog.Interval),
@ -92,25 +92,25 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
} }
conf.Ignored = engine conf.Ignored = engine
Context.queryLog, err = querylog.New(conf) globalContext.queryLog, err = querylog.New(conf)
if err != nil { if err != nil {
return fmt.Errorf("init querylog: %w", err) return fmt.Errorf("init querylog: %w", err)
} }
Context.filters, err = filtering.New(config.Filtering, nil) globalContext.filters, err = filtering.New(config.Filtering, nil)
if err != nil { if err != nil {
// Don't wrap the error, since it's informative enough as is. // Don't wrap the error, since it's informative enough as is.
return err return err
} }
tlsConf := &tlsConfigSettings{} tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf) globalContext.tls.WriteDiskConfig(tlsConf)
return initDNSServer( return initDNSServer(
Context.filters, globalContext.filters,
Context.stats, globalContext.stats,
Context.queryLog, globalContext.queryLog,
Context.dhcpServer, globalContext.dhcpServer,
anonymizer, anonymizer,
httpRegister, httpRegister,
tlsConf, tlsConf,
@ -121,7 +121,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
// initDNSServer initializes the [context.dnsServer]. To only use the internal // initDNSServer initializes the [context.dnsServer]. To only use the internal
// proxy, none of the arguments are required, but tlsConf and l still must not // proxy, none of the arguments are required, but tlsConf and l still must not
// be nil, in other cases all the arguments also must not be nil. It also must // be nil, in other cases all the arguments also must not be nil. It also must
// not be called unless [config] and [Context] are initialized. // not be called unless [config] and [globalContext] are initialized.
// //
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter. // TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
func initDNSServer( func initDNSServer(
@ -134,7 +134,7 @@ func initDNSServer(
tlsConf *tlsConfigSettings, tlsConf *tlsConfigSettings,
l *slog.Logger, l *slog.Logger,
) (err error) { ) (err error) {
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
Logger: l, Logger: l,
DNSFilter: filters, DNSFilter: filters,
Stats: sts, Stats: sts,
@ -142,7 +142,7 @@ func initDNSServer(
PrivateNets: parseSubnetSet(config.DNS.PrivateNets), PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
Anonymizer: anonymizer, Anonymizer: anonymizer,
DHCPServer: dhcpSrv, DHCPServer: dhcpSrv,
EtcHosts: Context.etcHosts, EtcHosts: globalContext.etcHosts,
LocalDomain: config.DHCP.LocalDomainName, LocalDomain: config.DHCP.LocalDomainName,
}) })
defer func() { defer func() {
@ -154,7 +154,7 @@ func initDNSServer(
return fmt.Errorf("dnsforward.NewServer: %w", err) return fmt.Errorf("dnsforward.NewServer: %w", err)
} }
Context.clients.clientChecker = Context.dnsServer globalContext.clients.clientChecker = globalContext.dnsServer
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg) dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
if err != nil { if err != nil {
@ -163,12 +163,12 @@ func initDNSServer(
// Try to prepare the server with disabled private RDNS resolution if it // Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError]. // failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
err = Context.dnsServer.Prepare(dnsConf) err = globalContext.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) { if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err) log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
dnsConf.UsePrivateRDNS = false dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf) err = globalContext.dnsServer.Prepare(dnsConf)
} }
if err != nil { if err != nil {
@ -194,7 +194,7 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
} }
func isRunning() bool { func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning() return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
} }
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) { func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
@ -235,7 +235,7 @@ func newServerConfig(
fwdConf := dnsConf.Config fwdConf := dnsConf.Config
fwdConf.FilterHandler = applyAdditionalFiltering fwdConf.FilterHandler = applyAdditionalFiltering
fwdConf.ClientsContainer = &Context.clients fwdConf.ClientsContainer = &globalContext.clients
newConf = &dnsforward.ServerConfig{ newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port), UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
@ -244,7 +244,7 @@ func newServerConfig(
TLSConfig: newDNSTLSConfig(tlsConf, hosts), TLSConfig: newDNSTLSConfig(tlsConf, hosts),
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH, TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout), UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
TLSv12Roots: Context.tlsRoots, TLSv12Roots: globalContext.tlsRoots,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpReg, HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
@ -259,16 +259,16 @@ func newServerConfig(
var initialAddresses []netip.Addr var initialAddresses []netip.Addr
// Context.stats may be nil here if initDNSServer is called from // Context.stats may be nil here if initDNSServer is called from
// [cmdlineUpdate]. // [cmdlineUpdate].
if sts := Context.stats; sts != nil { if sts := globalContext.stats; sts != nil {
const initialClientsNum = 100 const initialClientsNum = 100
initialAddresses = Context.stats.TopClientsIP(initialClientsNum) initialAddresses = globalContext.stats.TopClientsIP(initialClientsNum)
} }
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they // Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
// are set by [dnsforward.Server.Prepare]. // are set by [dnsforward.Server.Prepare].
newConf.AddrProcConf = &client.DefaultAddrProcConfig{ newConf.AddrProcConf = &client.DefaultAddrProcConfig{
Exchanger: Context.dnsServer, Exchanger: globalContext.dnsServer,
AddressUpdater: &Context.clients, AddressUpdater: &globalContext.clients,
InitialAddresses: initialAddresses, InitialAddresses: initialAddresses,
CatchPanics: true, CatchPanics: true,
UseRDNS: clientSrcConf.RDNS, UseRDNS: clientSrcConf.RDNS,
@ -359,7 +359,7 @@ type dnsEncryption struct {
func getDNSEncryption() (de dnsEncryption) { func getDNSEncryption() (de dnsEncryption) {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf) globalContext.tls.WriteDiskConfig(&tlsConf)
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 { if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
return dnsEncryption{} return dnsEncryption{}
@ -402,7 +402,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
// pref is a prefix for logging messages around the scope. // pref is a prefix for logging messages around the scope.
const pref = "applying filters" const pref = "applying filters"
Context.filters.ApplyBlockedServices(setts) globalContext.filters.ApplyBlockedServices(setts)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID) log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
@ -412,9 +412,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ClientIP = clientIP setts.ClientIP = clientIP
c, ok := Context.clients.storage.Find(clientID) c, ok := globalContext.clients.storage.Find(clientID)
if !ok { if !ok {
c, ok = Context.clients.storage.Find(clientIP.String()) c, ok = globalContext.clients.storage.Find(clientIP.String())
if !ok { if !ok {
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
@ -429,7 +429,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ServicesRules = nil setts.ServicesRules = nil
svcs := c.BlockedServices.IDs svcs := c.BlockedServices.IDs
if !c.BlockedServices.Schedule.Contains(time.Now()) { if !c.BlockedServices.Schedule.Contains(time.Now()) {
Context.filters.ApplyBlockedServicesList(setts, svcs) globalContext.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
} }
} }
@ -455,24 +455,24 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running") return fmt.Errorf("unable to start forwarding DNS server: Already running")
} }
Context.filters.EnableFilters(false) globalContext.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context. // TODO(s.chzhen): Pass context.
ctx := context.TODO() ctx := context.TODO()
err := Context.clients.Start(ctx) err := globalContext.clients.Start(ctx)
if err != nil { if err != nil {
return fmt.Errorf("starting clients container: %w", err) return fmt.Errorf("starting clients container: %w", err)
} }
err = Context.dnsServer.Start() err = globalContext.dnsServer.Start()
if err != nil { if err != nil {
return fmt.Errorf("starting dns server: %w", err) return fmt.Errorf("starting dns server: %w", err)
} }
Context.filters.Start() globalContext.filters.Start()
Context.stats.Start() globalContext.stats.Start()
err = Context.queryLog.Start(ctx) err = globalContext.queryLog.Start(ctx)
if err != nil { if err != nil {
return fmt.Errorf("starting query log: %w", err) return fmt.Errorf("starting query log: %w", err)
} }
@ -482,14 +482,14 @@ func startDNSServer() error {
func reconfigureDNSServer() (err error) { func reconfigureDNSServer() (err error) {
tlsConf := &tlsConfigSettings{} tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf) globalContext.tls.WriteDiskConfig(tlsConf)
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister) newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
if err != nil { if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err) return fmt.Errorf("generating forwarding dns server config: %w", err)
} }
err = Context.dnsServer.Reconfigure(newConf) err = globalContext.dnsServer.Reconfigure(newConf)
if err != nil { if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err) return fmt.Errorf("starting forwarding dns server: %w", err)
} }
@ -502,12 +502,12 @@ func stopDNSServer() (err error) {
return nil return nil
} }
err = Context.dnsServer.Stop() err = globalContext.dnsServer.Stop()
if err != nil { if err != nil {
return fmt.Errorf("stopping forwarding dns server: %w", err) return fmt.Errorf("stopping forwarding dns server: %w", err)
} }
err = Context.clients.close(context.TODO()) err = globalContext.clients.close(context.TODO())
if err != nil { if err != nil {
return fmt.Errorf("closing clients container: %w", err) return fmt.Errorf("closing clients container: %w", err)
} }
@ -519,25 +519,25 @@ func stopDNSServer() (err error) {
func closeDNSServer() { func closeDNSServer() {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them // DNS forward module must be closed BEFORE stats or queryLog because it depends on them
if Context.dnsServer != nil { if globalContext.dnsServer != nil {
Context.dnsServer.Close() globalContext.dnsServer.Close()
Context.dnsServer = nil globalContext.dnsServer = nil
} }
if Context.filters != nil { if globalContext.filters != nil {
Context.filters.Close() globalContext.filters.Close()
} }
if Context.stats != nil { if globalContext.stats != nil {
err := Context.stats.Close() err := globalContext.stats.Close()
if err != nil { if err != nil {
log.Error("closing stats: %s", err) log.Error("closing stats: %s", err)
} }
} }
if Context.queryLog != nil { if globalContext.queryLog != nil {
// TODO(s.chzhen): Pass context. // TODO(s.chzhen): Pass context.
err := Context.queryLog.Shutdown(context.TODO()) err := globalContext.queryLog.Shutdown(context.TODO())
if err != nil { if err != nil {
log.Error("closing query log: %s", err) log.Error("closing query log: %s", err)
} }

View File

@ -37,14 +37,14 @@ func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage)
func TestApplyAdditionalFiltering(t *testing.T) { func TestApplyAdditionalFiltering(t *testing.T) {
var err error var err error
Context.filters, err = filtering.New(&filtering.Config{ globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{ BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(), Schedule: schedule.EmptyWeekly(),
}, },
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{ globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default", Name: "default",
ClientIDs: []string{"default"}, ClientIDs: []string{"default"},
UseOwnSettings: false, UseOwnSettings: false,
@ -124,7 +124,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
err error err error
) )
Context.filters, err = filtering.New(&filtering.Config{ globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{ BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(), Schedule: schedule.EmptyWeekly(),
IDs: globalBlockedServices, IDs: globalBlockedServices,
@ -132,7 +132,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{ globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default", Name: "default",
ClientIDs: []string{"default"}, ClientIDs: []string{"default"},
UseOwnBlockedServices: false, UseOwnBlockedServices: false,

View File

@ -91,10 +91,10 @@ func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir) return filepath.Join(c.workDir, dataDir)
} }
// Context - a global context object // globalContext is a global context object.
// //
// TODO(a.garipov): Refactor. // TODO(a.garipov): Refactor.
var Context homeContext var globalContext homeContext
// Main is the entry point // Main is the entry point
func Main(clientBuildFS fs.FS) { func Main(clientBuildFS fs.FS) {
@ -120,8 +120,8 @@ func Main(clientBuildFS fs.FS) {
log.Info("Received signal %q", sig) log.Info("Received signal %q", sig)
switch sig { switch sig {
case syscall.SIGHUP: case syscall.SIGHUP:
Context.clients.storage.ReloadARP(ctx) globalContext.clients.storage.ReloadARP(ctx)
Context.tls.reload() globalContext.tls.reload()
default: default:
cleanup(ctx) cleanup(ctx)
cleanupAlways() cleanupAlways()
@ -140,13 +140,13 @@ func Main(clientBuildFS fs.FS) {
run(opts, clientBuildFS, done) run(opts, clientBuildFS, done)
} }
// setupContext initializes [Context] fields. It also reads and upgrades // setupContext initializes [globalContext] fields. It also reads and upgrades
// config file if necessary. // config file if necessary.
func setupContext(opts options) (err error) { func setupContext(opts options) (err error) {
Context.firstRun = detectFirstRun() globalContext.firstRun = detectFirstRun()
Context.tlsRoots = aghtls.SystemRootCAs() globalContext.tlsRoots = aghtls.SystemRootCAs()
Context.mux = http.NewServeMux() globalContext.mux = http.NewServeMux()
if !opts.noEtcHosts { if !opts.noEtcHosts {
err = setupHostsContainer() err = setupHostsContainer()
@ -156,7 +156,7 @@ func setupContext(opts options) (err error) {
} }
} }
if Context.firstRun { if globalContext.firstRun {
log.Info("This is the first time AdGuard Home is launched") log.Info("This is the first time AdGuard Home is launched")
checkNetworkPermissions() checkNetworkPermissions()
@ -247,7 +247,7 @@ func setupHostsContainer() (err error) {
return fmt.Errorf("getting default system hosts paths: %w", err) return fmt.Errorf("getting default system hosts paths: %w", err)
} }
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...) globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
if err != nil { if err != nil {
closeErr := hostsWatcher.Close() closeErr := hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) { if errors.Is(err, aghnet.ErrNoHostsPaths) {
@ -271,7 +271,7 @@ func setupOpts(opts options) (err error) {
} }
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) { if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile globalContext.pidFileName = opts.pidFile
} }
return nil return nil
@ -286,13 +286,13 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
} }
//lint:ignore SA1019 Migration is not over. //lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = Context.workDir config.DHCP.WorkDir = globalContext.workDir
config.DHCP.DataDir = Context.getDataDir() config.DHCP.DataDir = globalContext.getDataDir()
config.DHCP.HTTPRegister = httpRegister config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified config.DHCP.ConfigModified = onConfigModified
Context.dhcpServer, err = dhcpd.Create(config.DHCP) globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
if Context.dhcpServer == nil || err != nil { if globalContext.dhcpServer == nil || err != nil {
// TODO(a.garipov): There are a lot of places in the code right // TODO(a.garipov): There are a lot of places in the code right
// now which assume that the DHCP server can be nil despite this // now which assume that the DHCP server can be nil despite this
// condition. Inspect them and perhaps rewrite them to use // condition. Inspect them and perhaps rewrite them to use
@ -305,12 +305,12 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb")) arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
} }
return Context.clients.Init( return globalContext.clients.Init(
ctx, ctx,
logger, logger,
config.Clients.Persistent, config.Clients.Persistent,
Context.dhcpServer, globalContext.dhcpServer,
Context.etcHosts, globalContext.etcHosts,
arpDB, arpDB,
config.Filtering, config.Filtering,
) )
@ -374,15 +374,15 @@ func setupDNSFilteringConf(
pcTXTSuffix = `pc.dns.adguard.com.` pcTXTSuffix = `pc.dns.adguard.com.`
) )
conf.EtcHosts = Context.etcHosts conf.EtcHosts = globalContext.etcHosts
// TODO(s.chzhen): Use empty interface. // TODO(s.chzhen): Use empty interface.
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled { if globalContext.etcHosts == nil || !config.DNS.HostsFileEnabled {
conf.EtcHosts = nil conf.EtcHosts = nil
} }
conf.ConfigModified = onConfigModified conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir() conf.DataDir = globalContext.getDataDir()
conf.Filters = slices.Clone(config.Filters) conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters) conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules) conf.UserRules = slices.Clone(config.UserRules)
@ -560,7 +560,7 @@ func initWeb(
ReadHeaderTimeout: readHdrTimeout, ReadHeaderTimeout: readHdrTimeout,
WriteTimeout: writeTimeout, WriteTimeout: writeTimeout,
firstRun: Context.firstRun, firstRun: globalContext.firstRun,
disableUpdate: disableUpdate, disableUpdate: disableUpdate,
runningAsService: opts.runningAsService, runningAsService: opts.runningAsService,
serveHTTP3: config.DNS.ServeHTTP3, serveHTTP3: config.DNS.ServeHTTP3,
@ -602,7 +602,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// Print the first message after logger is configured. // Print the first message after logger is configured.
log.Info(version.Full()) log.Info(version.Full())
log.Debug("current working directory is %s", Context.workDir) log.Debug("current working directory is %s", globalContext.workDir)
if opts.runningAsService { if opts.runningAsService {
log.Info("AdGuard Home is running as a service") log.Info("AdGuard Home is running as a service")
} }
@ -632,13 +632,13 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
confPath := configFilePath() confPath := configFilePath()
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config) upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
// TODO(e.burkov): This could be made earlier, probably as the option's // TODO(e.burkov): This could be made earlier, probably as the option's
// effect. // effect.
cmdlineUpdate(ctx, slogLogger, opts, upd) cmdlineUpdate(ctx, slogLogger, opts, upd)
if !Context.firstRun { if !globalContext.firstRun {
// Save the updated config. // Save the updated config.
err = config.write() err = config.write()
fatalOnError(err) fatalOnError(err)
@ -648,33 +648,33 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
} }
} }
dataDir := Context.getDataDir() dataDir := globalContext.getDataDir()
err = os.MkdirAll(dataDir, aghos.DefaultPermDir) err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir)) fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
GLMode = opts.glinetMode GLMode = opts.glinetMode
// Init auth module. // Init auth module.
Context.auth, err = initUsers() globalContext.auth, err = initUsers()
fatalOnError(err) fatalOnError(err)
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS) globalContext.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
if err != nil { if err != nil {
log.Error("initializing tls: %s", err) log.Error("initializing tls: %s", err)
onConfigModified() onConfigModified()
} }
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL) globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
fatalOnError(err) fatalOnError(err)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config) statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
fatalOnError(err) fatalOnError(err)
if !Context.firstRun { if !globalContext.firstRun {
err = initDNS(slogLogger, statsDir, querylogDir) err = initDNS(slogLogger, statsDir, querylogDir)
fatalOnError(err) fatalOnError(err)
Context.tls.start() globalContext.tls.start()
go func() { go func() {
startErr := startDNSServer() startErr := startDNSServer()
@ -684,8 +684,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
} }
}() }()
if Context.dhcpServer != nil { if globalContext.dhcpServer != nil {
err = Context.dhcpServer.Start() err = globalContext.dhcpServer.Start()
if err != nil { if err != nil {
log.Error("starting dhcp server: %s", err) log.Error("starting dhcp server: %s", err)
} }
@ -693,10 +693,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
} }
if !opts.noPermCheck { if !opts.noPermCheck {
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir) checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
} }
Context.web.start(ctx) globalContext.web.start(ctx)
// Wait for other goroutines to complete their job. // Wait for other goroutines to complete their job.
<-done <-done
@ -775,7 +775,7 @@ func checkPermissions(
// initUsers initializes context auth module. Clears config users field. // initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) { func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
@ -810,7 +810,7 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
// startMods initializes and starts the DNS server after installation. // startMods initializes and starts the DNS server after installation.
// baseLogger must not be nil. // baseLogger must not be nil.
func startMods(baseLogger *slog.Logger) (err error) { func startMods(baseLogger *slog.Logger) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config) statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
if err != nil { if err != nil {
return err return err
} }
@ -820,7 +820,7 @@ func startMods(baseLogger *slog.Logger) (err error) {
return err return err
} }
Context.tls.start() globalContext.tls.start()
err = startDNSServer() err = startDNSServer()
if err != nil { if err != nil {
@ -883,14 +883,14 @@ func writePIDFile(fn string) bool {
func initConfigFilename(opts options) { func initConfigFilename(opts options) {
confPath := opts.confFilename confPath := opts.confFilename
if confPath == "" { if confPath == "" {
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml") globalContext.confFilePath = filepath.Join(globalContext.workDir, "AdGuardHome.yaml")
return return
} }
log.Debug("config path overridden to %q from cmdline", confPath) log.Debug("config path overridden to %q from cmdline", confPath)
Context.confFilePath = confPath globalContext.confFilePath = confPath
} }
// initWorkingDir initializes the workDir. If no command-line arguments are // initWorkingDir initializes the workDir. If no command-line arguments are
@ -904,18 +904,18 @@ func initWorkingDir(opts options) (err error) {
if opts.workDir != "" { if opts.workDir != "" {
// If there is a custom config file, use it's directory as our working dir // If there is a custom config file, use it's directory as our working dir
Context.workDir = opts.workDir globalContext.workDir = opts.workDir
} else { } else {
Context.workDir = filepath.Dir(execPath) globalContext.workDir = filepath.Dir(execPath)
} }
workDir, err := filepath.EvalSymlinks(Context.workDir) workDir, err := filepath.EvalSymlinks(globalContext.workDir)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
} }
Context.workDir = workDir globalContext.workDir = workDir
return nil return nil
} }
@ -924,13 +924,13 @@ func initWorkingDir(opts options) (err error) {
func cleanup(ctx context.Context) { func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home") log.Info("stopping AdGuard Home")
if Context.web != nil { if globalContext.web != nil {
Context.web.close(ctx) globalContext.web.close(ctx)
Context.web = nil globalContext.web = nil
} }
if Context.auth != nil { if globalContext.auth != nil {
Context.auth.Close() globalContext.auth.Close()
Context.auth = nil globalContext.auth = nil
} }
err := stopDNSServer() err := stopDNSServer()
@ -938,28 +938,28 @@ func cleanup(ctx context.Context) {
log.Error("stopping dns server: %s", err) log.Error("stopping dns server: %s", err)
} }
if Context.dhcpServer != nil { if globalContext.dhcpServer != nil {
err = Context.dhcpServer.Stop() err = globalContext.dhcpServer.Stop()
if err != nil { if err != nil {
log.Error("stopping dhcp server: %s", err) log.Error("stopping dhcp server: %s", err)
} }
} }
if Context.etcHosts != nil { if globalContext.etcHosts != nil {
if err = Context.etcHosts.Close(); err != nil { if err = globalContext.etcHosts.Close(); err != nil {
log.Error("closing hosts container: %s", err) log.Error("closing hosts container: %s", err)
} }
} }
if Context.tls != nil { if globalContext.tls != nil {
Context.tls = nil globalContext.tls = nil
} }
} }
// This function is called before application exits // This function is called before application exits
func cleanupAlways() { func cleanupAlways() {
if len(Context.pidFileName) != 0 { if len(globalContext.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName) _ = os.Remove(globalContext.pidFileName)
} }
log.Info("stopped") log.Info("stopped")
@ -1007,8 +1007,8 @@ func printWebAddrs(proto, addr string, port uint16) {
// admin interface. proto is either schemeHTTP or schemeHTTPS. // admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) { func printHTTPAddresses(proto string) {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfigSettings{}
if Context.tls != nil { if globalContext.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf) globalContext.tls.WriteDiskConfig(&tlsConf)
} }
port := config.HTTPConfig.Address.Port() port := config.HTTPConfig.Address.Port()
@ -1050,9 +1050,9 @@ func printHTTPAddresses(proto string) {
// detectFirstRun returns true if this is the first run of AdGuard Home. // detectFirstRun returns true if this is the first run of AdGuard Home.
func detectFirstRun() (ok bool) { func detectFirstRun() (ok bool) {
confPath := Context.confFilePath confPath := globalContext.confFilePath
if !filepath.IsAbs(confPath) { if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, Context.confFilePath) confPath = filepath.Join(globalContext.workDir, globalContext.confFilePath)
} }
_, err := os.Stat(confPath) _, err := os.Stat(confPath)
@ -1105,7 +1105,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
os.Exit(osutil.ExitCodeSuccess) os.Exit(osutil.ExitCodeSuccess)
} }
err = upd.Update(Context.firstRun) err = upd.Update(globalContext.firstRun)
fatalOnError(err) fatalOnError(err)
err = restartService() err = restartService()

View File

@ -17,7 +17,7 @@ func httpClient() (c *http.Client) {
// Do not use Context.dnsServer.DialContext directly in the struct literal // Do not use Context.dnsServer.DialContext directly in the struct literal
// below, since Context.dnsServer may be nil when this function is called. // below, since Context.dnsServer may be nil when this function is called.
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) { dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return Context.dnsServer.DialContext(ctx, network, addr) return globalContext.dnsServer.DialContext(ctx, network, addr)
} }
return &http.Client{ return &http.Client{
@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
DialContext: dialContext, DialContext: dialContext,
Proxy: httpProxy, Proxy: httpProxy,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots, RootCAs: globalContext.tlsRoots,
CipherSuites: Context.tlsCipherIDs, CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
}, },

View File

@ -66,7 +66,7 @@ func configureLogger(ls *logSettings) (err error) {
logFilePath := ls.File logFilePath := ls.File
if !filepath.IsAbs(logFilePath) { if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath) logFilePath = filepath.Join(globalContext.workDir, logFilePath)
} }
log.SetOutput(&lumberjack.Logger{ log.SetOutput(&lumberjack.Logger{

View File

@ -19,10 +19,10 @@ func setupDNSIPs(t testing.TB) {
t.Helper() t.Helper()
prevConfig := config prevConfig := config
prevTLS := Context.tls prevTLS := globalContext.tls
t.Cleanup(func() { t.Cleanup(func() {
config = prevConfig config = prevConfig
Context.tls = prevTLS globalContext.tls = prevTLS
}) })
config = &configuration{ config = &configuration{
@ -32,7 +32,7 @@ func setupDNSIPs(t testing.TB) {
}, },
} }
Context.tls = &tlsManager{} globalContext.tls = &tlsManager{}
} }
func TestHandleMobileConfigDoH(t *testing.T) { func TestHandleMobileConfigDoH(t *testing.T) {
@ -62,10 +62,10 @@ func TestHandleMobileConfigDoH(t *testing.T) {
}) })
t.Run("error_no_host", func(t *testing.T) { t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := globalContext.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { globalContext.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}} globalContext.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)
@ -134,10 +134,10 @@ func TestHandleMobileConfigDoT(t *testing.T) {
}) })
t.Run("error_no_host", func(t *testing.T) { t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := globalContext.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { globalContext.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}} globalContext.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)

View File

@ -47,7 +47,7 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint. // handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) { func handleGetProfile(w http.ResponseWriter, r *http.Request) {
u := Context.auth.getCurrentUser(r) u := globalContext.auth.getCurrentUser(r)
var resp profileJSON var resp profileJSON
func() { func() {

View File

@ -112,7 +112,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.tlsConfigChanged(context.Background(), tlsConf) globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// reload updates the configuration and restarts t. // reload updates the configuration and restarts t.
@ -160,7 +160,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.tlsConfigChanged(context.Background(), tlsConf) globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// loadTLSConf loads and validates the TLS configuration. The returned error is // loadTLSConf loads and validates the TLS configuration. The returned error is
@ -463,7 +463,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason. // same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings) globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}() }()
} }
} }
@ -539,7 +539,7 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
DNSName: srvName, DNSName: srvName,
Roots: Context.tlsRoots, Roots: globalContext.tlsRoots,
Intermediates: pool, Intermediates: pool,
} }
_, err = main.Verify(opts) _, err = main.Verify(opts)

View File

@ -129,7 +129,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
clientFS := http.FileServer(http.FS(conf.clientFS)) clientFS := http.FileServer(http.FS(conf.clientFS))
// if not configured, redirect / to /install.html, otherwise redirect /install.html to / // if not configured, redirect / to /install.html, otherwise redirect /install.html to /
Context.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler)) globalContext.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
// add handlers for /install paths, we only need them when we're not configured yet // add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun { if conf.firstRun {
@ -138,7 +138,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
"This is the first launch of AdGuard Home, redirecting everything to /install.html", "This is the first launch of AdGuard Home, redirecting everything to /install.html",
) )
Context.mux.Handle("/install.html", preInstallHandler(clientFS)) globalContext.mux.Handle("/install.html", preInstallHandler(clientFS))
w.registerInstallHandlers() w.registerInstallHandlers()
} else { } else {
registerControlHandlers(w) registerControlHandlers(w)
@ -154,7 +154,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
// //
// TODO(a.garipov): Adapt for HTTP/3. // TODO(a.garipov): Adapt for HTTP/3.
func webCheckPortAvailable(port uint16) (ok bool) { func webCheckPortAvailable(port uint16) (ok bool) {
if Context.web.httpsServer.server != nil { if globalContext.web.httpsServer.server != nil {
return true return true
} }
@ -224,7 +224,7 @@ func (web *webAPI) start(ctx context.Context) {
errs := make(chan error, 2) errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{}) hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
logger := web.baseLogger.With(loggerKeyServer, "plain") logger := web.baseLogger.With(loggerKeyServer, "plain")
@ -307,11 +307,11 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
web.httpsServer.server = &http.Server{ web.httpsServer.server = &http.Server{
Addr: addr, Addr: addr,
Handler: withMiddlewares(Context.mux, limitRequestBody), Handler: withMiddlewares(globalContext.mux, limitRequestBody),
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert}, Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots, RootCAs: globalContext.tlsRoots,
CipherSuites: Context.tlsCipherIDs, CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
@ -344,11 +344,11 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
Addr: address, Addr: address,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert}, Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots, RootCAs: globalContext.tlsRoots,
CipherSuites: Context.tlsCipherIDs, CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
Handler: withMiddlewares(Context.mux, limitRequestBody), Handler: withMiddlewares(globalContext.mux, limitRequestBody),
} }
web.logger.DebugContext(ctx, "starting http/3 server") web.logger.DebugContext(ctx, "starting http/3 server")