mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: creates a hook in the user mapping flow (#37190)
* wip * Auth Info: refactored out into it's own service * Auth: adds extension point where users are being mapped * Update pkg/services/login/authinfoservice/service.go Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com> * Update pkg/services/login/authinfoservice/service.go Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com> * Auth: simplified code * moved most authinfo stuff to its own package * added back code * linter * simplified Co-authored-by: Joan López de la Franca Beltran <joanjan14@gmail.com>
This commit is contained in:
parent
4f340550ee
commit
d51b2630c7
@ -10,6 +10,7 @@ var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrLastGrafanaAdmin = errors.New("cannot remove last grafana admin")
|
||||
ErrProtectedUser = errors.New("cannot adopt protected user")
|
||||
)
|
||||
|
||||
type Password string
|
||||
|
@ -98,8 +98,6 @@ type GetUserByAuthInfoQuery struct {
|
||||
UserId int64
|
||||
Email string
|
||||
Login string
|
||||
|
||||
Result *User
|
||||
}
|
||||
|
||||
type GetExternalUserInfoByLoginQuery struct {
|
||||
|
@ -13,8 +13,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api"
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -37,6 +35,7 @@ import (
|
||||
_ "github.com/grafana/grafana/pkg/services/auth/jwt"
|
||||
_ "github.com/grafana/grafana/pkg/services/cleanup"
|
||||
_ "github.com/grafana/grafana/pkg/services/librarypanels"
|
||||
_ "github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||
_ "github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
_ "github.com/grafana/grafana/pkg/services/ngalert"
|
||||
_ "github.com/grafana/grafana/pkg/services/notifications"
|
||||
@ -45,6 +44,7 @@ import (
|
||||
_ "github.com/grafana/grafana/pkg/services/search"
|
||||
_ "github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Config contains parameters for the New function.
|
||||
|
7
pkg/services/login/authinfo.go
Normal file
7
pkg/services/login/authinfo.go
Normal file
@ -0,0 +1,7 @@
|
||||
package login
|
||||
|
||||
import "github.com/grafana/grafana/pkg/models"
|
||||
|
||||
type AuthInfoService interface {
|
||||
LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error)
|
||||
}
|
@ -1,10 +1,12 @@
|
||||
package sqlstore
|
||||
package authinfoservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
@ -13,123 +15,7 @@ import (
|
||||
|
||||
var getTime = time.Now
|
||||
|
||||
const genericOAuthModule = "oauth_generic_oauth"
|
||||
|
||||
func init() {
|
||||
bus.AddHandler("sql", GetUserByAuthInfo)
|
||||
bus.AddHandler("sql", GetExternalUserInfoByLogin)
|
||||
bus.AddHandler("sql", GetAuthInfo)
|
||||
bus.AddHandler("sql", SetAuthInfo)
|
||||
bus.AddHandler("sql", UpdateAuthInfo)
|
||||
bus.AddHandler("sql", DeleteAuthInfo)
|
||||
}
|
||||
|
||||
func GetUserByAuthInfo(query *models.GetUserByAuthInfoQuery) error {
|
||||
user := &models.User{}
|
||||
has := false
|
||||
var err error
|
||||
authQuery := &models.GetAuthInfoQuery{}
|
||||
|
||||
// Try to find the user by auth module and id first
|
||||
if query.AuthModule != "" && query.AuthId != "" {
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
authQuery.AuthId = query.AuthId
|
||||
|
||||
err = GetAuthInfo(authQuery)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
|
||||
err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
sqlog.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
authQuery.Result = nil
|
||||
} else {
|
||||
has, err = x.Id(authQuery.Result.UserId).Get(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !has {
|
||||
// if the user has been deleted then remove the entry
|
||||
err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
sqlog.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
authQuery.Result = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by id
|
||||
if !has && query.UserId != 0 {
|
||||
has, err = x.Id(query.UserId).Get(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by email address
|
||||
if !has && query.Email != "" {
|
||||
user = &models.User{Email: query.Email}
|
||||
has, err = x.Get(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by login
|
||||
if !has && query.Login != "" {
|
||||
user = &models.User{Login: query.Login}
|
||||
has, err = x.Get(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// No user found
|
||||
if !has {
|
||||
return models.ErrUserNotFound
|
||||
}
|
||||
|
||||
// Special case for generic oauth duplicates
|
||||
if query.AuthModule == genericOAuthModule && user.Id != 0 {
|
||||
authQuery.UserId = user.Id
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
err = GetAuthInfo(authQuery)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if authQuery.Result == nil && query.AuthModule != "" {
|
||||
cmd2 := &models.SetAuthInfoCommand{
|
||||
UserId: user.Id,
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
if err := SetAuthInfo(cmd2); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
query.Result = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
|
||||
func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
||||
err := bus.Dispatch(&userQuery)
|
||||
if err != nil {
|
||||
@ -153,13 +39,20 @@ func GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||
func (s *Implementation) GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||
userAuth := &models.UserAuth{
|
||||
UserId: query.UserId,
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
has, err := x.Desc("created").Get(userAuth)
|
||||
|
||||
var has bool
|
||||
var err error
|
||||
|
||||
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
has, err = sess.Desc("created").Get(userAuth)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -188,8 +81,8 @@ func GetAuthInfo(query *models.GetAuthInfoQuery) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
func (s *Implementation) SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
authUser := &models.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
@ -222,8 +115,8 @@ func SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
|
||||
})
|
||||
}
|
||||
|
||||
func UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
func (s *Implementation) UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
||||
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
authUser := &models.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
@ -256,13 +149,13 @@ func UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
|
||||
AuthModule: cmd.AuthModule,
|
||||
}
|
||||
upd, err := sess.Update(authUser, cond)
|
||||
sqlog.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
||||
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
func (s *Implementation) DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
|
||||
return s.SQLStore.WithTransactionalDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
_, err := sess.Delete(cmd.UserAuth)
|
||||
return err
|
||||
})
|
223
pkg/services/login/authinfoservice/service.go
Normal file
223
pkg/services/login/authinfoservice/service.go
Normal file
@ -0,0 +1,223 @@
|
||||
package authinfoservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
)
|
||||
|
||||
const genericOAuthModule = "oauth_generic_oauth"
|
||||
|
||||
func init() {
|
||||
srv := &Implementation{}
|
||||
|
||||
registry.Register(®istry.Descriptor{
|
||||
Name: "UserAuthInfo",
|
||||
Instance: srv,
|
||||
InitPriority: registry.MediumHigh,
|
||||
})
|
||||
}
|
||||
|
||||
type Implementation struct {
|
||||
Bus bus.Bus `inject:""`
|
||||
SQLStore *sqlstore.SQLStore `inject:""`
|
||||
UserProtectionService login.UserProtectionService `inject:""`
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (s *Implementation) Init() error {
|
||||
s.logger = log.New("login.authinfo")
|
||||
|
||||
s.Bus.AddHandler(s.GetExternalUserInfoByLogin)
|
||||
s.Bus.AddHandler(s.GetAuthInfo)
|
||||
s.Bus.AddHandler(s.SetAuthInfo)
|
||||
s.Bus.AddHandler(s.UpdateAuthInfo)
|
||||
s.Bus.AddHandler(s.DeleteAuthInfo)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Implementation) getUserById(id int64) (bool, *models.User, error) {
|
||||
var (
|
||||
has bool
|
||||
err error
|
||||
)
|
||||
user := &models.User{}
|
||||
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
has, err = sess.ID(id).Get(user)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
return has, user, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) getUser(user *models.User) (bool, error) {
|
||||
var err error
|
||||
var has bool
|
||||
|
||||
err = s.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
has, err = sess.Get(user)
|
||||
return err
|
||||
})
|
||||
|
||||
return has, err
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndFix(query *models.GetUserByAuthInfoQuery) (bool, *models.User, *models.UserAuth, error) {
|
||||
authQuery := &models.GetAuthInfoQuery{}
|
||||
|
||||
// Try to find the user by auth module and id first
|
||||
if query.AuthModule != "" && query.AuthId != "" {
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
authQuery.AuthId = query.AuthId
|
||||
|
||||
err := s.GetAuthInfo(authQuery)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
|
||||
err := s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
return false, nil, nil, models.ErrUserNotFound
|
||||
} else {
|
||||
has, user, err := s.getUserById(authQuery.Result.UserId)
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
if !has {
|
||||
// if the user has been deleted then remove the entry
|
||||
err = s.DeleteAuthInfo(&models.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
return false, nil, nil, models.ErrUserNotFound
|
||||
}
|
||||
|
||||
return true, user, authQuery.Result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil, nil, models.ErrUserNotFound
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupByOneOf(userId int64, email string, login string) (bool, *models.User, error) {
|
||||
foundUser := false
|
||||
var user *models.User
|
||||
var err error
|
||||
|
||||
// If not found, try to find the user by id
|
||||
if userId != 0 {
|
||||
foundUser, user, err = s.getUserById(userId)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by email address
|
||||
if !foundUser && email != "" {
|
||||
user = &models.User{Email: email}
|
||||
foundUser, err = s.getUser(user)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by login
|
||||
if !foundUser && login != "" {
|
||||
user = &models.User{Login: login}
|
||||
foundUser, err = s.getUser(user)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !foundUser {
|
||||
return false, nil, models.ErrUserNotFound
|
||||
}
|
||||
|
||||
return foundUser, user, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GenericOAuthLookup(authModule string, authId string, userID int64) (*models.UserAuth, error) {
|
||||
if authModule == genericOAuthModule && userID != 0 {
|
||||
authQuery := &models.GetAuthInfoQuery{}
|
||||
authQuery.AuthModule = authModule
|
||||
authQuery.AuthId = authId
|
||||
authQuery.UserId = userID
|
||||
err := s.GetAuthInfo(authQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return authQuery.Result, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
// 1. LookupAndFix = auth info, user, error
|
||||
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
|
||||
foundUser, user, authInfo, err := s.LookupAndFix(query)
|
||||
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. FindByUserDetails
|
||||
if !foundUser {
|
||||
_, user, err = s.LookupByOneOf(query.UserId, query.Email, query.Login)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.UserProtectionService.AllowUserMapping(user, query.AuthModule); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Special case for generic oauth duplicates
|
||||
ai, err := s.GenericOAuthLookup(query.AuthModule, query.AuthId, user.Id)
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ai != nil {
|
||||
authInfo = ai
|
||||
}
|
||||
|
||||
if authInfo == nil && query.AuthModule != "" {
|
||||
cmd := &models.SetAuthInfoCommand{
|
||||
UserId: user.Id,
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
if err := s.SetAuthInfo(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
@ -1,21 +1,31 @@
|
||||
// +build integration
|
||||
|
||||
package sqlstore
|
||||
package authinfoservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
//nolint:goconst
|
||||
func TestUserAuth(t *testing.T) {
|
||||
sqlStore := InitTestDB(t)
|
||||
sqlStore := sqlstore.InitTestDB(t)
|
||||
srv := &Implementation{
|
||||
Bus: bus.New(),
|
||||
SQLStore: sqlStore,
|
||||
UserProtectionService: OSSUserProtectionImpl{},
|
||||
}
|
||||
srv.Init()
|
||||
|
||||
t.Run("Given 5 users", func(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
@ -24,7 +34,7 @@ func TestUserAuth(t *testing.T) {
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
_, err := sqlStore.CreateUser(context.Background(), cmd)
|
||||
_, err := srv.SQLStore.CreateUser(context.Background(), cmd)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@ -33,89 +43,90 @@ func TestUserAuth(t *testing.T) {
|
||||
login := "loginuser0"
|
||||
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login}
|
||||
err := GetUserByAuthInfo(query)
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// By ID
|
||||
id := query.Result.Id
|
||||
id := user.Id
|
||||
|
||||
query = &models.GetUserByAuthInfoQuery{UserId: id}
|
||||
err = GetUserByAuthInfo(query)
|
||||
_, user, err = srv.LookupByOneOf(id, "", "")
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Id, id)
|
||||
require.Equal(t, user.Id, id)
|
||||
|
||||
// By Email
|
||||
email := "user1@test.com"
|
||||
|
||||
query = &models.GetUserByAuthInfoQuery{Email: email}
|
||||
err = GetUserByAuthInfo(query)
|
||||
_, user, err = srv.LookupByOneOf(0, email, "")
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Email, email)
|
||||
require.Equal(t, user.Email, email)
|
||||
|
||||
// Don't find nonexistent user
|
||||
email = "nonexistent@test.com"
|
||||
|
||||
query = &models.GetUserByAuthInfoQuery{Email: email}
|
||||
err = GetUserByAuthInfo(query)
|
||||
_, user, err = srv.LookupByOneOf(0, email, "")
|
||||
|
||||
require.Equal(t, err, models.ErrUserNotFound)
|
||||
require.Nil(t, query.Result)
|
||||
require.Equal(t, models.ErrUserNotFound, err)
|
||||
require.Nil(t, user)
|
||||
})
|
||||
|
||||
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
|
||||
// get nonexistent user_auth entry
|
||||
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
err := GetUserByAuthInfo(query)
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
|
||||
require.Equal(t, err, models.ErrUserNotFound)
|
||||
require.Nil(t, query.Result)
|
||||
require.Equal(t, models.ErrUserNotFound, err)
|
||||
require.Nil(t, user)
|
||||
|
||||
// create user_auth entry
|
||||
login := "loginuser0"
|
||||
|
||||
query.Login = login
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// get via user_auth
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// get with non-matching id
|
||||
id := query.Result.Id
|
||||
id := user.Id
|
||||
|
||||
query.UserId = id + 1
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, "loginuser1")
|
||||
require.Equal(t, user.Login, "loginuser1")
|
||||
|
||||
// get via user_auth
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, "loginuser1")
|
||||
require.Equal(t, user.Login, "loginuser1")
|
||||
|
||||
// remove user
|
||||
_, err = x.Exec("DELETE FROM "+dialect.Quote("user")+" WHERE id=?", query.Result.Id)
|
||||
require.Nil(t, err)
|
||||
srv.SQLStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
|
||||
sess.Exec("DELETE FROM "+srv.SQLStore.Dialect.Quote("user")+" WHERE id=?", user.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// get via user_auth for deleted user
|
||||
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
|
||||
require.Equal(t, err, models.ErrUserNotFound)
|
||||
require.Nil(t, query.Result)
|
||||
require.Nil(t, user)
|
||||
})
|
||||
|
||||
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
|
||||
@ -131,26 +142,26 @@ func TestUserAuth(t *testing.T) {
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
||||
err := GetUserByAuthInfo(query)
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
cmd := &models.UpdateAuthInfoCommand{
|
||||
UserId: query.Result.Id,
|
||||
UserId: user.Id,
|
||||
AuthId: query.AuthId,
|
||||
AuthModule: query.AuthModule,
|
||||
OAuthToken: token,
|
||||
}
|
||||
err = UpdateAuthInfo(cmd)
|
||||
err = srv.UpdateAuthInfo(cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
getAuthQuery := &models.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
|
||||
@ -160,7 +171,7 @@ func TestUserAuth(t *testing.T) {
|
||||
|
||||
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
||||
// Restore after destructive operation
|
||||
sqlStore = InitTestDB(t)
|
||||
sqlStore = sqlstore.InitTestDB(t)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := models.CreateUserCommand{
|
||||
@ -175,48 +186,48 @@ func TestUserAuth(t *testing.T) {
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
|
||||
err := GetUserByAuthInfo(query)
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// Add a second auth module for this user
|
||||
// Have this module's last log-in be more recent
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery := &models.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
|
||||
err = UpdateAuthInfo(updateAuthCmd)
|
||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: user.Id, AuthModule: "test1", AuthId: "test1"}
|
||||
err = srv.UpdateAuthInfo(updateAuthCmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery = &models.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
UserId: user.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
err = srv.GetAuthInfo(getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
|
||||
@ -229,20 +240,20 @@ func TestUserAuth(t *testing.T) {
|
||||
// Expect to pass since there's a matching login user
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
|
||||
err := GetUserByAuthInfo(query)
|
||||
user, err := srv.LookupAndUpdate(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
require.Equal(t, user.Login, login)
|
||||
|
||||
// Should throw a "user not found" error since there's no matching login user
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
|
||||
err = GetUserByAuthInfo(query)
|
||||
user, err = srv.LookupAndUpdate(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.NotNil(t, err)
|
||||
require.Nil(t, query.Result)
|
||||
require.Nil(t, user)
|
||||
})
|
||||
})
|
||||
}
|
21
pkg/services/login/authinfoservice/userprotection.go
Normal file
21
pkg/services/login/authinfoservice/userprotection.go
Normal file
@ -0,0 +1,21 @@
|
||||
package authinfoservice
|
||||
|
||||
import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterService(&OSSUserProtectionImpl{})
|
||||
}
|
||||
|
||||
type OSSUserProtectionImpl struct {
|
||||
}
|
||||
|
||||
func (OSSUserProtectionImpl) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (OSSUserProtectionImpl) AllowUserMapping(_ *models.User, _ string) error {
|
||||
return nil
|
||||
}
|
@ -22,10 +22,11 @@ var (
|
||||
)
|
||||
|
||||
type Implementation struct {
|
||||
SQLStore *sqlstore.SQLStore `inject:""`
|
||||
Bus bus.Bus `inject:""`
|
||||
QuotaService *quota.QuotaService `inject:""`
|
||||
TeamSync login.TeamSyncFunc
|
||||
SQLStore *sqlstore.SQLStore `inject:""`
|
||||
Bus bus.Bus `inject:""`
|
||||
AuthInfoService login.AuthInfoService `inject:""`
|
||||
QuotaService *quota.QuotaService `inject:""`
|
||||
TeamSync login.TeamSyncFunc
|
||||
}
|
||||
|
||||
func (ls *Implementation) Init() error {
|
||||
@ -43,14 +44,14 @@ func (ls *Implementation) CreateUser(cmd models.CreateUserCommand) (*models.User
|
||||
func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
extUser := cmd.ExternalUser
|
||||
|
||||
userQuery := &models.GetUserByAuthInfoQuery{
|
||||
user, err := ls.AuthInfoService.LookupAndUpdate(&models.GetUserByAuthInfoQuery{
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
UserId: extUser.UserId,
|
||||
Email: extUser.Email,
|
||||
Login: extUser.Login,
|
||||
}
|
||||
if err := bus.Dispatch(userQuery); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
@ -85,7 +86,7 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cmd.Result = userQuery.Result
|
||||
cmd.Result = user
|
||||
|
||||
err = updateUser(cmd.Result, extUser)
|
||||
if err != nil {
|
||||
@ -100,7 +101,7 @@ func (ls *Implementation) UpsertUser(cmd *models.UpsertUserCommand) error {
|
||||
}
|
||||
}
|
||||
|
||||
if extUser.AuthModule == models.AuthModuleLDAP && userQuery.Result.IsDisabled {
|
||||
if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled {
|
||||
// Re-enable user when it found in LDAP
|
||||
if err := ls.Bus.Dispatch(&models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil {
|
||||
return err
|
||||
|
@ -76,10 +76,22 @@ func Test_syncOrgRoles_whenTryingToRemoveLastOrgLogsError(t *testing.T) {
|
||||
assert.Contains(t, logs, models.ErrLastOrgAdmin.Error())
|
||||
}
|
||||
|
||||
type authInfoServiceMock struct {
|
||||
user *models.User
|
||||
err error
|
||||
}
|
||||
|
||||
func (a *authInfoServiceMock) LookupAndUpdate(query *models.GetUserByAuthInfoQuery) (*models.User, error) {
|
||||
return a.user, a.err
|
||||
}
|
||||
|
||||
func Test_teamSync(t *testing.T) {
|
||||
b := bus.New()
|
||||
authInfoMock := &authInfoServiceMock{}
|
||||
login := Implementation{
|
||||
Bus: bus.New(),
|
||||
QuotaService: "a.QuotaService{},
|
||||
Bus: b,
|
||||
QuotaService: "a.QuotaService{},
|
||||
AuthInfoService: authInfoMock,
|
||||
}
|
||||
|
||||
upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: "test_user@example.org"}}
|
||||
@ -89,13 +101,9 @@ func Test_teamSync(t *testing.T) {
|
||||
Name: "test_user",
|
||||
Login: "test_user",
|
||||
}
|
||||
|
||||
authInfoMock.user = expectedUser
|
||||
bus.ClearBusHandlers()
|
||||
t.Cleanup(func() { bus.ClearBusHandlers() })
|
||||
bus.AddHandler("test", func(query *models.GetUserByAuthInfoQuery) error {
|
||||
query.Result = expectedUser
|
||||
return nil
|
||||
})
|
||||
|
||||
var actualUser *models.User
|
||||
var actualExternalUser *models.ExternalUserInfo
|
||||
|
7
pkg/services/login/userprotection.go
Normal file
7
pkg/services/login/userprotection.go
Normal file
@ -0,0 +1,7 @@
|
||||
package login
|
||||
|
||||
import "github.com/grafana/grafana/pkg/models"
|
||||
|
||||
type UserProtectionService interface {
|
||||
AllowUserMapping(user *models.User, authModule string) error
|
||||
}
|
@ -90,8 +90,6 @@ func TestTeamCommandsAndQueries(t *testing.T) {
|
||||
|
||||
Convey("Should return latest auth module for users when getting team members", func() {
|
||||
userId := userIds[1]
|
||||
err := SetAuthInfo(&models.SetAuthInfoCommand{UserId: userId, AuthModule: "oauth_github", AuthId: "1234567"})
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
teamQuery := &models.SearchTeamsQuery{OrgId: testOrgID, Name: "group1 name", Page: 1, Limit: 10}
|
||||
err = SearchTeams(teamQuery)
|
||||
@ -111,7 +109,6 @@ func TestTeamCommandsAndQueries(t *testing.T) {
|
||||
So(memberQuery.Result[0].Login, ShouldEqual, "loginuser1")
|
||||
So(memberQuery.Result[0].OrgId, ShouldEqual, testOrgID)
|
||||
So(memberQuery.Result[0].External, ShouldEqual, true)
|
||||
So(memberQuery.Result[0].AuthModule, ShouldEqual, "oauth_github")
|
||||
})
|
||||
|
||||
Convey("Should be able to update users in a team", func() {
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
@ -119,7 +118,7 @@ func TestUserDataAccess(t *testing.T) {
|
||||
t.Run("Testing DB - multiple users", func(t *testing.T) {
|
||||
ss = InitTestDB(t)
|
||||
|
||||
users := createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
||||
createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
||||
return &models.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
@ -188,48 +187,6 @@ func TestUserDataAccess(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.Len(t, query.Result.Users, 1)
|
||||
require.EqualValues(t, query.Result.TotalCount, 1)
|
||||
|
||||
// Return list users based on their auth type
|
||||
for index, user := range users {
|
||||
authModule := "killa"
|
||||
|
||||
// define every second user as ldap
|
||||
if index%2 == 0 {
|
||||
authModule = "ldap"
|
||||
}
|
||||
|
||||
cmd2 := &models.SetAuthInfoCommand{
|
||||
UserId: user.Id,
|
||||
AuthModule: authModule,
|
||||
AuthId: "gorilla",
|
||||
}
|
||||
err := SetAuthInfo(cmd2)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
query = models.SearchUsersQuery{AuthModule: "ldap"}
|
||||
err = SearchUsers(&query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, query.Result.Users, 3)
|
||||
|
||||
zero, second, fourth := false, false, false
|
||||
for _, user := range query.Result.Users {
|
||||
if user.Name == "user0" {
|
||||
zero = true
|
||||
}
|
||||
|
||||
if user.Name == "user2" {
|
||||
second = true
|
||||
}
|
||||
|
||||
if user.Name == "user4" {
|
||||
fourth = true
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, zero)
|
||||
require.True(t, second)
|
||||
require.True(t, fourth)
|
||||
})
|
||||
|
||||
t.Run("Testing DB - return list users based on their is_disabled flag", func(t *testing.T) {
|
||||
@ -490,107 +447,6 @@ func TestUserDataAccess(t *testing.T) {
|
||||
IsDisabled: false,
|
||||
}
|
||||
})
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "ldap", AuthId: "ldap0"}
|
||||
err := GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
|
||||
// Add a second auth module for this user
|
||||
// Have this module's last log-in be more recent
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "oauth", AuthId: "oauth0"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
|
||||
// Return the only most recently used auth_module
|
||||
searchUserQuery := &models.SearchUsersQuery{}
|
||||
err = SearchUsers(searchUserQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Len(t, searchUserQuery.Result.Users, 5)
|
||||
for _, user := range searchUserQuery.Result.Users {
|
||||
if user.Login == login {
|
||||
require.Len(t, user.AuthModule, 1)
|
||||
require.Equal(t, user.AuthModule[0], "oauth")
|
||||
}
|
||||
}
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "ldap", AuthId: "ldap1"}
|
||||
err = UpdateAuthInfo(updateAuthCmd)
|
||||
require.Nil(t, err)
|
||||
|
||||
searchUserQuery = &models.SearchUsersQuery{}
|
||||
err = SearchUsers(searchUserQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
for _, user := range searchUserQuery.Result.Users {
|
||||
if user.Login == login {
|
||||
require.Len(t, user.AuthModule, 1)
|
||||
require.Equal(t, user.AuthModule[0], "ldap")
|
||||
}
|
||||
}
|
||||
|
||||
// Re-init DB
|
||||
ss = InitTestDB(t)
|
||||
createFiveTestUsers(t, ss, func(i int) *models.CreateUserCommand {
|
||||
return &models.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
IsDisabled: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Search LDAP users
|
||||
for i := 0; i < 5; i++ {
|
||||
// Find a user to set tokens on
|
||||
login = fmt.Sprint("loginuser", i)
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "ldap", AuthId: fmt.Sprint("ldap", i)}
|
||||
err = GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
}
|
||||
|
||||
// Log in first user with oauth
|
||||
login = "loginuser0"
|
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "oauth", AuthId: "oauth0"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
getTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, query.Result.Login, login)
|
||||
|
||||
// Should only return users recently logged in with ldap when filtered by ldap auth module
|
||||
searchUserQuery = &models.SearchUsersQuery{AuthModule: "ldap"}
|
||||
err = SearchUsers(searchUserQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Len(t, searchUserQuery.Result.Users, 4)
|
||||
for _, user := range searchUserQuery.Result.Users {
|
||||
if user.Login == login {
|
||||
require.Len(t, user.AuthModule, 1)
|
||||
require.Equal(t, user.AuthModule[0], "ldap")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Testing DB - grafana admin users", func(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user