Auth proxy: Return standard error type (#29502)

* Rewrite auth proxy tests to use standard lib

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>

* Auth proxy: Use standard error type

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
This commit is contained in:
Arve Knudsen 2020-12-02 16:57:16 +01:00 committed by GitHub
parent 2c535a9583
commit 752a424e1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 200 additions and 174 deletions

View File

@ -12,18 +12,41 @@ import (
var header = setting.AuthProxyHeaderName var header = setting.AuthProxyHeaderName
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, *authproxy.Error) { func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache) logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers // Try to log in user via various providers
id, e := auth.Login(logger, ignoreCache) id, err := auth.Login(logger, ignoreCache)
if e != nil { if err != nil {
logger.Error("Failed to login", "username", username, "message", e.Error(), "error", e.DetailsError, details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details,
"ignoreCache", ignoreCache) "ignoreCache", ignoreCache)
return 0, e return 0, err
} }
return id, nil return id, nil
} }
// handleError calls ctx.Handle with the error message and the underlying error.
// If the error is of type authproxy.Error, its DetailsError is unwrapped and passed to ctx.Handle.
// If a callback is provided, it's called with either err.DetailsError, if err is of type
// authproxy.Error, otherwise err itself.
func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(err error)) {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
ctx.Handle(statusCode, err.Error(), details)
if cb != nil {
cb(details)
}
}
func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqContext, orgID int64) bool { func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqContext, orgID int64) bool {
username := ctx.Req.Header.Get(header) username := ctx.Req.Header.Get(header)
auth := authproxy.New(&authproxy.Options{ auth := authproxy.New(&authproxy.Options{
@ -46,18 +69,15 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon
// Check if allowed to continue with this IP // Check if allowed to continue with this IP
if err := auth.IsAllowedIP(); err != nil { if err := auth.IsAllowedIP(); err != nil {
logger.Error( handleError(ctx, err, 407, func(details error) {
"Failed to check whitelisted IP addresses", logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
"message", err.Error(), })
"error", err.DetailsError,
)
ctx.Handle(407, err.Error(), err.DetailsError)
return true return true
} }
id, e := logUserIn(auth, username, logger, false) id, err := logUserIn(auth, username, logger, false)
if e != nil { if err != nil {
ctx.Handle(407, e.Error(), e.DetailsError) handleError(ctx, err, 407, nil)
return true return true
} }
@ -76,15 +96,16 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon
logger.Error("Got unexpected error when removing user from auth cache", "error", err) logger.Error("Got unexpected error when removing user from auth cache", "error", err)
} }
} }
id, e = logUserIn(auth, username, logger, true) id, err = logUserIn(auth, username, logger, true)
if e != nil { if err != nil {
ctx.Handle(407, e.Error(), e.DetailsError) handleError(ctx, err, 407, nil)
return true return true
} }
user, e = auth.GetSignedUser(id) user, err = auth.GetSignedUser(id)
if e != nil { if err != nil {
ctx.Handle(407, e.Error(), e.DetailsError) handleError(ctx, err, 407, nil)
return true return true
} }
} }
@ -96,14 +117,15 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon
ctx.IsSignedIn = true ctx.IsSignedIn = true
// Remember user data in cache // Remember user data in cache
if e := auth.Remember(id); e != nil { if err := auth.Remember(id); err != nil {
logger.Error( handleError(ctx, err, 500, func(details error) {
"Failed to store user in cache", logger.Error(
"username", username, "Failed to store user in cache",
"message", e.Error(), "username", username,
"error", e.DetailsError, "message", e.Error(),
) "error", details,
ctx.Handle(500, e.Error(), e.DetailsError) )
})
return true return true
} }

View File

@ -62,16 +62,16 @@ type Error struct {
DetailsError error DetailsError error
} }
// newError creates the Error // newError returns an Error.
func newError(message string, err error) *Error { func newError(message string, err error) Error {
return &Error{ return Error{
Message: message, Message: message,
DetailsError: err, DetailsError: err,
} }
} }
// Error returns a Error error string // Error returns the error message.
func (err *Error) Error() string { func (err Error) Error() string {
return err.Message return err.Message
} }
@ -114,7 +114,7 @@ func (auth *AuthProxy) HasHeader() bool {
} }
// IsAllowedIP compares presented IP with the whitelist one // IsAllowedIP compares presented IP with the whitelist one
func (auth *AuthProxy) IsAllowedIP() *Error { func (auth *AuthProxy) IsAllowedIP() error {
ip := auth.ctx.Req.RemoteAddr ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.whitelistIP)) == 0 { if len(strings.TrimSpace(auth.whitelistIP)) == 0 {
@ -144,11 +144,10 @@ func (auth *AuthProxy) IsAllowedIP() *Error {
} }
} }
err = fmt.Errorf( return newError("proxy authentication required", fmt.Errorf(
"request for user (%s) from %s is not from the authentication proxy", auth.header, "request for user (%s) from %s is not from the authentication proxy", auth.header,
sourceIP, sourceIP,
) ))
return newError("proxy authentication required", err)
} }
func HashCacheKey(key string) string { func HashCacheKey(key string) string {
@ -173,7 +172,7 @@ func (auth *AuthProxy) getKey() string {
} }
// Login logs in user ID by whatever means possible. // Login logs in user ID by whatever means possible.
func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, *Error) { func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) {
if !ignoreCache { if !ignoreCache {
// Error here means absent cache - we don't need to handle that // Error here means absent cache - we don't need to handle that
id, err := auth.GetUserViaCache(logger) id, err := auth.GetUserViaCache(logger)
@ -229,15 +228,15 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
} }
// LoginViaLDAP logs in user via LDAP request // LoginViaLDAP logs in user via LDAP request
func (auth *AuthProxy) LoginViaLDAP() (int64, *Error) { func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
config, err := getLDAPConfig() config, err := getLDAPConfig()
if err != nil { if err != nil {
return 0, newError("failed to get LDAP config", nil) return 0, newError("failed to get LDAP config", err)
} }
extUser, _, err := newLDAP(config.Servers).User(auth.header) extUser, _, err := newLDAP(config.Servers).User(auth.header)
if err != nil { if err != nil {
return 0, newError(err.Error(), nil) return 0, err
} }
// Have to sync grafana and LDAP user during log in // Have to sync grafana and LDAP user during log in
@ -246,9 +245,8 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, *Error) {
SignupAllowed: auth.LDAPAllowSignup, SignupAllowed: auth.LDAPAllowSignup,
ExternalUser: extUser, ExternalUser: extUser,
} }
err = bus.Dispatch(upsert) if err := bus.Dispatch(upsert); err != nil {
if err != nil { return 0, err
return 0, newError(err.Error(), nil)
} }
return upsert.Result.Id, nil return upsert.Result.Id, nil
@ -273,7 +271,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
extUser.Email = auth.header extUser.Email = auth.header
extUser.Login = auth.header extUser.Login = auth.header
default: default:
return 0, newError("auth proxy header property invalid", nil) return 0, fmt.Errorf("auth proxy header property invalid")
} }
auth.headersIterator(func(field string, header string) { auth.headersIterator(func(field string, header string) {
@ -314,21 +312,21 @@ func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
} }
// GetSignedUser gets full signed user info. // GetSignedUser gets full signed user info.
func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, *Error) { func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, error) {
query := &models.GetSignedInUserQuery{ query := &models.GetSignedInUserQuery{
OrgId: auth.orgID, OrgId: auth.orgID,
UserId: userID, UserId: userID,
} }
if err := bus.Dispatch(query); err != nil { if err := bus.Dispatch(query); err != nil {
return nil, newError(err.Error(), nil) return nil, err
} }
return query.Result, nil return query.Result, nil
} }
// Remember user in cache // Remember user in cache
func (auth *AuthProxy) Remember(id int64) *Error { func (auth *AuthProxy) Remember(id int64) error {
key := auth.getKey() key := auth.getKey()
// Check if user already in cache // Check if user already in cache
@ -341,7 +339,7 @@ func (auth *AuthProxy) Remember(id int64) *Error {
err := auth.store.Set(key, id, expiration) err := auth.store.Set(key, id, expiration)
if err != nil { if err != nil {
return newError(err.Error(), nil) return err
} }
return nil return nil

View File

@ -13,7 +13,8 @@ import (
"github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/multildap" "github.com/grafana/grafana/pkg/services/multildap"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/macaron.v1" "gopkg.in/macaron.v1"
) )
@ -68,134 +69,142 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot
func TestMiddlewareContext(t *testing.T) { func TestMiddlewareContext(t *testing.T) {
logger := log.New("test") logger := log.New("test")
Convey("auth_proxy helper", t, func() { req, err := http.NewRequest("POST", "http://example.com", nil)
req, err := http.NewRequest("POST", "http://example.com", nil) require.NoError(t, err)
So(err, ShouldBeNil) setting.AuthProxyHeaderName = "X-Killa"
setting.AuthProxyHeaderName = "X-Killa" store := remotecache.NewFakeStore(t)
store := remotecache.NewFakeStore(t)
name := "markelog" name := "markelog"
req.Header.Add(setting.AuthProxyHeaderName, name) req.Header.Add(setting.AuthProxyHeaderName, name)
Convey("when the cache only contains the main header", func() { t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
Convey("with a simple cache key", func() { const id int64 = 33
// Set cache key // Set cache key
key := fmt.Sprintf(CachePrefix, HashCacheKey(name)) key := fmt.Sprintf(CachePrefix, HashCacheKey(name))
err := store.Set(key, int64(33), 0) err := store.Set(key, id, 0)
So(err, ShouldBeNil) require.NoError(t, err)
// Set up the middleware // Set up the middleware
auth := prepareMiddleware(t, req, store) auth := prepareMiddleware(t, req, store)
So(auth.getKey(), ShouldEqual, "auth-proxy-sync-ttl:0a7f3374e9659b10980fd66247b0cf2f") assert.Equal(t, "auth-proxy-sync-ttl:0a7f3374e9659b10980fd66247b0cf2f", auth.getKey())
id, err := auth.Login(logger, false) gotID, err := auth.Login(logger, false)
So(err, ShouldBeNil) require.NoError(t, err)
So(id, ShouldEqual, 33) assert.Equal(t, id, gotID)
}) })
Convey("when the cache key contains additional headers", func() { t.Run("When the cache key contains additional headers", func(t *testing.T) {
setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} const id int64 = 33
group := "grafana-core-team" setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
req.Header.Add("X-WEBAUTH-GROUPS", group) group := "grafana-core-team"
req.Header.Add("X-WEBAUTH-GROUPS", group)
key := fmt.Sprintf(CachePrefix, HashCacheKey(name+"-"+group)) key := fmt.Sprintf(CachePrefix, HashCacheKey(name+"-"+group))
err := store.Set(key, int64(33), 0) err := store.Set(key, id, 0)
So(err, ShouldBeNil) require.NoError(t, err)
auth := prepareMiddleware(t, req, store) auth := prepareMiddleware(t, req, store)
So(auth.getKey(), ShouldEqual, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0") assert.Equal(t, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0", auth.getKey())
id, err := auth.Login(logger, false) gotID, err := auth.Login(logger, false)
So(err, ShouldBeNil) require.NoError(t, err)
So(id, ShouldEqual, 33) assert.Equal(t, id, gotID)
}) })
}) }
Convey("LDAP", func() { func TestMiddlewareContext_ldap(t *testing.T) {
Convey("logs in via LDAP", func() { logger := log.New("test")
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { req, err := http.NewRequest("POST", "http://example.com", nil)
cmd.Result = &models.User{ require.NoError(t, err)
Id: 42, setting.AuthProxyHeaderName = "X-Killa"
}
const headerName = "markelog"
return nil req.Header.Add(setting.AuthProxyHeaderName, headerName)
})
t.Run("Logs in via LDAP", func(t *testing.T) {
isLDAPEnabled = func() bool { const id int64 = 42
return true
} bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{
stub := &fakeMultiLDAP{ Id: id,
ID: 42, }
}
return nil
getLDAPConfig = func() (*ldap.Config, error) { })
config := &ldap.Config{
Servers: []*ldap.ServerConfig{ isLDAPEnabled = func() bool {
{ return true
SearchBaseDNs: []string{"BaseDNHere"}, }
},
}, stub := &fakeMultiLDAP{
} ID: id,
return config, nil }
}
getLDAPConfig = func() (*ldap.Config, error) {
newLDAP = func(servers []*ldap.ServerConfig) multildap.IMultiLDAP { config := &ldap.Config{
return stub Servers: []*ldap.ServerConfig{
} {
SearchBaseDNs: []string{"BaseDNHere"},
defer func() { },
newLDAP = multildap.New },
isLDAPEnabled = ldap.IsEnabled }
getLDAPConfig = ldap.GetConfig return config, nil
}() }
store := remotecache.NewFakeStore(t) newLDAP = func(servers []*ldap.ServerConfig) multildap.IMultiLDAP {
return stub
auth := prepareMiddleware(t, req, store) }
id, err := auth.Login(logger, false) defer func() {
newLDAP = multildap.New
So(err, ShouldBeNil) isLDAPEnabled = ldap.IsEnabled
So(id, ShouldEqual, 42) getLDAPConfig = ldap.GetConfig
So(stub.userCalled, ShouldEqual, true) }()
})
store := remotecache.NewFakeStore(t)
Convey("gets nice error if ldap is enabled but not configured", func() {
isLDAPEnabled = func() bool { auth := prepareMiddleware(t, req, store)
return true
} gotID, err := auth.Login(logger, false)
require.NoError(t, err)
getLDAPConfig = func() (*ldap.Config, error) {
return nil, errors.New("Something went wrong") assert.Equal(t, id, gotID)
} assert.True(t, stub.userCalled)
})
defer func() {
newLDAP = multildap.New t.Run("Gets nice error if ldap is enabled but not configured", func(t *testing.T) {
isLDAPEnabled = ldap.IsEnabled const id int64 = 42
getLDAPConfig = ldap.GetConfig isLDAPEnabled = func() bool {
}() return true
}
store := remotecache.NewFakeStore(t)
getLDAPConfig = func() (*ldap.Config, error) {
auth := prepareMiddleware(t, req, store) return nil, errors.New("something went wrong")
}
stub := &fakeMultiLDAP{
ID: 42, defer func() {
} newLDAP = multildap.New
isLDAPEnabled = ldap.IsEnabled
newLDAP = func(servers []*ldap.ServerConfig) multildap.IMultiLDAP { getLDAPConfig = ldap.GetConfig
return stub }()
}
store := remotecache.NewFakeStore(t)
id, err := auth.Login(logger, false)
auth := prepareMiddleware(t, req, store)
So(err, ShouldNotBeNil)
So(err.Error(), ShouldContainSubstring, "failed to get the user") stub := &fakeMultiLDAP{
So(id, ShouldNotEqual, 42) ID: id,
So(stub.loginCalled, ShouldEqual, false) }
})
}) newLDAP = func(servers []*ldap.ServerConfig) multildap.IMultiLDAP {
return stub
}
gotID, err := auth.Login(logger, false)
require.EqualError(t, err, "failed to get the user")
assert.NotEqual(t, id, gotID)
assert.False(t, stub.loginCalled)
}) })
} }

View File

@ -107,10 +107,7 @@ func GetConfig() (*Config, error) {
loadingMutex.Lock() loadingMutex.Lock()
defer loadingMutex.Unlock() defer loadingMutex.Unlock()
var err error return readConfig(setting.LDAPConfigFile)
config, err = readConfig(setting.LDAPConfigFile)
return config, err
} }
func readConfig(configFile string) (*Config, error) { func readConfig(configFile string) (*Config, error) {