mirror of
https://github.com/grafana/grafana.git
synced 2025-01-27 00:37:04 -06:00
Chore: Remove GetUserByEmail and GetUserByLogin from sqlstore (#55903)
* Chore: Remove GetUserByEmail and GetUserByLogin from sqlstore Rename GetUserProfile to GetProfile * Fix lint * Skip test for mysql * Add missing method to sqlstore mock
This commit is contained in:
parent
a8f43b97a2
commit
a45ef61d25
@ -716,7 +716,7 @@ func TestOrgUsersAPIEndpointWithSetPerms_AccessControl(t *testing.T) {
|
||||
sc := setupHTTPServer(t, true, func(hs *HTTPServer) {
|
||||
hs.tempUserService = tempuserimpl.ProvideService(hs.SQLStore)
|
||||
hs.userService = userimpl.ProvideService(
|
||||
hs.SQLStore, nil, nil, hs.SQLStore.(*sqlstore.SQLStore),
|
||||
hs.SQLStore, nil, setting.NewCfg(), hs.SQLStore.(*sqlstore.SQLStore),
|
||||
)
|
||||
})
|
||||
setInitCtxSignedInViewer(sc.initCtx)
|
||||
|
@ -51,7 +51,7 @@ func (hs *HTTPServer) GetUserByID(c *models.ReqContext) response.Response {
|
||||
func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) response.Response {
|
||||
query := user.GetUserProfileQuery{UserID: userID}
|
||||
|
||||
userProfile, err := hs.userService.GetUserProfile(c.Req.Context(), &query)
|
||||
userProfile, err := hs.userService.GetProfile(c.Req.Context(), &query)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return response.Error(404, user.ErrUserNotFound.Error(), nil)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/session"
|
||||
"xorm.io/core"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
@ -13,6 +14,7 @@ type DB interface {
|
||||
WithDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error
|
||||
NewSession(ctx context.Context) *sqlstore.DBSession
|
||||
GetDialect() migrator.Dialect
|
||||
GetDBType() core.DbType
|
||||
GetSqlxSession() *session.SessionDB
|
||||
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/session"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"xorm.io/core"
|
||||
)
|
||||
|
||||
type OrgListResponse []struct {
|
||||
@ -76,6 +77,10 @@ func (m *SQLStoreMock) GetDialect() migrator.Dialect {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SQLStoreMock) GetDBType() core.DbType {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *SQLStoreMock) HasEditPermissionInFolders(ctx context.Context, query *models.HasEditPermissionInFoldersQuery) error {
|
||||
return m.ExpectedError
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"xorm.io/core"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -171,6 +172,10 @@ func (ss *SQLStore) GetDialect() migrator.Dialect {
|
||||
return ss.Dialect
|
||||
}
|
||||
|
||||
func (ss *SQLStore) GetDBType() core.DbType {
|
||||
return ss.engine.Dialect().DBType()
|
||||
}
|
||||
|
||||
func (ss *SQLStore) Bus() bus.Bus {
|
||||
return ss.bus
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/session"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"xorm.io/core"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
@ -15,6 +16,7 @@ type Store interface {
|
||||
GetDataSourceStats(ctx context.Context, query *models.GetDataSourceStatsQuery) error
|
||||
GetDataSourceAccessStats(ctx context.Context, query *models.GetDataSourceAccessStatsQuery) error
|
||||
GetDialect() migrator.Dialect
|
||||
GetDBType() core.DbType
|
||||
GetSystemStats(ctx context.Context, query *models.GetSystemStatsQuery) error
|
||||
GetOrgByName(name string) (*models.Org, error)
|
||||
CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error
|
||||
|
@ -201,87 +201,6 @@ func (ss *SQLStore) GetUserById(ctx context.Context, query *models.GetUserByIdQu
|
||||
})
|
||||
}
|
||||
|
||||
func (ss *SQLStore) GetUserByLogin(ctx context.Context, query *models.GetUserByLoginQuery) error {
|
||||
return ss.WithDbSession(ctx, func(sess *DBSession) error {
|
||||
if query.LoginOrEmail == "" {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
// Try and find the user by login first.
|
||||
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
|
||||
usr := &user.User{}
|
||||
where := "login=?"
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(login)=LOWER(?)"
|
||||
}
|
||||
|
||||
has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !has && strings.Contains(query.LoginOrEmail, "@") {
|
||||
// If the user wasn't found, and it contains an "@" fallback to finding the
|
||||
// user by email.
|
||||
|
||||
where = "email=?"
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(email)=LOWER(?)"
|
||||
}
|
||||
usr = &user.User{}
|
||||
has, err = sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
query.Result = usr
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (ss *SQLStore) GetUserByEmail(ctx context.Context, query *models.GetUserByEmailQuery) error {
|
||||
return ss.WithDbSession(ctx, func(sess *DBSession) error {
|
||||
if query.Email == "" {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
usr := &user.User{}
|
||||
where := "email=?"
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(email)=LOWER(?)"
|
||||
}
|
||||
|
||||
has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.Email).Get(usr)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
query.Result = usr
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (ss *SQLStore) UpdateUser(ctx context.Context, cmd *models.UpdateUserCommand) error {
|
||||
if ss.Cfg.CaseInsensitiveLogin {
|
||||
cmd.Login = strings.ToLower(cmd.Login)
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -90,82 +89,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
|
||||
Permissions: map[int64]map[string][]string{1: {"users:read": {"global.users:*"}}},
|
||||
}
|
||||
|
||||
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: "usertest@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login",
|
||||
}
|
||||
user, err := ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
query := models.GetUserByIdQuery{Id: user.ID}
|
||||
err = ss.GetUserById(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, query.Result.Email, "usertest@test.com")
|
||||
require.Equal(t, query.Result.Password, "")
|
||||
require.Len(t, query.Result.Rands, 10)
|
||||
require.Len(t, query.Result.Salt, 10)
|
||||
require.False(t, query.Result.IsDisabled)
|
||||
|
||||
query = models.GetUserByIdQuery{Id: user.ID}
|
||||
err = ss.GetUserById(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, query.Result.Email, "usertest@test.com")
|
||||
require.Equal(t, query.Result.Password, "")
|
||||
require.Len(t, query.Result.Rands, 10)
|
||||
require.Len(t, query.Result.Salt, 10)
|
||||
require.False(t, query.Result.IsDisabled)
|
||||
|
||||
t.Run("Get User by email case insensitive", func(t *testing.T) {
|
||||
ss.Cfg.CaseInsensitiveLogin = true
|
||||
query := models.GetUserByEmailQuery{Email: "USERtest@TEST.COM"}
|
||||
err = ss.GetUserByEmail(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, query.Result.Email, "usertest@test.com")
|
||||
require.Equal(t, query.Result.Password, "")
|
||||
require.Len(t, query.Result.Rands, 10)
|
||||
require.Len(t, query.Result.Salt, 10)
|
||||
require.False(t, query.Result.IsDisabled)
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
|
||||
t.Run("Get User by login - case insensitive", func(t *testing.T) {
|
||||
ss.Cfg.CaseInsensitiveLogin = true
|
||||
|
||||
query := models.GetUserByLoginQuery{LoginOrEmail: "USER_test_login"}
|
||||
err = ss.GetUserByLogin(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, query.Result.Email, "usertest@test.com")
|
||||
require.Equal(t, query.Result.Password, "")
|
||||
require.Len(t, query.Result.Rands, 10)
|
||||
require.Len(t, query.Result.Salt, 10)
|
||||
require.False(t, query.Result.IsDisabled)
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
|
||||
t.Run("Get User by login - email fallback case insensitive", func(t *testing.T) {
|
||||
ss.Cfg.CaseInsensitiveLogin = true
|
||||
query := models.GetUserByLoginQuery{LoginOrEmail: "USERtest@TEST.COM"}
|
||||
err = ss.GetUserByLogin(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, query.Result.Email, "usertest@test.com")
|
||||
require.Equal(t, query.Result.Password, "")
|
||||
require.Len(t, query.Result.Rands, 10)
|
||||
require.Len(t, query.Result.Salt, 10)
|
||||
require.False(t, query.Result.IsDisabled)
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Testing DB - creates and loads disabled user", func(t *testing.T) {
|
||||
ss = InitTestDB(t)
|
||||
cmd := user.CreateUserCommand{
|
||||
@ -475,90 +398,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
|
||||
assert.Len(t, query.Result.Users, 2)
|
||||
})
|
||||
|
||||
t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) {
|
||||
if ss.engine.Dialect().DBType() == migrator.MySQL {
|
||||
t.Skip("Skipping on MySQL due to case insensitive indexes")
|
||||
}
|
||||
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: "confusertest@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_email_conflict",
|
||||
}
|
||||
userEmailConflict, err := ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "confusertest@TEST.COM",
|
||||
Name: "user name",
|
||||
Login: "user_email_conflict_two",
|
||||
}
|
||||
_, err = ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "user_test_login_conflict@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login_conflict",
|
||||
}
|
||||
userLoginConflict, err := ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "user_test_login_conflict_two@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login_CONFLICT",
|
||||
}
|
||||
_, err = ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = true
|
||||
|
||||
t.Run("GetUserByEmail - email conflict", func(t *testing.T) {
|
||||
query := models.GetUserByEmailQuery{Email: "confusertest@test.com"}
|
||||
err = ss.GetUserByEmail(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByEmail - login conflict", func(t *testing.T) {
|
||||
query := models.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"}
|
||||
err = ss.GetUserByEmail(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByID - email conflict", func(t *testing.T) {
|
||||
query := models.GetUserByIdQuery{Id: userEmailConflict.ID}
|
||||
err = ss.GetUserById(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByID - login conflict", func(t *testing.T) {
|
||||
query := models.GetUserByIdQuery{Id: userLoginConflict.ID}
|
||||
err = ss.GetUserById(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByLogin - email conflict", func(t *testing.T) {
|
||||
query := models.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"}
|
||||
err = ss.GetUserByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByLogin - login conflict", func(t *testing.T) {
|
||||
query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"}
|
||||
err = ss.GetUserByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetUserByLogin - login conflict by email", func(t *testing.T) {
|
||||
query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"}
|
||||
err = ss.GetUserByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
|
||||
ss = InitTestDB(t)
|
||||
|
||||
t.Run("Testing DB - enable all users", func(t *testing.T) {
|
||||
|
@ -21,5 +21,5 @@ type Service interface {
|
||||
BatchDisableUsers(context.Context, *BatchDisableUsersCommand) error
|
||||
UpdatePermissions(int64, bool) error
|
||||
SetUserHelpFlag(context.Context, *SetUserHelpFlagCommand) error
|
||||
GetUserProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error)
|
||||
GetProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error)
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package userimpl
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/events"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
@ -20,6 +21,8 @@ type store interface {
|
||||
GetNotServiceAccount(context.Context, int64) (*user.User, error)
|
||||
Delete(context.Context, int64) error
|
||||
CaseInsensitiveLoginConflict(context.Context, string, string) error
|
||||
GetByLogin(context.Context, *user.GetUserByLoginQuery) (*user.User, error)
|
||||
GetByEmail(context.Context, *user.GetUserByEmailQuery) (*user.User, error)
|
||||
}
|
||||
|
||||
type sqlStore struct {
|
||||
@ -145,3 +148,101 @@ func (ss *sqlStore) CaseInsensitiveLoginConflict(ctx context.Context, login, ema
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
|
||||
usr := &user.User{}
|
||||
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||
if query.LoginOrEmail == "" {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
// Try and find the user by login first.
|
||||
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
|
||||
where := "login=?"
|
||||
if ss.cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(login)=LOWER(?)"
|
||||
}
|
||||
|
||||
has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !has && strings.Contains(query.LoginOrEmail, "@") {
|
||||
// If the user wasn't found, and it contains an "@" fallback to finding the
|
||||
// user by email.
|
||||
|
||||
where = "email=?"
|
||||
if ss.cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(email)=LOWER(?)"
|
||||
}
|
||||
usr = &user.User{}
|
||||
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
if ss.cfg.CaseInsensitiveLogin {
|
||||
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
|
||||
usr := &user.User{}
|
||||
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||
if query.Email == "" {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
where := "email=?"
|
||||
if ss.cfg.CaseInsensitiveLogin {
|
||||
where = "LOWER(email)=LOWER(?)"
|
||||
}
|
||||
|
||||
has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.Email).Get(usr)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user.ErrUserNotFound
|
||||
}
|
||||
|
||||
if ss.cfg.CaseInsensitiveLogin {
|
||||
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (ss *sqlStore) userCaseInsensitiveLoginConflict(ctx context.Context, sess *sqlstore.DBSession, login, email string) error {
|
||||
users := make([]user.User, 0)
|
||||
|
||||
if err := sess.Where("LOWER(email)=LOWER(?) OR LOWER(login)=LOWER(?)",
|
||||
email, login).Find(&users); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(users) > 1 {
|
||||
return &user.ErrCaseInsensitiveLoginConflict{Users: users}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
@ -54,4 +55,143 @@ func TestIntegrationUserDataAccess(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
|
||||
ss := sqlstore.InitTestDB(t)
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: "usertest@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login",
|
||||
}
|
||||
usr, err := ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := userStore.GetByID(context.Background(), usr.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, result.Email, "usertest@test.com")
|
||||
require.Equal(t, result.Password, "")
|
||||
require.Len(t, result.Rands, 10)
|
||||
require.Len(t, result.Salt, 10)
|
||||
require.False(t, result.IsDisabled)
|
||||
|
||||
result, err = userStore.GetByID(context.Background(), usr.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, result.Email, "usertest@test.com")
|
||||
require.Equal(t, result.Password, "")
|
||||
require.Len(t, result.Rands, 10)
|
||||
require.Len(t, result.Salt, 10)
|
||||
require.False(t, result.IsDisabled)
|
||||
|
||||
t.Run("Get User by email case insensitive", func(t *testing.T) {
|
||||
userStore.cfg.CaseInsensitiveLogin = true
|
||||
query := user.GetUserByEmailQuery{Email: "USERtest@TEST.COM"}
|
||||
result, err := userStore.GetByEmail(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, result.Email, "usertest@test.com")
|
||||
require.Equal(t, result.Password, "")
|
||||
require.Len(t, result.Rands, 10)
|
||||
require.Len(t, result.Salt, 10)
|
||||
require.False(t, result.IsDisabled)
|
||||
|
||||
userStore.cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
|
||||
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
|
||||
result, err = userStore.GetByID(context.Background(), usr.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, result.Email, "usertest@test.com")
|
||||
require.Equal(t, result.Password, "")
|
||||
require.Len(t, result.Rands, 10)
|
||||
require.Len(t, result.Salt, 10)
|
||||
require.False(t, result.IsDisabled)
|
||||
|
||||
result, err = userStore.GetByID(context.Background(), usr.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, result.Email, "usertest@test.com")
|
||||
require.Equal(t, result.Password, "")
|
||||
require.Len(t, result.Rands, 10)
|
||||
require.Len(t, result.Salt, 10)
|
||||
require.False(t, result.IsDisabled)
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) {
|
||||
if ss.GetDBType() == migrator.MySQL {
|
||||
t.Skip("Skipping on MySQL due to case insensitive indexes")
|
||||
}
|
||||
userStore.cfg.CaseInsensitiveLogin = true
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: "confusertest@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_email_conflict",
|
||||
}
|
||||
// userEmailConflict
|
||||
_, err := ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "confusertest@TEST.COM",
|
||||
Name: "user name",
|
||||
Login: "user_email_conflict_two",
|
||||
}
|
||||
_, err = ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "user_test_login_conflict@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login_conflict",
|
||||
}
|
||||
// userLoginConflict
|
||||
_, err = ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd = user.CreateUserCommand{
|
||||
Email: "user_test_login_conflict_two@test.com",
|
||||
Name: "user name",
|
||||
Login: "user_test_login_CONFLICT",
|
||||
}
|
||||
_, err = ss.CreateUser(context.Background(), cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = true
|
||||
|
||||
t.Run("GetByEmail - email conflict", func(t *testing.T) {
|
||||
query := user.GetUserByEmailQuery{Email: "confusertest@test.com"}
|
||||
_, err = userStore.GetByEmail(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetByEmail - login conflict", func(t *testing.T) {
|
||||
query := user.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"}
|
||||
_, err = userStore.GetByEmail(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetByLogin - email conflict", func(t *testing.T) {
|
||||
query := user.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"}
|
||||
_, err = userStore.GetByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetByLogin - login conflict", func(t *testing.T) {
|
||||
query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"}
|
||||
_, err = userStore.GetByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetByLogin - login conflict by email", func(t *testing.T) {
|
||||
query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"}
|
||||
_, err = userStore.GetByLogin(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
ss.Cfg.CaseInsensitiveLogin = false
|
||||
})
|
||||
}
|
||||
|
@ -152,24 +152,12 @@ func (s *Service) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*u
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// TODO: remove wrapper around sqlstore
|
||||
func (s *Service) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
|
||||
q := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
||||
err := s.sqlStore.GetUserByLogin(ctx, &q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.Result, nil
|
||||
return s.store.GetByLogin(ctx, query)
|
||||
}
|
||||
|
||||
// TODO: remove wrapper around sqlstore
|
||||
func (s *Service) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
|
||||
q := models.GetUserByEmailQuery{Email: query.Email}
|
||||
err := s.sqlStore.GetUserByEmail(ctx, &q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.Result, nil
|
||||
return s.store.GetByEmail(ctx, query)
|
||||
}
|
||||
|
||||
// TODO: remove wrapper around sqlstore
|
||||
@ -313,7 +301,7 @@ func (s *Service) SetUserHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlag
|
||||
}
|
||||
|
||||
// TODO: remove wrapper around sqlstore
|
||||
func (s *Service) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
|
||||
func (s *Service) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
|
||||
q := &models.GetUserProfileQuery{
|
||||
UserId: query.UserID,
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package userimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
||||
@ -77,6 +78,14 @@ func TestUserService(t *testing.T) {
|
||||
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetByID - email conflict", func(t *testing.T) {
|
||||
userService.cfg.CaseInsensitiveLogin = true
|
||||
userStore.ExpectedError = errors.New("email conflict")
|
||||
query := user.GetUserByIDQuery{}
|
||||
_, err := userService.GetByID(context.Background(), &query)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type FakeUserStore struct {
|
||||
@ -112,3 +121,11 @@ func (f *FakeUserStore) GetByID(context.Context, int64) (*user.User, error) {
|
||||
func (f *FakeUserStore) CaseInsensitiveLoginConflict(context.Context, string, string) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeUserStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeUserStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
@ -91,6 +91,6 @@ func (f *FakeUserService) SetUserHelpFlag(ctx context.Context, cmd *user.SetUser
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeUserService) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
|
||||
func (f *FakeUserService) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
|
||||
return f.ExpectedUSerProfileDTO, f.ExpectedError
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user