Chore: Remove GetUserByEmail and GetUserByLogin from sqlstore (#55903)

* Chore: Remove GetUserByEmail and GetUserByLogin from sqlstore
 Rename GetUserProfile to GetProfile

* Fix lint

* Skip test for mysql

* Add missing method to sqlstore mock
This commit is contained in:
idafurjes 2022-09-28 13:18:19 +02:00 committed by GitHub
parent a8f43b97a2
commit a45ef61d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 279 additions and 261 deletions

View File

@ -716,7 +716,7 @@ func TestOrgUsersAPIEndpointWithSetPerms_AccessControl(t *testing.T) {
sc := setupHTTPServer(t, true, func(hs *HTTPServer) {
hs.tempUserService = tempuserimpl.ProvideService(hs.SQLStore)
hs.userService = userimpl.ProvideService(
hs.SQLStore, nil, nil, hs.SQLStore.(*sqlstore.SQLStore),
hs.SQLStore, nil, setting.NewCfg(), hs.SQLStore.(*sqlstore.SQLStore),
)
})
setInitCtxSignedInViewer(sc.initCtx)

View File

@ -51,7 +51,7 @@ func (hs *HTTPServer) GetUserByID(c *models.ReqContext) response.Response {
func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) response.Response {
query := user.GetUserProfileQuery{UserID: userID}
userProfile, err := hs.userService.GetUserProfile(c.Req.Context(), &query)
userProfile, err := hs.userService.GetProfile(c.Req.Context(), &query)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, user.ErrUserNotFound.Error(), nil)

View File

@ -6,6 +6,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"xorm.io/core"
)
type DB interface {
@ -13,6 +14,7 @@ type DB interface {
WithDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error
NewSession(ctx context.Context) *sqlstore.DBSession
GetDialect() migrator.Dialect
GetDBType() core.DbType
GetSqlxSession() *session.SessionDB
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
}

View File

@ -9,6 +9,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/services/user"
"xorm.io/core"
)
type OrgListResponse []struct {
@ -76,6 +77,10 @@ func (m *SQLStoreMock) GetDialect() migrator.Dialect {
return nil
}
func (m *SQLStoreMock) GetDBType() core.DbType {
return ""
}
func (m *SQLStoreMock) HasEditPermissionInFolders(ctx context.Context, query *models.HasEditPermissionInFoldersQuery) error {
return m.ExpectedError
}

View File

@ -16,6 +16,7 @@ import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"xorm.io/core"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/bus"
@ -171,6 +172,10 @@ func (ss *SQLStore) GetDialect() migrator.Dialect {
return ss.Dialect
}
func (ss *SQLStore) GetDBType() core.DbType {
return ss.engine.Dialect().DBType()
}
func (ss *SQLStore) Bus() bus.Bus {
return ss.bus
}

View File

@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/services/user"
"xorm.io/core"
)
type Store interface {
@ -15,6 +16,7 @@ type Store interface {
GetDataSourceStats(ctx context.Context, query *models.GetDataSourceStatsQuery) error
GetDataSourceAccessStats(ctx context.Context, query *models.GetDataSourceAccessStatsQuery) error
GetDialect() migrator.Dialect
GetDBType() core.DbType
GetSystemStats(ctx context.Context, query *models.GetSystemStatsQuery) error
GetOrgByName(name string) (*models.Org, error)
CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error

View File

@ -201,87 +201,6 @@ func (ss *SQLStore) GetUserById(ctx context.Context, query *models.GetUserByIdQu
})
}
func (ss *SQLStore) GetUserByLogin(ctx context.Context, query *models.GetUserByLoginQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
if query.LoginOrEmail == "" {
return user.ErrUserNotFound
}
// Try and find the user by login first.
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
usr := &user.User{}
where := "login=?"
if ss.Cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}
has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr)
if err != nil {
return err
}
if !has && strings.Contains(query.LoginOrEmail, "@") {
// If the user wasn't found, and it contains an "@" fallback to finding the
// user by email.
where = "email=?"
if ss.Cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
usr = &user.User{}
has, err = sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr)
}
if err != nil {
return err
} else if !has {
return user.ErrUserNotFound
}
if ss.Cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
}
}
query.Result = usr
return nil
})
}
func (ss *SQLStore) GetUserByEmail(ctx context.Context, query *models.GetUserByEmailQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
if query.Email == "" {
return user.ErrUserNotFound
}
usr := &user.User{}
where := "email=?"
if ss.Cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.Email).Get(usr)
if err != nil {
return err
} else if !has {
return user.ErrUserNotFound
}
if ss.Cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
}
}
query.Result = usr
return nil
})
}
func (ss *SQLStore) UpdateUser(ctx context.Context, cmd *models.UpdateUserCommand) error {
if ss.Cfg.CaseInsensitiveLogin {
cmd.Login = strings.ToLower(cmd.Login)

View File

@ -7,7 +7,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -90,82 +89,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
Permissions: map[int64]map[string][]string{1: {"users:read": {"global.users:*"}}},
}
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
cmd := user.CreateUserCommand{
Email: "usertest@test.com",
Name: "user name",
Login: "user_test_login",
}
user, err := ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
query := models.GetUserByIdQuery{Id: user.ID}
err = ss.GetUserById(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, query.Result.Email, "usertest@test.com")
require.Equal(t, query.Result.Password, "")
require.Len(t, query.Result.Rands, 10)
require.Len(t, query.Result.Salt, 10)
require.False(t, query.Result.IsDisabled)
query = models.GetUserByIdQuery{Id: user.ID}
err = ss.GetUserById(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, query.Result.Email, "usertest@test.com")
require.Equal(t, query.Result.Password, "")
require.Len(t, query.Result.Rands, 10)
require.Len(t, query.Result.Salt, 10)
require.False(t, query.Result.IsDisabled)
t.Run("Get User by email case insensitive", func(t *testing.T) {
ss.Cfg.CaseInsensitiveLogin = true
query := models.GetUserByEmailQuery{Email: "USERtest@TEST.COM"}
err = ss.GetUserByEmail(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, query.Result.Email, "usertest@test.com")
require.Equal(t, query.Result.Password, "")
require.Len(t, query.Result.Rands, 10)
require.Len(t, query.Result.Salt, 10)
require.False(t, query.Result.IsDisabled)
ss.Cfg.CaseInsensitiveLogin = false
})
t.Run("Get User by login - case insensitive", func(t *testing.T) {
ss.Cfg.CaseInsensitiveLogin = true
query := models.GetUserByLoginQuery{LoginOrEmail: "USER_test_login"}
err = ss.GetUserByLogin(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, query.Result.Email, "usertest@test.com")
require.Equal(t, query.Result.Password, "")
require.Len(t, query.Result.Rands, 10)
require.Len(t, query.Result.Salt, 10)
require.False(t, query.Result.IsDisabled)
ss.Cfg.CaseInsensitiveLogin = false
})
t.Run("Get User by login - email fallback case insensitive", func(t *testing.T) {
ss.Cfg.CaseInsensitiveLogin = true
query := models.GetUserByLoginQuery{LoginOrEmail: "USERtest@TEST.COM"}
err = ss.GetUserByLogin(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, query.Result.Email, "usertest@test.com")
require.Equal(t, query.Result.Password, "")
require.Len(t, query.Result.Rands, 10)
require.Len(t, query.Result.Salt, 10)
require.False(t, query.Result.IsDisabled)
ss.Cfg.CaseInsensitiveLogin = false
})
})
t.Run("Testing DB - creates and loads disabled user", func(t *testing.T) {
ss = InitTestDB(t)
cmd := user.CreateUserCommand{
@ -475,90 +398,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
assert.Len(t, query.Result.Users, 2)
})
t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) {
if ss.engine.Dialect().DBType() == migrator.MySQL {
t.Skip("Skipping on MySQL due to case insensitive indexes")
}
cmd := user.CreateUserCommand{
Email: "confusertest@test.com",
Name: "user name",
Login: "user_email_conflict",
}
userEmailConflict, err := ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "confusertest@TEST.COM",
Name: "user name",
Login: "user_email_conflict_two",
}
_, err = ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "user_test_login_conflict@test.com",
Name: "user name",
Login: "user_test_login_conflict",
}
userLoginConflict, err := ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "user_test_login_conflict_two@test.com",
Name: "user name",
Login: "user_test_login_CONFLICT",
}
_, err = ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
ss.Cfg.CaseInsensitiveLogin = true
t.Run("GetUserByEmail - email conflict", func(t *testing.T) {
query := models.GetUserByEmailQuery{Email: "confusertest@test.com"}
err = ss.GetUserByEmail(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByEmail - login conflict", func(t *testing.T) {
query := models.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"}
err = ss.GetUserByEmail(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByID - email conflict", func(t *testing.T) {
query := models.GetUserByIdQuery{Id: userEmailConflict.ID}
err = ss.GetUserById(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByID - login conflict", func(t *testing.T) {
query := models.GetUserByIdQuery{Id: userLoginConflict.ID}
err = ss.GetUserById(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByLogin - email conflict", func(t *testing.T) {
query := models.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"}
err = ss.GetUserByLogin(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByLogin - login conflict", func(t *testing.T) {
query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"}
err = ss.GetUserByLogin(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetUserByLogin - login conflict by email", func(t *testing.T) {
query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"}
err = ss.GetUserByLogin(context.Background(), &query)
require.Error(t, err)
})
ss.Cfg.CaseInsensitiveLogin = false
})
ss = InitTestDB(t)
t.Run("Testing DB - enable all users", func(t *testing.T) {

View File

@ -21,5 +21,5 @@ type Service interface {
BatchDisableUsers(context.Context, *BatchDisableUsersCommand) error
UpdatePermissions(int64, bool) error
SetUserHelpFlag(context.Context, *SetUserHelpFlagCommand) error
GetUserProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error)
GetProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error)
}

View File

@ -3,6 +3,7 @@ package userimpl
import (
"context"
"fmt"
"strings"
"github.com/grafana/grafana/pkg/events"
"github.com/grafana/grafana/pkg/infra/log"
@ -20,6 +21,8 @@ type store interface {
GetNotServiceAccount(context.Context, int64) (*user.User, error)
Delete(context.Context, int64) error
CaseInsensitiveLoginConflict(context.Context, string, string) error
GetByLogin(context.Context, *user.GetUserByLoginQuery) (*user.User, error)
GetByEmail(context.Context, *user.GetUserByEmailQuery) (*user.User, error)
}
type sqlStore struct {
@ -145,3 +148,101 @@ func (ss *sqlStore) CaseInsensitiveLoginConflict(ctx context.Context, login, ema
})
return err
}
func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
usr := &user.User{}
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
if query.LoginOrEmail == "" {
return user.ErrUserNotFound
}
// Try and find the user by login first.
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
where := "login=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}
has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
if err != nil {
return err
}
if !has && strings.Contains(query.LoginOrEmail, "@") {
// If the user wasn't found, and it contains an "@" fallback to finding the
// user by email.
where = "email=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
usr = &user.User{}
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
}
if err != nil {
return err
} else if !has {
return user.ErrUserNotFound
}
if ss.cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return usr, nil
}
func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
usr := &user.User{}
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
if query.Email == "" {
return user.ErrUserNotFound
}
where := "email=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.Email).Get(usr)
if err != nil {
return err
} else if !has {
return user.ErrUserNotFound
}
if ss.cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return usr, nil
}
func (ss *sqlStore) userCaseInsensitiveLoginConflict(ctx context.Context, sess *sqlstore.DBSession, login, email string) error {
users := make([]user.User, 0)
if err := sess.Where("LOWER(email)=LOWER(?) OR LOWER(login)=LOWER(?)",
email, login).Find(&users); err != nil {
return err
}
if len(users) > 1 {
return &user.ErrCaseInsensitiveLoginConflict{Users: users}
}
return nil
}

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
@ -54,4 +55,143 @@ func TestIntegrationUserDataAccess(t *testing.T) {
)
require.NoError(t, err)
})
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
ss := sqlstore.InitTestDB(t)
cmd := user.CreateUserCommand{
Email: "usertest@test.com",
Name: "user name",
Login: "user_test_login",
}
usr, err := ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
result, err := userStore.GetByID(context.Background(), usr.ID)
require.Nil(t, err)
require.Equal(t, result.Email, "usertest@test.com")
require.Equal(t, result.Password, "")
require.Len(t, result.Rands, 10)
require.Len(t, result.Salt, 10)
require.False(t, result.IsDisabled)
result, err = userStore.GetByID(context.Background(), usr.ID)
require.Nil(t, err)
require.Equal(t, result.Email, "usertest@test.com")
require.Equal(t, result.Password, "")
require.Len(t, result.Rands, 10)
require.Len(t, result.Salt, 10)
require.False(t, result.IsDisabled)
t.Run("Get User by email case insensitive", func(t *testing.T) {
userStore.cfg.CaseInsensitiveLogin = true
query := user.GetUserByEmailQuery{Email: "USERtest@TEST.COM"}
result, err := userStore.GetByEmail(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, result.Email, "usertest@test.com")
require.Equal(t, result.Password, "")
require.Len(t, result.Rands, 10)
require.Len(t, result.Salt, 10)
require.False(t, result.IsDisabled)
userStore.cfg.CaseInsensitiveLogin = false
})
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
result, err = userStore.GetByID(context.Background(), usr.ID)
require.Nil(t, err)
require.Equal(t, result.Email, "usertest@test.com")
require.Equal(t, result.Password, "")
require.Len(t, result.Rands, 10)
require.Len(t, result.Salt, 10)
require.False(t, result.IsDisabled)
result, err = userStore.GetByID(context.Background(), usr.ID)
require.Nil(t, err)
require.Equal(t, result.Email, "usertest@test.com")
require.Equal(t, result.Password, "")
require.Len(t, result.Rands, 10)
require.Len(t, result.Salt, 10)
require.False(t, result.IsDisabled)
ss.Cfg.CaseInsensitiveLogin = false
})
})
t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) {
if ss.GetDBType() == migrator.MySQL {
t.Skip("Skipping on MySQL due to case insensitive indexes")
}
userStore.cfg.CaseInsensitiveLogin = true
cmd := user.CreateUserCommand{
Email: "confusertest@test.com",
Name: "user name",
Login: "user_email_conflict",
}
// userEmailConflict
_, err := ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "confusertest@TEST.COM",
Name: "user name",
Login: "user_email_conflict_two",
}
_, err = ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "user_test_login_conflict@test.com",
Name: "user name",
Login: "user_test_login_conflict",
}
// userLoginConflict
_, err = ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
cmd = user.CreateUserCommand{
Email: "user_test_login_conflict_two@test.com",
Name: "user name",
Login: "user_test_login_CONFLICT",
}
_, err = ss.CreateUser(context.Background(), cmd)
require.NoError(t, err)
ss.Cfg.CaseInsensitiveLogin = true
t.Run("GetByEmail - email conflict", func(t *testing.T) {
query := user.GetUserByEmailQuery{Email: "confusertest@test.com"}
_, err = userStore.GetByEmail(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetByEmail - login conflict", func(t *testing.T) {
query := user.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"}
_, err = userStore.GetByEmail(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetByLogin - email conflict", func(t *testing.T) {
query := user.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"}
_, err = userStore.GetByLogin(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetByLogin - login conflict", func(t *testing.T) {
query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"}
_, err = userStore.GetByLogin(context.Background(), &query)
require.Error(t, err)
})
t.Run("GetByLogin - login conflict by email", func(t *testing.T) {
query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"}
_, err = userStore.GetByLogin(context.Background(), &query)
require.Error(t, err)
})
ss.Cfg.CaseInsensitiveLogin = false
})
}

View File

@ -152,24 +152,12 @@ func (s *Service) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*u
return user, nil
}
// TODO: remove wrapper around sqlstore
func (s *Service) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
q := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
err := s.sqlStore.GetUserByLogin(ctx, &q)
if err != nil {
return nil, err
}
return q.Result, nil
return s.store.GetByLogin(ctx, query)
}
// TODO: remove wrapper around sqlstore
func (s *Service) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
q := models.GetUserByEmailQuery{Email: query.Email}
err := s.sqlStore.GetUserByEmail(ctx, &q)
if err != nil {
return nil, err
}
return q.Result, nil
return s.store.GetByEmail(ctx, query)
}
// TODO: remove wrapper around sqlstore
@ -313,7 +301,7 @@ func (s *Service) SetUserHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlag
}
// TODO: remove wrapper around sqlstore
func (s *Service) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
func (s *Service) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
q := &models.GetUserProfileQuery{
UserId: query.UserID,
}

View File

@ -2,6 +2,7 @@ package userimpl
import (
"context"
"errors"
"testing"
"github.com/grafana/grafana/pkg/services/org/orgtest"
@ -77,6 +78,14 @@ func TestUserService(t *testing.T) {
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
require.NoError(t, err)
})
t.Run("GetByID - email conflict", func(t *testing.T) {
userService.cfg.CaseInsensitiveLogin = true
userStore.ExpectedError = errors.New("email conflict")
query := user.GetUserByIDQuery{}
_, err := userService.GetByID(context.Background(), &query)
require.Error(t, err)
})
}
type FakeUserStore struct {
@ -112,3 +121,11 @@ func (f *FakeUserStore) GetByID(context.Context, int64) (*user.User, error) {
func (f *FakeUserStore) CaseInsensitiveLoginConflict(context.Context, string, string) error {
return f.ExpectedError
}
func (f *FakeUserStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeUserStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}

View File

@ -91,6 +91,6 @@ func (f *FakeUserService) SetUserHelpFlag(ctx context.Context, cmd *user.SetUser
return f.ExpectedError
}
func (f *FakeUserService) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
func (f *FakeUserService) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) {
return f.ExpectedUSerProfileDTO, f.ExpectedError
}