From 8d5b0084f1a0560d3476c264f75f99ddb24e5059 Mon Sep 17 00:00:00 2001 From: Arve Knudsen Date: Wed, 2 Dec 2020 12:44:51 +0100 Subject: [PATCH] Middleware: Simplifications (#29491) * Middleware: Simplify Signed-off-by: Arve Knudsen * middleware: Rename auth_proxy directory to authproxy Signed-off-by: Arve Knudsen --- pkg/infra/remotecache/testing.go | 5 ++- pkg/middleware/auth.go | 2 +- pkg/middleware/auth_proxy.go | 4 +-- pkg/middleware/auth_proxy_test.go | 6 ++-- .../{auth_proxy => authproxy}/auth_proxy.go | 16 +++++----- .../auth_proxy_test.go | 18 +++++------ pkg/middleware/cookie.go | 32 ++++++++++++++++--- pkg/middleware/middleware.go | 16 ---------- pkg/middleware/middleware_test.go | 2 +- pkg/registry/registry.go | 15 +++++++-- 10 files changed, 66 insertions(+), 50 deletions(-) rename pkg/middleware/{auth_proxy => authproxy}/auth_proxy.go (95%) rename pkg/middleware/{auth_proxy => authproxy}/auth_proxy_test.go (93%) diff --git a/pkg/infra/remotecache/testing.go b/pkg/infra/remotecache/testing.go index 43cd449cb68..1333ef1ea04 100644 --- a/pkg/infra/remotecache/testing.go +++ b/pkg/infra/remotecache/testing.go @@ -5,6 +5,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/require" ) // NewFakeStore creates store for testing @@ -26,9 +27,7 @@ func NewFakeStore(t *testing.T) *RemoteCache { } err := dc.Init() - if err != nil { - t.Fatalf("failed to init remote cache for test. error: %v", err) - } + require.NoError(t, err, "Failed to init remote cache for test") return dc } diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index c44d7dd9a7c..d24871d31ca 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -57,7 +57,7 @@ func notAuthorized(c *models.ReqContext) { // remove any forceLogin=true params redirectTo = removeForceLoginParams(redirectTo) - WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, newCookieOptions) + WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil) c.Redirect(setting.AppSubUrl + "/login") } diff --git a/pkg/middleware/auth_proxy.go b/pkg/middleware/auth_proxy.go index d69fd6dd2de..2365df81852 100644 --- a/pkg/middleware/auth_proxy.go +++ b/pkg/middleware/auth_proxy.go @@ -5,7 +5,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" - authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" + "github.com/grafana/grafana/pkg/middleware/authproxy" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" ) @@ -45,7 +45,7 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon } // Check if allowed to continue with this IP - if result, err := auth.IsAllowedIP(); !result { + if err := auth.IsAllowedIP(); err != nil { logger.Error( "Failed to check whitelisted IP addresses", "message", err.Error(), diff --git a/pkg/middleware/auth_proxy_test.go b/pkg/middleware/auth_proxy_test.go index 1c678efe535..ee022d7f74c 100644 --- a/pkg/middleware/auth_proxy_test.go +++ b/pkg/middleware/auth_proxy_test.go @@ -8,7 +8,7 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" - authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" + "github.com/grafana/grafana/pkg/middleware/authproxy" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" @@ -29,7 +29,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { cmd.Result = &models.User{Id: userID} return nil } - getSignedUserHandler := func(cmd *models.GetSignedInUserQuery) error { + getUserHandler := func(cmd *models.GetSignedInUserQuery) error { // Simulate that the cached user ID is stale if cmd.UserId != userID { return models.ErrUserNotFound @@ -46,7 +46,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { origEnabled := setting.AuthProxyEnabled origHeaderProperty := setting.AuthProxyHeaderProperty bus.AddHandler("", upsertHandler) - bus.AddHandler("", getSignedUserHandler) + bus.AddHandler("", getUserHandler) t.Cleanup(func() { setting.AuthProxyHeaderName = origHeaderName setting.AuthProxyEnabled = origEnabled diff --git a/pkg/middleware/auth_proxy/auth_proxy.go b/pkg/middleware/authproxy/auth_proxy.go similarity index 95% rename from pkg/middleware/auth_proxy/auth_proxy.go rename to pkg/middleware/authproxy/auth_proxy.go index 6df587a0579..0e0267721a9 100644 --- a/pkg/middleware/auth_proxy/auth_proxy.go +++ b/pkg/middleware/authproxy/auth_proxy.go @@ -114,11 +114,11 @@ func (auth *AuthProxy) HasHeader() bool { } // IsAllowedIP compares presented IP with the whitelist one -func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { +func (auth *AuthProxy) IsAllowedIP() *Error { ip := auth.ctx.Req.RemoteAddr if len(strings.TrimSpace(auth.whitelistIP)) == 0 { - return true, nil + return nil } proxies := strings.Split(auth.whitelistIP, ",") @@ -126,7 +126,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { for _, proxy := range proxies { result, err := coerceProxyAddress(proxy) if err != nil { - return false, newError("Could not get the network", err) + return newError("could not get the network", err) } proxyObjs = append(proxyObjs, result) @@ -134,13 +134,13 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { sourceIP, _, err := net.SplitHostPort(ip) if err != nil { - return false, newError("could not parse address", err) + return newError("could not parse address", err) } sourceObj := net.ParseIP(sourceIP) for _, proxyObj := range proxyObjs { if proxyObj.Contains(sourceObj) { - return true, nil + return nil } } @@ -148,7 +148,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { "request for user (%s) from %s is not from the authentication proxy", auth.header, sourceIP, ) - return false, newError("Proxy authentication required", err) + return newError("proxy authentication required", err) } func HashCacheKey(key string) string { @@ -232,7 +232,7 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { func (auth *AuthProxy) LoginViaLDAP() (int64, *Error) { config, err := getLDAPConfig() if err != nil { - return 0, newError("Failed to get LDAP config", nil) + return 0, newError("failed to get LDAP config", nil) } extUser, _, err := newLDAP(config.Servers).User(auth.header) @@ -273,7 +273,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { extUser.Email = auth.header extUser.Login = auth.header default: - return 0, newError("Auth proxy header property invalid", nil) + return 0, newError("auth proxy header property invalid", nil) } auth.headersIterator(func(field string, header string) { diff --git a/pkg/middleware/auth_proxy/auth_proxy_test.go b/pkg/middleware/authproxy/auth_proxy_test.go similarity index 93% rename from pkg/middleware/auth_proxy/auth_proxy_test.go rename to pkg/middleware/authproxy/auth_proxy_test.go index e4a66a0d830..a6625407af8 100644 --- a/pkg/middleware/auth_proxy/auth_proxy_test.go +++ b/pkg/middleware/authproxy/auth_proxy_test.go @@ -17,31 +17,31 @@ import ( "gopkg.in/macaron.v1" ) -type TestMultiLDAP struct { +type fakeMultiLDAP struct { multildap.MultiLDAP ID int64 userCalled bool loginCalled bool } -func (stub *TestMultiLDAP) Login(query *models.LoginUserQuery) ( +func (m *fakeMultiLDAP) Login(query *models.LoginUserQuery) ( *models.ExternalUserInfo, error, ) { - stub.loginCalled = true + m.loginCalled = true result := &models.ExternalUserInfo{ - UserId: stub.ID, + UserId: m.ID, } return result, nil } -func (stub *TestMultiLDAP) User(login string) ( +func (m *fakeMultiLDAP) User(login string) ( *models.ExternalUserInfo, ldap.ServerConfig, error, ) { - stub.userCalled = true + m.userCalled = true result := &models.ExternalUserInfo{ - UserId: stub.ID, + UserId: m.ID, } return result, ldap.ServerConfig{}, nil } @@ -126,7 +126,7 @@ func TestMiddlewareContext(t *testing.T) { return true } - stub := &TestMultiLDAP{ + stub := &fakeMultiLDAP{ ID: 42, } @@ -181,7 +181,7 @@ func TestMiddlewareContext(t *testing.T) { auth := prepareMiddleware(t, req, store) - stub := &TestMultiLDAP{ + stub := &fakeMultiLDAP{ ID: 42, } diff --git a/pkg/middleware/cookie.go b/pkg/middleware/cookie.go index c81cb0dfa6c..78939632dbc 100644 --- a/pkg/middleware/cookie.go +++ b/pkg/middleware/cookie.go @@ -2,7 +2,10 @@ package middleware import ( "net/http" + "net/url" + "time" + "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" ) @@ -26,14 +29,18 @@ func newCookieOptions() CookieOptions { } } -type GetCookieOptionsFunc func() CookieOptions +type getCookieOptionsFunc func() CookieOptions -func DeleteCookie(w http.ResponseWriter, name string, getCookieOptionsFunc GetCookieOptionsFunc) { - WriteCookie(w, name, "", -1, getCookieOptionsFunc) +func DeleteCookie(w http.ResponseWriter, name string, getCookieOptions getCookieOptionsFunc) { + WriteCookie(w, name, "", -1, getCookieOptions) } -func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, getCookieOptionsFunc GetCookieOptionsFunc) { - options := getCookieOptionsFunc() +func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, getCookieOptions getCookieOptionsFunc) { + if getCookieOptions == nil { + getCookieOptions = newCookieOptions + } + + options := getCookieOptions() cookie := http.Cookie{ Name: name, MaxAge: maxAge, @@ -47,3 +54,18 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g } http.SetCookie(w, &cookie) } + +func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) { + if setting.Env == setting.Dev { + ctx.Logger.Info("New token", "unhashed token", value) + } + + var maxAge int + if maxLifetime <= 0 { + maxAge = -1 + } else { + maxAge = int(maxLifetime.Seconds()) + } + + WriteCookie(ctx.Resp, setting.LoginCookieName, url.QueryEscape(value), maxAge, nil) +} diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index ac82532e52b..6d5dd32fc26 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net/url" "strconv" "strings" "time" @@ -274,21 +273,6 @@ func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.User } } -func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) { - if setting.Env == setting.Dev { - ctx.Logger.Info("New token", "unhashed token", value) - } - - var maxAge int - if maxLifetime <= 0 { - maxAge = -1 - } else { - maxAge = int(maxLifetime.Seconds()) - } - - WriteCookie(ctx.Resp, setting.LoginCookieName, url.QueryEscape(value), maxAge, newCookieOptions) -} - func AddDefaultResponseHeaders() macaron.Handler { return func(ctx *macaron.Context) { ctx.Resp.Before(func(w macaron.ResponseWriter) { diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 7a996f82c3c..f78906485c8 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -21,7 +21,7 @@ import ( "github.com/grafana/grafana/pkg/components/gtime" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" - authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" + "github.com/grafana/grafana/pkg/middleware/authproxy" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/login" diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 477079d34e4..dd60fa56f6e 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -17,7 +17,7 @@ type Descriptor struct { var services []*Descriptor func RegisterServiceWithPriority(instance Service, priority Priority) { - services = append(services, &Descriptor{ + Register(&Descriptor{ Name: reflect.TypeOf(instance).Elem().Name(), Instance: instance, InitPriority: priority, @@ -25,7 +25,7 @@ func RegisterServiceWithPriority(instance Service, priority Priority) { } func RegisterService(instance Service) { - services = append(services, &Descriptor{ + Register(&Descriptor{ Name: reflect.TypeOf(instance).Elem().Name(), Instance: instance, InitPriority: Medium, @@ -33,6 +33,17 @@ func RegisterService(instance Service) { } func Register(descriptor *Descriptor) { + if descriptor == nil { + return + } + // Overwrite any existing equivalent service + for i, svc := range services { + if svc.Name == descriptor.Name { + services[i] = descriptor + return + } + } + services = append(services, descriptor) }