mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-02-20 11:44:09 +08:00
Pull request 2347: AGDNS-2690-global-context
Some checks failed
build / test (macOS-latest) (push) Has been cancelled
build / test (ubuntu-latest) (push) Has been cancelled
build / test (windows-latest) (push) Has been cancelled
lint / go-lint (push) Has been cancelled
lint / eslint (push) Has been cancelled
build / build-release (push) Has been cancelled
build / notify (push) Has been cancelled
lint / notify (push) Has been cancelled
Some checks failed
build / test (macOS-latest) (push) Has been cancelled
build / test (ubuntu-latest) (push) Has been cancelled
build / test (windows-latest) (push) Has been cancelled
lint / go-lint (push) Has been cancelled
lint / eslint (push) Has been cancelled
build / build-release (push) Has been cancelled
build / notify (push) Has been cancelled
lint / notify (push) Has been cancelled
Merge in DNS/adguard-home from AGDNS-2690-global-context to master Squashed commit of the following: commit 58d5999e5d9112b3391f988ed76e87eff2919d6b Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 19 18:51:41 2025 +0300 home: imp naming commit cfb371df59c816be1022d499cc41ffaf2b72d124 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 19 18:42:52 2025 +0300 home: global context
This commit is contained in:
parent
a5b073d070
commit
1e0873aa71
@ -356,7 +356,7 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
|||||||
// There's no Cookie, check Basic authentication.
|
// There's no Cookie, check Basic authentication.
|
||||||
user, pass, ok := r.BasicAuth()
|
user, pass, ok := r.BasicAuth()
|
||||||
if ok {
|
if ok {
|
||||||
u, _ = Context.auth.findUser(user, pass)
|
u, _ = globalContext.auth.findUser(user, pass)
|
||||||
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
@ -155,7 +155,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
|
if rateLimiter := globalContext.auth.rateLimiter; rateLimiter != nil {
|
||||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||||
writeErrorWithIP(
|
writeErrorWithIP(
|
||||||
@ -176,10 +176,10 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
cookie, err := globalContext.auth.newCookie(req, remoteIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logIP := remoteIP
|
logIP := remoteIP
|
||||||
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
|
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
|
||||||
logIP = ip.String()
|
logIP = ip.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.auth.removeSession(c.Value)
|
globalContext.auth.removeSession(c.Value)
|
||||||
|
|
||||||
c = &http.Cookie{
|
c = &http.Cookie{
|
||||||
Name: sessionCookieName,
|
Name: sessionCookieName,
|
||||||
@ -232,7 +232,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// RegisterAuthHandlers - register handlers
|
// RegisterAuthHandlers - register handlers
|
||||||
func RegisterAuthHandlers() {
|
func RegisterAuthHandlers() {
|
||||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
globalContext.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,13 +254,13 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
|||||||
// Check Basic authentication.
|
// Check Basic authentication.
|
||||||
user, pass, hasBasic := r.BasicAuth()
|
user, pass, hasBasic := r.BasicAuth()
|
||||||
if hasBasic {
|
if hasBasic {
|
||||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
_, isAuthenticated = globalContext.auth.findUser(user, pass)
|
||||||
if !isAuthenticated {
|
if !isAuthenticated {
|
||||||
log.Info("%s: invalid basic authorization value", pref)
|
log.Info("%s: invalid basic authorization value", pref)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
res := Context.auth.checkSession(cookie.Value)
|
res := globalContext.auth.checkSession(cookie.Value)
|
||||||
isAuthenticated = res == checkSessionOK
|
isAuthenticated = res == checkSessionOK
|
||||||
if !isAuthenticated {
|
if !isAuthenticated {
|
||||||
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
||||||
@ -294,12 +294,12 @@ func optionalAuth(
|
|||||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
p := r.URL.Path
|
p := r.URL.Path
|
||||||
authRequired := Context.auth != nil && Context.auth.authRequired()
|
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
|
||||||
if p == "/login.html" {
|
if p == "/login.html" {
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
if authRequired && err == nil {
|
if authRequired && err == nil {
|
||||||
// Redirect to the dashboard if already authenticated.
|
// Redirect to the dashboard if already authenticated.
|
||||||
res := Context.auth.checkSession(cookie.Value)
|
res := globalContext.auth.checkSession(cookie.Value)
|
||||||
if res == checkSessionOK {
|
if res == checkSessionOK {
|
||||||
http.Redirect(w, r, "", http.StatusFound)
|
http.Redirect(w, r, "", http.StatusFound)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||||||
users := []webUser{
|
users := []webUser{
|
||||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||||
}
|
}
|
||||||
Context.auth = InitAuth(fn, users, 60, nil, nil)
|
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
@ -68,7 +68,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||||||
assert.True(t, handlerCalled)
|
assert.True(t, handlerCalled)
|
||||||
|
|
||||||
// perform login
|
// perform login
|
||||||
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, cookie)
|
require.NotNil(t, cookie)
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||||||
assert.True(t, handlerCalled)
|
assert.True(t, handlerCalled)
|
||||||
r.Header.Del(httphdr.Cookie)
|
r.Header.Del(httphdr.Cookie)
|
||||||
|
|
||||||
Context.auth.Close()
|
globalContext.auth.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRealIP(t *testing.T) {
|
func TestRealIP(t *testing.T) {
|
||||||
|
@ -486,9 +486,9 @@ var config = &configuration{
|
|||||||
// configFilePath returns the absolute path to the symlink-evaluated path to the
|
// configFilePath returns the absolute path to the symlink-evaluated path to the
|
||||||
// current config file.
|
// current config file.
|
||||||
func configFilePath() (confPath string) {
|
func configFilePath() (confPath string) {
|
||||||
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
|
confPath, err := filepath.EvalSymlinks(globalContext.confFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
confPath = Context.confFilePath
|
confPath = globalContext.confFilePath
|
||||||
logFunc := log.Error
|
logFunc := log.Error
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
logFunc = log.Debug
|
logFunc = log.Debug
|
||||||
@ -498,7 +498,7 @@ func configFilePath() (confPath string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !filepath.IsAbs(confPath) {
|
if !filepath.IsAbs(confPath) {
|
||||||
confPath = filepath.Join(Context.workDir, confPath)
|
confPath = filepath.Join(globalContext.workDir, confPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
return confPath
|
return confPath
|
||||||
@ -530,8 +530,8 @@ func parseConfig() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
migrator := configmigrate.New(&configmigrate.Config{
|
migrator := configmigrate.New(&configmigrate.Config{
|
||||||
WorkingDir: Context.workDir,
|
WorkingDir: globalContext.workDir,
|
||||||
DataDir: Context.getDataDir(),
|
DataDir: globalContext.getDataDir(),
|
||||||
})
|
})
|
||||||
|
|
||||||
var upgraded bool
|
var upgraded bool
|
||||||
@ -644,27 +644,27 @@ func (c *configuration) write() (err error) {
|
|||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
if Context.auth != nil {
|
if globalContext.auth != nil {
|
||||||
config.Users = Context.auth.usersList()
|
config.Users = globalContext.auth.usersList()
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.tls != nil {
|
if globalContext.tls != nil {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfigSettings{}
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
globalContext.tls.WriteDiskConfig(&tlsConf)
|
||||||
config.TLS = tlsConf
|
config.TLS = tlsConf
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.stats != nil {
|
if globalContext.stats != nil {
|
||||||
statsConf := stats.Config{}
|
statsConf := stats.Config{}
|
||||||
Context.stats.WriteDiskConfig(&statsConf)
|
globalContext.stats.WriteDiskConfig(&statsConf)
|
||||||
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
|
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
|
||||||
config.Stats.Enabled = statsConf.Enabled
|
config.Stats.Enabled = statsConf.Enabled
|
||||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.queryLog != nil {
|
if globalContext.queryLog != nil {
|
||||||
dc := querylog.Config{}
|
dc := querylog.Config{}
|
||||||
Context.queryLog.WriteDiskConfig(&dc)
|
globalContext.queryLog.WriteDiskConfig(&dc)
|
||||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||||
config.QueryLog.Enabled = dc.Enabled
|
config.QueryLog.Enabled = dc.Enabled
|
||||||
config.QueryLog.FileEnabled = dc.FileEnabled
|
config.QueryLog.FileEnabled = dc.FileEnabled
|
||||||
@ -673,14 +673,14 @@ func (c *configuration) write() (err error) {
|
|||||||
config.QueryLog.Ignored = dc.Ignored.Values()
|
config.QueryLog.Ignored = dc.Ignored.Values()
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.filters != nil {
|
if globalContext.filters != nil {
|
||||||
Context.filters.WriteDiskConfig(config.Filtering)
|
globalContext.filters.WriteDiskConfig(config.Filtering)
|
||||||
config.Filters = config.Filtering.Filters
|
config.Filters = config.Filtering.Filters
|
||||||
config.WhitelistFilters = config.Filtering.WhitelistFilters
|
config.WhitelistFilters = config.Filtering.WhitelistFilters
|
||||||
config.UserRules = config.Filtering.UserRules
|
config.UserRules = config.Filtering.UserRules
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := Context.dnsServer; s != nil {
|
if s := globalContext.dnsServer; s != nil {
|
||||||
c := dnsforward.Config{}
|
c := dnsforward.Config{}
|
||||||
s.WriteDiskConfig(&c)
|
s.WriteDiskConfig(&c)
|
||||||
dns := &config.DNS
|
dns := &config.DNS
|
||||||
@ -695,11 +695,11 @@ func (c *configuration) write() (err error) {
|
|||||||
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
|
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.dhcpServer != nil {
|
if globalContext.dhcpServer != nil {
|
||||||
Context.dhcpServer.WriteDiskConfig(config.DHCP)
|
globalContext.dhcpServer.WriteDiskConfig(config.DHCP)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Clients.Persistent = Context.clients.forConfig()
|
config.Clients.Persistent = globalContext.clients.forConfig()
|
||||||
|
|
||||||
confPath := configFilePath()
|
confPath := configFilePath()
|
||||||
log.Debug("writing config file %q", confPath)
|
log.Debug("writing config file %q", confPath)
|
||||||
@ -726,14 +726,14 @@ func setContextTLSCipherIDs() (err error) {
|
|||||||
if len(config.TLS.OverrideTLSCiphers) == 0 {
|
if len(config.TLS.OverrideTLSCiphers) == 0 {
|
||||||
log.Info("tls: using default ciphers")
|
log.Info("tls: using default ciphers")
|
||||||
|
|
||||||
Context.tlsCipherIDs = aghtls.SaferCipherSuites()
|
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
|
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
|
||||||
|
|
||||||
Context.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing override ciphers: %w", err)
|
return fmt.Errorf("parsing override ciphers: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -129,10 +129,10 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
protectionDisabledUntil *time.Time
|
protectionDisabledUntil *time.Time
|
||||||
protectionEnabled bool
|
protectionEnabled bool
|
||||||
)
|
)
|
||||||
if Context.dnsServer != nil {
|
if globalContext.dnsServer != nil {
|
||||||
fltConf = &dnsforward.Config{}
|
fltConf = &dnsforward.Config{}
|
||||||
Context.dnsServer.WriteDiskConfig(fltConf)
|
globalContext.dnsServer.WriteDiskConfig(fltConf)
|
||||||
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
|
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp statusResponse
|
var resp statusResponse
|
||||||
@ -162,7 +162,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// IsDHCPAvailable field is now false by default for Windows.
|
// IsDHCPAvailable field is now false by default for Windows.
|
||||||
if runtime.GOOS != "windows" {
|
if runtime.GOOS != "windows" {
|
||||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||||
@ -172,7 +172,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
// registration of handlers
|
// registration of handlers
|
||||||
// ------------------------
|
// ------------------------
|
||||||
func registerControlHandlers(web *webAPI) {
|
func registerControlHandlers(web *webAPI) {
|
||||||
Context.mux.HandleFunc(
|
globalContext.mux.HandleFunc(
|
||||||
"/control/version.json",
|
"/control/version.json",
|
||||||
postInstall(optionalAuth(web.handleVersionJSON)),
|
postInstall(optionalAuth(web.handleVersionJSON)),
|
||||||
)
|
)
|
||||||
@ -185,19 +185,19 @@ func registerControlHandlers(web *webAPI) {
|
|||||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||||
|
|
||||||
// No auth is necessary for DoH/DoT configurations
|
// No auth is necessary for DoH/DoT configurations
|
||||||
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
||||||
Context.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
|
globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
|
||||||
RegisterAuthHandlers()
|
RegisterAuthHandlers()
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||||
if method == "" {
|
if method == "" {
|
||||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||||
Context.mux.HandleFunc(url, postInstall(handler))
|
globalContext.mux.HandleFunc(url, postInstall(handler))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
globalContext.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure returns a wrapped handler that makes sure that the request has the
|
// ensure returns a wrapped handler that makes sure that the request has the
|
||||||
@ -223,8 +223,8 @@ func ensure(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.controlLock.Lock()
|
globalContext.controlLock.Lock()
|
||||||
defer Context.controlLock.Unlock()
|
defer globalContext.controlLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
@ -293,7 +293,7 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
|
|||||||
// preInstall lets the handler run only if firstRun is true, no redirects
|
// preInstall lets the handler run only if firstRun is true, no redirects
|
||||||
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !Context.firstRun {
|
if !globalContext.firstRun {
|
||||||
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
||||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
@ -320,7 +320,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
|||||||
// HTTPS-related headers. If proceed is true, the middleware must continue
|
// HTTPS-related headers. If proceed is true, the middleware must continue
|
||||||
// handling the request.
|
// handling the request.
|
||||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
web := Context.web
|
web := globalContext.web
|
||||||
if web.httpsServer.server == nil {
|
if web.httpsServer.server == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -409,7 +409,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
|
|||||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
|
if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
|
||||||
!strings.HasPrefix(path, "/assets/") {
|
!strings.HasPrefix(path, "/assets/") {
|
||||||
http.Redirect(w, r, "install.html", http.StatusFound)
|
http.Redirect(w, r, "install.html", http.StatusFound)
|
||||||
|
|
||||||
|
@ -428,20 +428,20 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
|||||||
curConfig := &configuration{}
|
curConfig := &configuration{}
|
||||||
copyInstallSettings(curConfig, config)
|
copyInstallSettings(curConfig, config)
|
||||||
|
|
||||||
Context.firstRun = false
|
globalContext.firstRun = false
|
||||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||||
config.DNS.Port = req.DNS.Port
|
config.DNS.Port = req.DNS.Port
|
||||||
config.Filtering.SafeFSPatterns = []string{
|
config.Filtering.SafeFSPatterns = []string{
|
||||||
filepath.Join(Context.workDir, userFilterDataDir, "*"),
|
filepath.Join(globalContext.workDir, userFilterDataDir, "*"),
|
||||||
}
|
}
|
||||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||||
|
|
||||||
u := &webUser{
|
u := &webUser{
|
||||||
Name: req.Username,
|
Name: req.Username,
|
||||||
}
|
}
|
||||||
err = Context.auth.addUser(u, req.Password)
|
err = globalContext.auth.addUser(u, req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Context.firstRun = true
|
globalContext.firstRun = true
|
||||||
copyInstallSettings(config, curConfig)
|
copyInstallSettings(config, curConfig)
|
||||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
|
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
|
||||||
|
|
||||||
@ -454,7 +454,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
|||||||
// functions potentially restart the HTTPS server.
|
// functions potentially restart the HTTPS server.
|
||||||
err = startMods(web.baseLogger)
|
err = startMods(web.baseLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Context.firstRun = true
|
globalContext.firstRun = true
|
||||||
copyInstallSettings(config, curConfig)
|
copyInstallSettings(config, curConfig)
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||||
|
|
||||||
@ -463,7 +463,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
|||||||
|
|
||||||
err = config.write()
|
err = config.write()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Context.firstRun = true
|
globalContext.firstRun = true
|
||||||
copyInstallSettings(config, curConfig)
|
copyInstallSettings(config, curConfig)
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||||
|
|
||||||
@ -528,7 +528,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (web *webAPI) registerInstallHandlers() {
|
func (web *webAPI) registerInstallHandlers() {
|
||||||
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
globalContext.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||||
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
globalContext.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||||
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
globalContext.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||||
}
|
}
|
||||||
|
@ -165,7 +165,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfigSettings{}
|
||||||
Context.tls.WriteDiskConfig(tlsConf)
|
globalContext.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
canUpdate := true
|
canUpdate := true
|
||||||
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
||||||
|
@ -45,9 +45,9 @@ func onConfigModified() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
// initDNS updates all the fields of the [globalContext] needed to initialize the DNS
|
||||||
// server and initializes it at last. It also must not be called unless
|
// server and initializes it at last. It also must not be called unless
|
||||||
// [config] and [Context] are initialized. baseLogger must not be nil.
|
// [config] and [globalContext] are initialized. baseLogger must not be nil.
|
||||||
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||||
anonymizer := config.anonymizer()
|
anonymizer := config.anonymizer()
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
|||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
Enabled: config.Stats.Enabled,
|
Enabled: config.Stats.Enabled,
|
||||||
ShouldCountClient: Context.clients.shouldCountClient,
|
ShouldCountClient: globalContext.clients.shouldCountClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
|
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
|
||||||
@ -67,7 +67,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
statsConf.Ignored = engine
|
statsConf.Ignored = engine
|
||||||
Context.stats, err = stats.New(statsConf)
|
globalContext.stats, err = stats.New(statsConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init stats: %w", err)
|
return fmt.Errorf("init stats: %w", err)
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
|||||||
Anonymizer: anonymizer,
|
Anonymizer: anonymizer,
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
FindClient: Context.clients.findMultiple,
|
FindClient: globalContext.clients.findMultiple,
|
||||||
BaseDir: querylogDir,
|
BaseDir: querylogDir,
|
||||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||||
RotationIvl: time.Duration(config.QueryLog.Interval),
|
RotationIvl: time.Duration(config.QueryLog.Interval),
|
||||||
@ -92,25 +92,25 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
conf.Ignored = engine
|
conf.Ignored = engine
|
||||||
Context.queryLog, err = querylog.New(conf)
|
globalContext.queryLog, err = querylog.New(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init querylog: %w", err)
|
return fmt.Errorf("init querylog: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters, err = filtering.New(config.Filtering, nil)
|
globalContext.filters, err = filtering.New(config.Filtering, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, since it's informative enough as is.
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfigSettings{}
|
||||||
Context.tls.WriteDiskConfig(tlsConf)
|
globalContext.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
return initDNSServer(
|
return initDNSServer(
|
||||||
Context.filters,
|
globalContext.filters,
|
||||||
Context.stats,
|
globalContext.stats,
|
||||||
Context.queryLog,
|
globalContext.queryLog,
|
||||||
Context.dhcpServer,
|
globalContext.dhcpServer,
|
||||||
anonymizer,
|
anonymizer,
|
||||||
httpRegister,
|
httpRegister,
|
||||||
tlsConf,
|
tlsConf,
|
||||||
@ -121,7 +121,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
|||||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||||
// proxy, none of the arguments are required, but tlsConf and l still must not
|
// proxy, none of the arguments are required, but tlsConf and l still must not
|
||||||
// be nil, in other cases all the arguments also must not be nil. It also must
|
// be nil, in other cases all the arguments also must not be nil. It also must
|
||||||
// not be called unless [config] and [Context] are initialized.
|
// not be called unless [config] and [globalContext] are initialized.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||||
func initDNSServer(
|
func initDNSServer(
|
||||||
@ -134,7 +134,7 @@ func initDNSServer(
|
|||||||
tlsConf *tlsConfigSettings,
|
tlsConf *tlsConfigSettings,
|
||||||
l *slog.Logger,
|
l *slog.Logger,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||||
Logger: l,
|
Logger: l,
|
||||||
DNSFilter: filters,
|
DNSFilter: filters,
|
||||||
Stats: sts,
|
Stats: sts,
|
||||||
@ -142,7 +142,7 @@ func initDNSServer(
|
|||||||
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
|
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
|
||||||
Anonymizer: anonymizer,
|
Anonymizer: anonymizer,
|
||||||
DHCPServer: dhcpSrv,
|
DHCPServer: dhcpSrv,
|
||||||
EtcHosts: Context.etcHosts,
|
EtcHosts: globalContext.etcHosts,
|
||||||
LocalDomain: config.DHCP.LocalDomainName,
|
LocalDomain: config.DHCP.LocalDomainName,
|
||||||
})
|
})
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -154,7 +154,7 @@ func initDNSServer(
|
|||||||
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
return fmt.Errorf("dnsforward.NewServer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.clients.clientChecker = Context.dnsServer
|
globalContext.clients.clientChecker = globalContext.dnsServer
|
||||||
|
|
||||||
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -163,12 +163,12 @@ func initDNSServer(
|
|||||||
|
|
||||||
// Try to prepare the server with disabled private RDNS resolution if it
|
// Try to prepare the server with disabled private RDNS resolution if it
|
||||||
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
|
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
|
||||||
err = Context.dnsServer.Prepare(dnsConf)
|
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||||
|
|
||||||
dnsConf.UsePrivateRDNS = false
|
dnsConf.UsePrivateRDNS = false
|
||||||
err = Context.dnsServer.Prepare(dnsConf)
|
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -194,7 +194,7 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isRunning() bool {
|
func isRunning() bool {
|
||||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
|
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
|
||||||
@ -235,7 +235,7 @@ func newServerConfig(
|
|||||||
|
|
||||||
fwdConf := dnsConf.Config
|
fwdConf := dnsConf.Config
|
||||||
fwdConf.FilterHandler = applyAdditionalFiltering
|
fwdConf.FilterHandler = applyAdditionalFiltering
|
||||||
fwdConf.ClientsContainer = &Context.clients
|
fwdConf.ClientsContainer = &globalContext.clients
|
||||||
|
|
||||||
newConf = &dnsforward.ServerConfig{
|
newConf = &dnsforward.ServerConfig{
|
||||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||||
@ -244,7 +244,7 @@ func newServerConfig(
|
|||||||
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
||||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||||
TLSv12Roots: Context.tlsRoots,
|
TLSv12Roots: globalContext.tlsRoots,
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpReg,
|
HTTPRegister: httpReg,
|
||||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||||
@ -259,16 +259,16 @@ func newServerConfig(
|
|||||||
var initialAddresses []netip.Addr
|
var initialAddresses []netip.Addr
|
||||||
// Context.stats may be nil here if initDNSServer is called from
|
// Context.stats may be nil here if initDNSServer is called from
|
||||||
// [cmdlineUpdate].
|
// [cmdlineUpdate].
|
||||||
if sts := Context.stats; sts != nil {
|
if sts := globalContext.stats; sts != nil {
|
||||||
const initialClientsNum = 100
|
const initialClientsNum = 100
|
||||||
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
|
initialAddresses = globalContext.stats.TopClientsIP(initialClientsNum)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
|
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
|
||||||
// are set by [dnsforward.Server.Prepare].
|
// are set by [dnsforward.Server.Prepare].
|
||||||
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
|
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
|
||||||
Exchanger: Context.dnsServer,
|
Exchanger: globalContext.dnsServer,
|
||||||
AddressUpdater: &Context.clients,
|
AddressUpdater: &globalContext.clients,
|
||||||
InitialAddresses: initialAddresses,
|
InitialAddresses: initialAddresses,
|
||||||
CatchPanics: true,
|
CatchPanics: true,
|
||||||
UseRDNS: clientSrcConf.RDNS,
|
UseRDNS: clientSrcConf.RDNS,
|
||||||
@ -359,7 +359,7 @@ type dnsEncryption struct {
|
|||||||
func getDNSEncryption() (de dnsEncryption) {
|
func getDNSEncryption() (de dnsEncryption) {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfigSettings{}
|
||||||
|
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
globalContext.tls.WriteDiskConfig(&tlsConf)
|
||||||
|
|
||||||
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
|
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
|
||||||
return dnsEncryption{}
|
return dnsEncryption{}
|
||||||
@ -402,7 +402,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
|||||||
// pref is a prefix for logging messages around the scope.
|
// pref is a prefix for logging messages around the scope.
|
||||||
const pref = "applying filters"
|
const pref = "applying filters"
|
||||||
|
|
||||||
Context.filters.ApplyBlockedServices(setts)
|
globalContext.filters.ApplyBlockedServices(setts)
|
||||||
|
|
||||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||||
|
|
||||||
@ -412,9 +412,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
|||||||
|
|
||||||
setts.ClientIP = clientIP
|
setts.ClientIP = clientIP
|
||||||
|
|
||||||
c, ok := Context.clients.storage.Find(clientID)
|
c, ok := globalContext.clients.storage.Find(clientID)
|
||||||
if !ok {
|
if !ok {
|
||||||
c, ok = Context.clients.storage.Find(clientIP.String())
|
c, ok = globalContext.clients.storage.Find(clientIP.String())
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||||
|
|
||||||
@ -429,7 +429,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
|||||||
setts.ServicesRules = nil
|
setts.ServicesRules = nil
|
||||||
svcs := c.BlockedServices.IDs
|
svcs := c.BlockedServices.IDs
|
||||||
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
||||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
globalContext.filters.ApplyBlockedServicesList(setts, svcs)
|
||||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -455,24 +455,24 @@ func startDNSServer() error {
|
|||||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters.EnableFilters(false)
|
globalContext.filters.EnableFilters(false)
|
||||||
|
|
||||||
// TODO(s.chzhen): Pass context.
|
// TODO(s.chzhen): Pass context.
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
err := Context.clients.Start(ctx)
|
err := globalContext.clients.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting clients container: %w", err)
|
return fmt.Errorf("starting clients container: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.dnsServer.Start()
|
err = globalContext.dnsServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting dns server: %w", err)
|
return fmt.Errorf("starting dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters.Start()
|
globalContext.filters.Start()
|
||||||
Context.stats.Start()
|
globalContext.stats.Start()
|
||||||
|
|
||||||
err = Context.queryLog.Start(ctx)
|
err = globalContext.queryLog.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting query log: %w", err)
|
return fmt.Errorf("starting query log: %w", err)
|
||||||
}
|
}
|
||||||
@ -482,14 +482,14 @@ func startDNSServer() error {
|
|||||||
|
|
||||||
func reconfigureDNSServer() (err error) {
|
func reconfigureDNSServer() (err error) {
|
||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfigSettings{}
|
||||||
Context.tls.WriteDiskConfig(tlsConf)
|
globalContext.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
|
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.dnsServer.Reconfigure(newConf)
|
err = globalContext.dnsServer.Reconfigure(newConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting forwarding dns server: %w", err)
|
return fmt.Errorf("starting forwarding dns server: %w", err)
|
||||||
}
|
}
|
||||||
@ -502,12 +502,12 @@ func stopDNSServer() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.dnsServer.Stop()
|
err = globalContext.dnsServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.clients.close(context.TODO())
|
err = globalContext.clients.close(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("closing clients container: %w", err)
|
return fmt.Errorf("closing clients container: %w", err)
|
||||||
}
|
}
|
||||||
@ -519,25 +519,25 @@ func stopDNSServer() (err error) {
|
|||||||
|
|
||||||
func closeDNSServer() {
|
func closeDNSServer() {
|
||||||
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
||||||
if Context.dnsServer != nil {
|
if globalContext.dnsServer != nil {
|
||||||
Context.dnsServer.Close()
|
globalContext.dnsServer.Close()
|
||||||
Context.dnsServer = nil
|
globalContext.dnsServer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.filters != nil {
|
if globalContext.filters != nil {
|
||||||
Context.filters.Close()
|
globalContext.filters.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.stats != nil {
|
if globalContext.stats != nil {
|
||||||
err := Context.stats.Close()
|
err := globalContext.stats.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("closing stats: %s", err)
|
log.Error("closing stats: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.queryLog != nil {
|
if globalContext.queryLog != nil {
|
||||||
// TODO(s.chzhen): Pass context.
|
// TODO(s.chzhen): Pass context.
|
||||||
err := Context.queryLog.Shutdown(context.TODO())
|
err := globalContext.queryLog.Shutdown(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("closing query log: %s", err)
|
log.Error("closing query log: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -37,14 +37,14 @@ func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage)
|
|||||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
Context.filters, err = filtering.New(&filtering.Config{
|
globalContext.filters, err = filtering.New(&filtering.Config{
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.EmptyWeekly(),
|
Schedule: schedule.EmptyWeekly(),
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
Context.clients.storage = newStorage(t, []*client.Persistent{{
|
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
|
||||||
Name: "default",
|
Name: "default",
|
||||||
ClientIDs: []string{"default"},
|
ClientIDs: []string{"default"},
|
||||||
UseOwnSettings: false,
|
UseOwnSettings: false,
|
||||||
@ -124,7 +124,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
Context.filters, err = filtering.New(&filtering.Config{
|
globalContext.filters, err = filtering.New(&filtering.Config{
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.EmptyWeekly(),
|
Schedule: schedule.EmptyWeekly(),
|
||||||
IDs: globalBlockedServices,
|
IDs: globalBlockedServices,
|
||||||
@ -132,7 +132,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
|||||||
}, nil)
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
Context.clients.storage = newStorage(t, []*client.Persistent{{
|
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
|
||||||
Name: "default",
|
Name: "default",
|
||||||
ClientIDs: []string{"default"},
|
ClientIDs: []string{"default"},
|
||||||
UseOwnBlockedServices: false,
|
UseOwnBlockedServices: false,
|
||||||
|
@ -91,10 +91,10 @@ func (c *homeContext) getDataDir() string {
|
|||||||
return filepath.Join(c.workDir, dataDir)
|
return filepath.Join(c.workDir, dataDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context - a global context object
|
// globalContext is a global context object.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Refactor.
|
// TODO(a.garipov): Refactor.
|
||||||
var Context homeContext
|
var globalContext homeContext
|
||||||
|
|
||||||
// Main is the entry point
|
// Main is the entry point
|
||||||
func Main(clientBuildFS fs.FS) {
|
func Main(clientBuildFS fs.FS) {
|
||||||
@ -120,8 +120,8 @@ func Main(clientBuildFS fs.FS) {
|
|||||||
log.Info("Received signal %q", sig)
|
log.Info("Received signal %q", sig)
|
||||||
switch sig {
|
switch sig {
|
||||||
case syscall.SIGHUP:
|
case syscall.SIGHUP:
|
||||||
Context.clients.storage.ReloadARP(ctx)
|
globalContext.clients.storage.ReloadARP(ctx)
|
||||||
Context.tls.reload()
|
globalContext.tls.reload()
|
||||||
default:
|
default:
|
||||||
cleanup(ctx)
|
cleanup(ctx)
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
@ -140,13 +140,13 @@ func Main(clientBuildFS fs.FS) {
|
|||||||
run(opts, clientBuildFS, done)
|
run(opts, clientBuildFS, done)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupContext initializes [Context] fields. It also reads and upgrades
|
// setupContext initializes [globalContext] fields. It also reads and upgrades
|
||||||
// config file if necessary.
|
// config file if necessary.
|
||||||
func setupContext(opts options) (err error) {
|
func setupContext(opts options) (err error) {
|
||||||
Context.firstRun = detectFirstRun()
|
globalContext.firstRun = detectFirstRun()
|
||||||
|
|
||||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
globalContext.tlsRoots = aghtls.SystemRootCAs()
|
||||||
Context.mux = http.NewServeMux()
|
globalContext.mux = http.NewServeMux()
|
||||||
|
|
||||||
if !opts.noEtcHosts {
|
if !opts.noEtcHosts {
|
||||||
err = setupHostsContainer()
|
err = setupHostsContainer()
|
||||||
@ -156,7 +156,7 @@ func setupContext(opts options) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.firstRun {
|
if globalContext.firstRun {
|
||||||
log.Info("This is the first time AdGuard Home is launched")
|
log.Info("This is the first time AdGuard Home is launched")
|
||||||
checkNetworkPermissions()
|
checkNetworkPermissions()
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ func setupHostsContainer() (err error) {
|
|||||||
return fmt.Errorf("getting default system hosts paths: %w", err)
|
return fmt.Errorf("getting default system hosts paths: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
|
globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeErr := hostsWatcher.Close()
|
closeErr := hostsWatcher.Close()
|
||||||
if errors.Is(err, aghnet.ErrNoHostsPaths) {
|
if errors.Is(err, aghnet.ErrNoHostsPaths) {
|
||||||
@ -271,7 +271,7 @@ func setupOpts(opts options) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
|
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
|
||||||
Context.pidFileName = opts.pidFile
|
globalContext.pidFileName = opts.pidFile
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -286,13 +286,13 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//lint:ignore SA1019 Migration is not over.
|
//lint:ignore SA1019 Migration is not over.
|
||||||
config.DHCP.WorkDir = Context.workDir
|
config.DHCP.WorkDir = globalContext.workDir
|
||||||
config.DHCP.DataDir = Context.getDataDir()
|
config.DHCP.DataDir = globalContext.getDataDir()
|
||||||
config.DHCP.HTTPRegister = httpRegister
|
config.DHCP.HTTPRegister = httpRegister
|
||||||
config.DHCP.ConfigModified = onConfigModified
|
config.DHCP.ConfigModified = onConfigModified
|
||||||
|
|
||||||
Context.dhcpServer, err = dhcpd.Create(config.DHCP)
|
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
|
||||||
if Context.dhcpServer == nil || err != nil {
|
if globalContext.dhcpServer == nil || err != nil {
|
||||||
// TODO(a.garipov): There are a lot of places in the code right
|
// TODO(a.garipov): There are a lot of places in the code right
|
||||||
// now which assume that the DHCP server can be nil despite this
|
// now which assume that the DHCP server can be nil despite this
|
||||||
// condition. Inspect them and perhaps rewrite them to use
|
// condition. Inspect them and perhaps rewrite them to use
|
||||||
@ -305,12 +305,12 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
|
|||||||
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
|
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return Context.clients.Init(
|
return globalContext.clients.Init(
|
||||||
ctx,
|
ctx,
|
||||||
logger,
|
logger,
|
||||||
config.Clients.Persistent,
|
config.Clients.Persistent,
|
||||||
Context.dhcpServer,
|
globalContext.dhcpServer,
|
||||||
Context.etcHosts,
|
globalContext.etcHosts,
|
||||||
arpDB,
|
arpDB,
|
||||||
config.Filtering,
|
config.Filtering,
|
||||||
)
|
)
|
||||||
@ -374,15 +374,15 @@ func setupDNSFilteringConf(
|
|||||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||||
)
|
)
|
||||||
|
|
||||||
conf.EtcHosts = Context.etcHosts
|
conf.EtcHosts = globalContext.etcHosts
|
||||||
// TODO(s.chzhen): Use empty interface.
|
// TODO(s.chzhen): Use empty interface.
|
||||||
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
|
if globalContext.etcHosts == nil || !config.DNS.HostsFileEnabled {
|
||||||
conf.EtcHosts = nil
|
conf.EtcHosts = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.ConfigModified = onConfigModified
|
conf.ConfigModified = onConfigModified
|
||||||
conf.HTTPRegister = httpRegister
|
conf.HTTPRegister = httpRegister
|
||||||
conf.DataDir = Context.getDataDir()
|
conf.DataDir = globalContext.getDataDir()
|
||||||
conf.Filters = slices.Clone(config.Filters)
|
conf.Filters = slices.Clone(config.Filters)
|
||||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||||
conf.UserRules = slices.Clone(config.UserRules)
|
conf.UserRules = slices.Clone(config.UserRules)
|
||||||
@ -560,7 +560,7 @@ func initWeb(
|
|||||||
ReadHeaderTimeout: readHdrTimeout,
|
ReadHeaderTimeout: readHdrTimeout,
|
||||||
WriteTimeout: writeTimeout,
|
WriteTimeout: writeTimeout,
|
||||||
|
|
||||||
firstRun: Context.firstRun,
|
firstRun: globalContext.firstRun,
|
||||||
disableUpdate: disableUpdate,
|
disableUpdate: disableUpdate,
|
||||||
runningAsService: opts.runningAsService,
|
runningAsService: opts.runningAsService,
|
||||||
serveHTTP3: config.DNS.ServeHTTP3,
|
serveHTTP3: config.DNS.ServeHTTP3,
|
||||||
@ -602,7 +602,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
|
|
||||||
// Print the first message after logger is configured.
|
// Print the first message after logger is configured.
|
||||||
log.Info(version.Full())
|
log.Info(version.Full())
|
||||||
log.Debug("current working directory is %s", Context.workDir)
|
log.Debug("current working directory is %s", globalContext.workDir)
|
||||||
if opts.runningAsService {
|
if opts.runningAsService {
|
||||||
log.Info("AdGuard Home is running as a service")
|
log.Info("AdGuard Home is running as a service")
|
||||||
}
|
}
|
||||||
@ -632,13 +632,13 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
|
|
||||||
confPath := configFilePath()
|
confPath := configFilePath()
|
||||||
|
|
||||||
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config)
|
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
|
||||||
|
|
||||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||||
// effect.
|
// effect.
|
||||||
cmdlineUpdate(ctx, slogLogger, opts, upd)
|
cmdlineUpdate(ctx, slogLogger, opts, upd)
|
||||||
|
|
||||||
if !Context.firstRun {
|
if !globalContext.firstRun {
|
||||||
// Save the updated config.
|
// Save the updated config.
|
||||||
err = config.write()
|
err = config.write()
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
@ -648,33 +648,33 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dataDir := Context.getDataDir()
|
dataDir := globalContext.getDataDir()
|
||||||
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
|
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
|
||||||
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
|
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
|
||||||
|
|
||||||
GLMode = opts.glinetMode
|
GLMode = opts.glinetMode
|
||||||
|
|
||||||
// Init auth module.
|
// Init auth module.
|
||||||
Context.auth, err = initUsers()
|
globalContext.auth, err = initUsers()
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
|
globalContext.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("initializing tls: %s", err)
|
log.Error("initializing tls: %s", err)
|
||||||
onConfigModified()
|
onConfigModified()
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
|
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
if !Context.firstRun {
|
if !globalContext.firstRun {
|
||||||
err = initDNS(slogLogger, statsDir, querylogDir)
|
err = initDNS(slogLogger, statsDir, querylogDir)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
Context.tls.start()
|
globalContext.tls.start()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
startErr := startDNSServer()
|
startErr := startDNSServer()
|
||||||
@ -684,8 +684,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if Context.dhcpServer != nil {
|
if globalContext.dhcpServer != nil {
|
||||||
err = Context.dhcpServer.Start()
|
err = globalContext.dhcpServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("starting dhcp server: %s", err)
|
log.Error("starting dhcp server: %s", err)
|
||||||
}
|
}
|
||||||
@ -693,10 +693,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !opts.noPermCheck {
|
if !opts.noPermCheck {
|
||||||
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
|
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.web.start(ctx)
|
globalContext.web.start(ctx)
|
||||||
|
|
||||||
// Wait for other goroutines to complete their job.
|
// Wait for other goroutines to complete their job.
|
||||||
<-done
|
<-done
|
||||||
@ -775,7 +775,7 @@ func checkPermissions(
|
|||||||
|
|
||||||
// initUsers initializes context auth module. Clears config users field.
|
// initUsers initializes context auth module. Clears config users field.
|
||||||
func initUsers() (auth *Auth, err error) {
|
func initUsers() (auth *Auth, err error) {
|
||||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
|
||||||
|
|
||||||
var rateLimiter *authRateLimiter
|
var rateLimiter *authRateLimiter
|
||||||
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
|
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
|
||||||
@ -810,7 +810,7 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
|||||||
// startMods initializes and starts the DNS server after installation.
|
// startMods initializes and starts the DNS server after installation.
|
||||||
// baseLogger must not be nil.
|
// baseLogger must not be nil.
|
||||||
func startMods(baseLogger *slog.Logger) (err error) {
|
func startMods(baseLogger *slog.Logger) (err error) {
|
||||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -820,7 +820,7 @@ func startMods(baseLogger *slog.Logger) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.tls.start()
|
globalContext.tls.start()
|
||||||
|
|
||||||
err = startDNSServer()
|
err = startDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -883,14 +883,14 @@ func writePIDFile(fn string) bool {
|
|||||||
func initConfigFilename(opts options) {
|
func initConfigFilename(opts options) {
|
||||||
confPath := opts.confFilename
|
confPath := opts.confFilename
|
||||||
if confPath == "" {
|
if confPath == "" {
|
||||||
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml")
|
globalContext.confFilePath = filepath.Join(globalContext.workDir, "AdGuardHome.yaml")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("config path overridden to %q from cmdline", confPath)
|
log.Debug("config path overridden to %q from cmdline", confPath)
|
||||||
|
|
||||||
Context.confFilePath = confPath
|
globalContext.confFilePath = confPath
|
||||||
}
|
}
|
||||||
|
|
||||||
// initWorkingDir initializes the workDir. If no command-line arguments are
|
// initWorkingDir initializes the workDir. If no command-line arguments are
|
||||||
@ -904,18 +904,18 @@ func initWorkingDir(opts options) (err error) {
|
|||||||
|
|
||||||
if opts.workDir != "" {
|
if opts.workDir != "" {
|
||||||
// If there is a custom config file, use it's directory as our working dir
|
// If there is a custom config file, use it's directory as our working dir
|
||||||
Context.workDir = opts.workDir
|
globalContext.workDir = opts.workDir
|
||||||
} else {
|
} else {
|
||||||
Context.workDir = filepath.Dir(execPath)
|
globalContext.workDir = filepath.Dir(execPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
workDir, err := filepath.EvalSymlinks(Context.workDir)
|
workDir, err := filepath.EvalSymlinks(globalContext.workDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.workDir = workDir
|
globalContext.workDir = workDir
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -924,13 +924,13 @@ func initWorkingDir(opts options) (err error) {
|
|||||||
func cleanup(ctx context.Context) {
|
func cleanup(ctx context.Context) {
|
||||||
log.Info("stopping AdGuard Home")
|
log.Info("stopping AdGuard Home")
|
||||||
|
|
||||||
if Context.web != nil {
|
if globalContext.web != nil {
|
||||||
Context.web.close(ctx)
|
globalContext.web.close(ctx)
|
||||||
Context.web = nil
|
globalContext.web = nil
|
||||||
}
|
}
|
||||||
if Context.auth != nil {
|
if globalContext.auth != nil {
|
||||||
Context.auth.Close()
|
globalContext.auth.Close()
|
||||||
Context.auth = nil
|
globalContext.auth = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := stopDNSServer()
|
err := stopDNSServer()
|
||||||
@ -938,28 +938,28 @@ func cleanup(ctx context.Context) {
|
|||||||
log.Error("stopping dns server: %s", err)
|
log.Error("stopping dns server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.dhcpServer != nil {
|
if globalContext.dhcpServer != nil {
|
||||||
err = Context.dhcpServer.Stop()
|
err = globalContext.dhcpServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("stopping dhcp server: %s", err)
|
log.Error("stopping dhcp server: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.etcHosts != nil {
|
if globalContext.etcHosts != nil {
|
||||||
if err = Context.etcHosts.Close(); err != nil {
|
if err = globalContext.etcHosts.Close(); err != nil {
|
||||||
log.Error("closing hosts container: %s", err)
|
log.Error("closing hosts container: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.tls != nil {
|
if globalContext.tls != nil {
|
||||||
Context.tls = nil
|
globalContext.tls = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function is called before application exits
|
// This function is called before application exits
|
||||||
func cleanupAlways() {
|
func cleanupAlways() {
|
||||||
if len(Context.pidFileName) != 0 {
|
if len(globalContext.pidFileName) != 0 {
|
||||||
_ = os.Remove(Context.pidFileName)
|
_ = os.Remove(globalContext.pidFileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("stopped")
|
log.Info("stopped")
|
||||||
@ -1007,8 +1007,8 @@ func printWebAddrs(proto, addr string, port uint16) {
|
|||||||
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
||||||
func printHTTPAddresses(proto string) {
|
func printHTTPAddresses(proto string) {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfigSettings{}
|
||||||
if Context.tls != nil {
|
if globalContext.tls != nil {
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
globalContext.tls.WriteDiskConfig(&tlsConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
port := config.HTTPConfig.Address.Port()
|
port := config.HTTPConfig.Address.Port()
|
||||||
@ -1050,9 +1050,9 @@ func printHTTPAddresses(proto string) {
|
|||||||
|
|
||||||
// detectFirstRun returns true if this is the first run of AdGuard Home.
|
// detectFirstRun returns true if this is the first run of AdGuard Home.
|
||||||
func detectFirstRun() (ok bool) {
|
func detectFirstRun() (ok bool) {
|
||||||
confPath := Context.confFilePath
|
confPath := globalContext.confFilePath
|
||||||
if !filepath.IsAbs(confPath) {
|
if !filepath.IsAbs(confPath) {
|
||||||
confPath = filepath.Join(Context.workDir, Context.confFilePath)
|
confPath = filepath.Join(globalContext.workDir, globalContext.confFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := os.Stat(confPath)
|
_, err := os.Stat(confPath)
|
||||||
@ -1105,7 +1105,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
|
|||||||
os.Exit(osutil.ExitCodeSuccess)
|
os.Exit(osutil.ExitCodeSuccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = upd.Update(Context.firstRun)
|
err = upd.Update(globalContext.firstRun)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
err = restartService()
|
err = restartService()
|
||||||
|
@ -17,7 +17,7 @@ func httpClient() (c *http.Client) {
|
|||||||
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
||||||
// below, since Context.dnsServer may be nil when this function is called.
|
// below, since Context.dnsServer may be nil when this function is called.
|
||||||
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||||
return Context.dnsServer.DialContext(ctx, network, addr)
|
return globalContext.dnsServer.DialContext(ctx, network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
|
|||||||
DialContext: dialContext,
|
DialContext: dialContext,
|
||||||
Proxy: httpProxy,
|
Proxy: httpProxy,
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
RootCAs: Context.tlsRoots,
|
RootCAs: globalContext.tlsRoots,
|
||||||
CipherSuites: Context.tlsCipherIDs,
|
CipherSuites: globalContext.tlsCipherIDs,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -66,7 +66,7 @@ func configureLogger(ls *logSettings) (err error) {
|
|||||||
|
|
||||||
logFilePath := ls.File
|
logFilePath := ls.File
|
||||||
if !filepath.IsAbs(logFilePath) {
|
if !filepath.IsAbs(logFilePath) {
|
||||||
logFilePath = filepath.Join(Context.workDir, logFilePath)
|
logFilePath = filepath.Join(globalContext.workDir, logFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(&lumberjack.Logger{
|
log.SetOutput(&lumberjack.Logger{
|
||||||
|
@ -19,10 +19,10 @@ func setupDNSIPs(t testing.TB) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
prevConfig := config
|
prevConfig := config
|
||||||
prevTLS := Context.tls
|
prevTLS := globalContext.tls
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
config = prevConfig
|
config = prevConfig
|
||||||
Context.tls = prevTLS
|
globalContext.tls = prevTLS
|
||||||
})
|
})
|
||||||
|
|
||||||
config = &configuration{
|
config = &configuration{
|
||||||
@ -32,7 +32,7 @@ func setupDNSIPs(t testing.TB) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.tls = &tlsManager{}
|
globalContext.tls = &tlsManager{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleMobileConfigDoH(t *testing.T) {
|
func TestHandleMobileConfigDoH(t *testing.T) {
|
||||||
@ -62,10 +62,10 @@ func TestHandleMobileConfigDoH(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error_no_host", func(t *testing.T) {
|
t.Run("error_no_host", func(t *testing.T) {
|
||||||
oldTLSConf := Context.tls
|
oldTLSConf := globalContext.tls
|
||||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
t.Cleanup(func() { globalContext.tls = oldTLSConf })
|
||||||
|
|
||||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
globalContext.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -134,10 +134,10 @@ func TestHandleMobileConfigDoT(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error_no_host", func(t *testing.T) {
|
t.Run("error_no_host", func(t *testing.T) {
|
||||||
oldTLSConf := Context.tls
|
oldTLSConf := globalContext.tls
|
||||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
t.Cleanup(func() { globalContext.tls = oldTLSConf })
|
||||||
|
|
||||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
globalContext.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
@ -47,7 +47,7 @@ type profileJSON struct {
|
|||||||
|
|
||||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
u := Context.auth.getCurrentUser(r)
|
u := globalContext.auth.getCurrentUser(r)
|
||||||
|
|
||||||
var resp profileJSON
|
var resp profileJSON
|
||||||
func() {
|
func() {
|
||||||
|
@ -112,7 +112,7 @@ func (m *tlsManager) start() {
|
|||||||
// The background context is used because the TLSConfigChanged wraps context
|
// The background context is used because the TLSConfigChanged wraps context
|
||||||
// with timeout on its own and shuts down the server, which handles current
|
// with timeout on its own and shuts down the server, which handles current
|
||||||
// request.
|
// request.
|
||||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reload updates the configuration and restarts t.
|
// reload updates the configuration and restarts t.
|
||||||
@ -160,7 +160,7 @@ func (m *tlsManager) reload() {
|
|||||||
// The background context is used because the TLSConfigChanged wraps context
|
// The background context is used because the TLSConfigChanged wraps context
|
||||||
// with timeout on its own and shuts down the server, which handles current
|
// with timeout on its own and shuts down the server, which handles current
|
||||||
// request.
|
// request.
|
||||||
Context.web.tlsConfigChanged(context.Background(), tlsConf)
|
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||||
@ -463,7 +463,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||||||
// same reason.
|
// same reason.
|
||||||
if restartHTTPS {
|
if restartHTTPS {
|
||||||
go func() {
|
go func() {
|
||||||
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -539,7 +539,7 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
|
|||||||
|
|
||||||
opts := x509.VerifyOptions{
|
opts := x509.VerifyOptions{
|
||||||
DNSName: srvName,
|
DNSName: srvName,
|
||||||
Roots: Context.tlsRoots,
|
Roots: globalContext.tlsRoots,
|
||||||
Intermediates: pool,
|
Intermediates: pool,
|
||||||
}
|
}
|
||||||
_, err = main.Verify(opts)
|
_, err = main.Verify(opts)
|
||||||
|
@ -129,7 +129,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
|||||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||||
|
|
||||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||||
Context.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
|
globalContext.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
|
||||||
|
|
||||||
// add handlers for /install paths, we only need them when we're not configured yet
|
// add handlers for /install paths, we only need them when we're not configured yet
|
||||||
if conf.firstRun {
|
if conf.firstRun {
|
||||||
@ -138,7 +138,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
|||||||
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
|
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
|
||||||
)
|
)
|
||||||
|
|
||||||
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
|
globalContext.mux.Handle("/install.html", preInstallHandler(clientFS))
|
||||||
w.registerInstallHandlers()
|
w.registerInstallHandlers()
|
||||||
} else {
|
} else {
|
||||||
registerControlHandlers(w)
|
registerControlHandlers(w)
|
||||||
@ -154,7 +154,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
|||||||
//
|
//
|
||||||
// TODO(a.garipov): Adapt for HTTP/3.
|
// TODO(a.garipov): Adapt for HTTP/3.
|
||||||
func webCheckPortAvailable(port uint16) (ok bool) {
|
func webCheckPortAvailable(port uint16) (ok bool) {
|
||||||
if Context.web.httpsServer.server != nil {
|
if globalContext.web.httpsServer.server != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +224,7 @@ func (web *webAPI) start(ctx context.Context) {
|
|||||||
errs := make(chan error, 2)
|
errs := make(chan error, 2)
|
||||||
|
|
||||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||||
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
|
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
|
||||||
|
|
||||||
logger := web.baseLogger.With(loggerKeyServer, "plain")
|
logger := web.baseLogger.With(loggerKeyServer, "plain")
|
||||||
|
|
||||||
@ -307,11 +307,11 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
|||||||
|
|
||||||
web.httpsServer.server = &http.Server{
|
web.httpsServer.server = &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||||
RootCAs: Context.tlsRoots,
|
RootCAs: globalContext.tlsRoots,
|
||||||
CipherSuites: Context.tlsCipherIDs,
|
CipherSuites: globalContext.tlsCipherIDs,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
},
|
},
|
||||||
ReadTimeout: web.conf.ReadTimeout,
|
ReadTimeout: web.conf.ReadTimeout,
|
||||||
@ -344,11 +344,11 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
|
|||||||
Addr: address,
|
Addr: address,
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||||
RootCAs: Context.tlsRoots,
|
RootCAs: globalContext.tlsRoots,
|
||||||
CipherSuites: Context.tlsCipherIDs,
|
CipherSuites: globalContext.tlsCipherIDs,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
},
|
},
|
||||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
|
||||||
}
|
}
|
||||||
|
|
||||||
web.logger.DebugContext(ctx, "starting http/3 server")
|
web.logger.DebugContext(ctx, "starting http/3 server")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user