User: support setting org and help flags though update function (#86535)

* User: Support setting active org through update function

* User: add support to update help flags through update function
This commit is contained in:
Karl Persson 2024-04-29 08:53:05 +02:00 committed by GitHub
parent 7077a5850e
commit c4cfee8d96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 179 additions and 373 deletions

View File

@ -338,7 +338,7 @@ func (hs *HTTPServer) applyUserInvite(ctx context.Context, usr *user.User, invit
if setActive {
// set org to active
if err := hs.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{OrgID: invite.OrgID, UserID: usr.ID}); err != nil {
if err := hs.userService.Update(ctx, &user.UpdateUserCommand{OrgID: &invite.OrgID, UserID: usr.ID}); err != nil {
return false, response.Error(http.StatusInternalServerError, "Failed to set org as active", err)
}
}

View File

@ -215,9 +215,7 @@ func (hs *HTTPServer) UpdateUserActiveOrg(c *contextmodel.ReqContext) response.R
return response.Error(http.StatusUnauthorized, "Not a valid organization", nil)
}
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: orgID}
if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil {
if err := hs.userService.Update(c.Req.Context(), &user.UpdateUserCommand{UserID: userID, OrgID: &orgID}); err != nil {
return response.Error(http.StatusInternalServerError, "Failed to change active organization", err)
}
@ -493,9 +491,7 @@ func (hs *HTTPServer) UserSetUsingOrg(c *contextmodel.ReqContext) response.Respo
return response.Error(http.StatusUnauthorized, "Not a valid organization", nil)
}
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: orgID}
if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil {
if err := hs.userService.Update(c.Req.Context(), &user.UpdateUserCommand{UserID: userID, OrgID: &orgID}); err != nil {
return response.Error(http.StatusInternalServerError, "Failed to change active organization", err)
}
@ -527,8 +523,7 @@ func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *contextmodel.ReqContex
return
}
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: orgID}
if err := hs.userService.SetUsingOrg(c.Req.Context(), &cmd); err != nil {
if err := hs.userService.Update(c.Req.Context(), &user.UpdateUserCommand{UserID: userID, OrgID: &orgID}); err != nil {
hs.NotFoundHandler(c)
return
}
@ -606,16 +601,11 @@ func (hs *HTTPServer) SetHelpFlag(c *contextmodel.ReqContext) response.Response
bitmask := &usr.HelpFlags1
bitmask.AddFlag(user.HelpFlags1(flag))
cmd := user.SetUserHelpFlagCommand{
UserID: userID,
HelpFlags1: *bitmask,
}
if err := hs.userService.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil {
if err := hs.userService.Update(c.Req.Context(), &user.UpdateUserCommand{UserID: userID, HelpFlags1: bitmask}); err != nil {
return response.Error(http.StatusInternalServerError, "Failed to update help flag", err)
}
return response.JSON(http.StatusOK, &util.DynMap{"message": "Help flag set", "helpFlags1": cmd.HelpFlags1})
return response.JSON(http.StatusOK, &util.DynMap{"message": "Help flag set", "helpFlags1": *bitmask})
}
// swagger:route GET /user/helpflags/clear signed_in_user clearHelpFlags
@ -633,16 +623,12 @@ func (hs *HTTPServer) ClearHelpFlags(c *contextmodel.ReqContext) response.Respon
return errResponse
}
cmd := user.SetUserHelpFlagCommand{
UserID: userID,
HelpFlags1: user.HelpFlags1(0),
}
if err := hs.userService.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil {
flags := user.HelpFlags1(0)
if err := hs.userService.Update(c.Req.Context(), &user.UpdateUserCommand{UserID: userID, HelpFlags1: &flags}); err != nil {
return response.Error(http.StatusInternalServerError, "Failed to update help flag", err)
}
return response.JSON(http.StatusOK, &util.DynMap{"message": "Help flag set", "helpFlags1": cmd.HelpFlags1})
return response.JSON(http.StatusOK, &util.DynMap{"message": "Help flag set", "helpFlags1": flags})
}
func getUserID(c *contextmodel.ReqContext) (int64, *response.NormalResponse) {

View File

@ -32,8 +32,7 @@ func OrgRedirect(cfg *setting.Cfg, userSvc user.Service) web.Handler {
return
}
cmd := user.SetUsingOrgCommand{UserID: ctx.UserID, OrgID: orgId}
if err := userSvc.SetUsingOrg(ctx.Req.Context(), &cmd); err != nil {
if err := userSvc.Update(ctx.Req.Context(), &user.UpdateUserCommand{UserID: ctx.UserID, OrgID: &orgId}); err != nil {
if ctx.IsApiRequest() {
ctx.JsonApiErr(404, "Not found", nil)
} else {

View File

@ -55,7 +55,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) {
sc.withIdentity(&authn.Identity{})
sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("")
sc.userService.ExpectedError = fmt.Errorf("")
sc.m.Get("/", sc.defaultHandler)
sc.fakeReq("GET", "/?orgId=1").exec()

View File

@ -120,9 +120,9 @@ func (s *OrgSync) SyncOrgRolesHook(ctx context.Context, id *authn.Identity, _ *a
if _, ok := id.OrgRoles[id.OrgID]; !ok {
if len(orgIDs) > 0 {
id.OrgID = orgIDs[0]
return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{
return s.userService.Update(ctx, &user.UpdateUserCommand{
UserID: userID,
OrgID: id.OrgID,
OrgID: &id.OrgID,
})
}
}
@ -159,8 +159,8 @@ func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.
return
}
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: s.cfg.LoginDefaultOrgId}
if svcErr := s.userService.SetUsingOrg(ctx, &cmd); svcErr != nil {
cmd := user.UpdateUserCommand{UserID: userID, OrgID: &s.cfg.LoginDefaultOrgId}
if svcErr := s.userService.Update(ctx, &cmd); svcErr != nil {
ctxLogger.Error("Failed to set default org", "id", currentIdentity.ID, "err", svcErr)
}
}

View File

@ -139,8 +139,8 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting: 2,
identity: &authn.Identity{ID: authn.MustParseNamespaceID("user:1")},
setupMock: func(userService *usertest.MockService, orgService *orgtest.FakeOrgService) {
userService.On("SetUsingOrg", mock.Anything, mock.MatchedBy(func(cmd *user.SetUsingOrgCommand) bool {
return cmd.UserID == 1 && cmd.OrgID == 2
userService.On("Update", mock.Anything, mock.MatchedBy(func(cmd *user.UpdateUserCommand) bool {
return cmd.UserID == 1 && *cmd.OrgID == 2
})).Return(nil)
},
},
@ -188,7 +188,7 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting: 2,
identity: &authn.Identity{ID: authn.MustParseNamespaceID("user:1")},
setupMock: func(userService *usertest.MockService, orgService *orgtest.FakeOrgService) {
userService.On("SetUsingOrg", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
userService.On("Update", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
},
},
}
@ -217,8 +217,6 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
}
s.SetDefaultOrgHook(context.Background(), tt.identity, nil, tt.inputErr)
userService.AssertExpectations(t)
})
}
}

View File

@ -39,7 +39,7 @@ type User struct {
Company string
EmailVerified bool
Theme string
HelpFlags1 HelpFlags1
HelpFlags1 HelpFlags1 `xorm:"help_flags1"`
IsDisabled bool
IsAdmin bool
@ -86,9 +86,13 @@ type UpdateUserCommand struct {
IsDisabled *bool `json:"-"`
EmailVerified *bool `json:"-"`
IsGrafanaAdmin *bool `json:"-"`
Password *Password `json:"-"`
// If password is included it will be validated, hashed and updated for user.
Password *Password `json:"-"`
// If old password is included it will be validated against users current password.
OldPassword *Password `json:"-"`
// If OrgID is included update current org for user
OrgID *int64 `json:"-"`
HelpFlags1 *HelpFlags1 `json:"-"`
}
type UpdateUserLastSeenAtCommand struct {
@ -96,11 +100,6 @@ type UpdateUserLastSeenAtCommand struct {
OrgID int64
}
type SetUsingOrgCommand struct {
UserID int64
OrgID int64
}
type SearchUsersQuery struct {
SignedInUser identity.Requester
OrgID int64 `xorm:"org_id"`
@ -179,11 +178,6 @@ type BatchDisableUsersCommand struct {
IsDisabled bool
}
type SetUserHelpFlagCommand struct {
HelpFlags1 HelpFlags1
UserID int64 `xorm:"user_id"`
}
type GetSignedInUserQuery struct {
UserID int64 `xorm:"user_id"`
Login string

View File

@ -17,12 +17,10 @@ type Service interface {
GetByEmail(context.Context, *GetUserByEmailQuery) (*User, error)
Update(context.Context, *UpdateUserCommand) error
UpdateLastSeenAt(context.Context, *UpdateUserLastSeenAtCommand) error
SetUsingOrg(context.Context, *SetUsingOrgCommand) error
GetSignedInUserWithCacheCtx(context.Context, *GetSignedInUserQuery) (*SignedInUser, error)
GetSignedInUser(context.Context, *GetSignedInUserQuery) (*SignedInUser, error)
Search(context.Context, *SearchUsersQuery) (*SearchUserQueryResult, error)
BatchDisableUsers(context.Context, *BatchDisableUsersCommand) error
SetUserHelpFlag(context.Context, *SetUserHelpFlagCommand) error
GetProfile(context.Context, *GetUserProfileQuery) (*UserProfileDTO, error)
}

View File

@ -19,20 +19,16 @@ import (
type store interface {
Insert(context.Context, *user.User) (int64, error)
Get(context.Context, *user.User) (*user.User, error)
GetByID(context.Context, int64) (*user.User, error)
GetNotServiceAccount(context.Context, int64) (*user.User, error)
GetByLogin(context.Context, *user.GetUserByLoginQuery) (*user.User, error)
GetByEmail(context.Context, *user.GetUserByEmailQuery) (*user.User, error)
Delete(context.Context, int64) error
LoginConflict(ctx context.Context, login, email string) error
CaseInsensitiveLoginConflict(context.Context, string, string) error
GetByLogin(context.Context, *user.GetUserByLoginQuery) (*user.User, error)
GetByEmail(context.Context, *user.GetUserByEmailQuery) (*user.User, error)
Update(context.Context, *user.UpdateUserCommand) error
UpdateLastSeenAt(context.Context, *user.UpdateUserLastSeenAtCommand) error
GetSignedInUser(context.Context, *user.GetSignedInUserQuery) (*user.SignedInUser, error)
UpdateUser(context.Context, *user.User) error
GetProfile(context.Context, *user.GetUserProfileQuery) (*user.UserProfileDTO, error)
SetHelpFlag(context.Context, *user.SetUserHelpFlagCommand) error
BatchDisableUsers(context.Context, *user.BatchDisableUsersCommand) error
Search(context.Context, *user.SearchUsersQuery) (*user.SearchUserQueryResult, error)
Count(ctx context.Context) (int64, error)
@ -87,30 +83,6 @@ func (ss *sqlStore) Insert(ctx context.Context, cmd *user.User) (int64, error) {
return cmd.ID, nil
}
func (ss *sqlStore) Get(ctx context.Context, usr *user.User) (*user.User, error) {
ret := &user.User{}
err := ss.db.WithDbSession(ctx, func(sess *db.Session) error {
// enforcement of lowercase due to forcement of caseinsensitive login
login := strings.ToLower(usr.Login)
email := strings.ToLower(usr.Email)
where := "email=? OR login=?"
exists, err := sess.Where(where, email, login).Get(ret)
if !exists {
return user.ErrUserNotFound
}
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return ret, nil
}
func (ss *sqlStore) Delete(ctx context.Context, userID int64) error {
err := ss.db.WithDbSession(ctx, func(sess *db.Session) error {
var rawSQL = "DELETE FROM " + ss.dialect.Quote("user") + " WHERE id = ?"
@ -123,21 +95,6 @@ func (ss *sqlStore) Delete(ctx context.Context, userID int64) error {
return nil
}
func (ss *sqlStore) GetNotServiceAccount(ctx context.Context, userID int64) (*user.User, error) {
usr := user.User{ID: userID}
err := ss.db.WithDbSession(ctx, func(sess *db.Session) error {
has, err := sess.Where(ss.notServiceAccountFilter()).Get(&usr)
if err != nil {
return err
}
if !has {
return user.ErrUserNotFound
}
return nil
})
return &usr, err
}
func (ss *sqlStore) GetByID(ctx context.Context, userID int64) (*user.User, error) {
var usr user.User
@ -254,12 +211,12 @@ func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQu
// sensitive.
func (ss *sqlStore) LoginConflict(ctx context.Context, login, email string) error {
err := ss.db.WithDbSession(ctx, func(sess *db.Session) error {
return ss.loginConflict(ctx, sess, login, email)
return ss.loginConflict(sess, login, email)
})
return err
}
func (ss *sqlStore) loginConflict(ctx context.Context, sess *db.Session, login, email string) error {
func (ss *sqlStore) loginConflict(sess *db.Session, login, email string) error {
users := make([]user.User, 0)
where := "LOWER(email)=LOWER(?) OR LOWER(login)=LOWER(?)"
login = strings.ToLower(login)
@ -289,52 +246,49 @@ func (ss *sqlStore) Update(ctx context.Context, cmd *user.UpdateUserCommand) err
cmd.Email = strings.ToLower(cmd.Email)
return ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
user := user.User{
usr := user.User{
Name: cmd.Name,
Theme: cmd.Theme,
Email: strings.ToLower(cmd.Email),
Login: strings.ToLower(cmd.Login),
Theme: cmd.Theme,
Updated: time.Now(),
}
q := sess.ID(cmd.UserID).Where(ss.notServiceAccountFilter())
if cmd.Password != nil {
user.Password = *cmd.Password
}
setOptional(cmd.OrgID, func(v int64) { usr.OrgID = v })
setOptional(cmd.Password, func(v user.Password) { usr.Password = v })
setOptional(cmd.IsDisabled, func(v bool) {
q = q.UseBool("is_disabled")
usr.IsDisabled = v
})
setOptional(cmd.EmailVerified, func(v bool) {
q = q.UseBool("email_verified")
usr.EmailVerified = v
})
setOptional(cmd.IsGrafanaAdmin, func(v bool) {
q = q.UseBool("is_admin")
usr.IsAdmin = v
})
setOptional(cmd.HelpFlags1, func(v user.HelpFlags1) { usr.HelpFlags1 = *cmd.HelpFlags1 })
if cmd.IsDisabled != nil {
sess.UseBool("is_disabled")
user.IsDisabled = *cmd.IsDisabled
}
if cmd.EmailVerified != nil {
q.UseBool("email_verified")
user.EmailVerified = *cmd.EmailVerified
}
if cmd.IsGrafanaAdmin != nil {
q.UseBool("is_admin")
user.IsAdmin = *cmd.IsGrafanaAdmin
}
if _, err := q.Update(&user); err != nil {
if _, err := q.Update(&usr); err != nil {
return err
}
if cmd.IsGrafanaAdmin != nil && !*cmd.IsGrafanaAdmin {
// validate that after update there is at least one server admin
if err := validateOneAdminLeft(ctx, sess); err != nil {
if err := validateOneAdminLeft(sess); err != nil {
return err
}
}
sess.PublishAfterCommit(&events.UserUpdated{
Timestamp: user.Created,
Id: user.ID,
Name: user.Name,
Login: user.Login,
Email: user.Email,
Timestamp: usr.Created,
Id: usr.ID,
Name: usr.Name,
Login: usr.Login,
Email: usr.Email,
})
return nil
@ -412,13 +366,6 @@ func (ss *sqlStore) GetSignedInUser(ctx context.Context, query *user.GetSignedIn
return &signedInUser, err
}
func (ss *sqlStore) UpdateUser(ctx context.Context, user *user.User) error {
return ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.ID(user.ID).Update(user)
return err
})
}
func (ss *sqlStore) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (*user.UserProfileDTO, error) {
var usr user.User
var userProfile user.UserProfileDTO
@ -450,19 +397,6 @@ func (ss *sqlStore) GetProfile(ctx context.Context, query *user.GetUserProfileQu
return &userProfile, err
}
func (ss *sqlStore) SetHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlagCommand) error {
return ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
user := user.User{
ID: cmd.UserID,
HelpFlags1: cmd.HelpFlags1,
Updated: time.Now(),
}
_, err := sess.ID(cmd.UserID).Cols("help_flags1").Update(&user)
return err
})
}
func (ss *sqlStore) Count(ctx context.Context) (int64, error) {
type result struct {
Count int64
@ -501,7 +435,7 @@ func (ss *sqlStore) CountUserAccountsWithEmptyRole(ctx context.Context) (int64,
}
// validateOneAdminLeft validate that there is an admin user left
func validateOneAdminLeft(ctx context.Context, sess *db.Session) error {
func validateOneAdminLeft(sess *db.Session) error {
count, err := sess.Where("is_admin=?", true).Count(&user.User{})
if err != nil {
return err
@ -673,3 +607,9 @@ func (ss *sqlStore) getAnyUserType(ctx context.Context, userID int64) (*user.Use
})
return &usr, err
}
func setOptional[T any](v *T, add func(v T)) {
if v != nil {
add(*v)
}
}

View File

@ -27,81 +27,6 @@ func TestMain(m *testing.M) {
testsuite.Run(m)
}
func TestIntegrationUserGet(t *testing.T) {
testCases := []struct {
name string
wantErr error
searchLogin string
searchEmail string
}{
{
name: "user found non exact",
wantErr: nil,
searchLogin: "test",
searchEmail: "Test@email.com",
},
{
name: "user found exact",
wantErr: nil,
searchLogin: "test",
searchEmail: "test@email.com",
},
{
name: "user found exact - case insensitive",
wantErr: nil,
searchLogin: "Test",
searchEmail: "Test@email.com",
},
{
name: "user not found - case insensitive",
wantErr: user.ErrUserNotFound,
searchLogin: "Test_login",
searchEmail: "Test*@email.com",
},
}
if testing.Short() {
t.Skip("skipping integration test")
}
ss, cfg := db.InitTestDBWithCfg(t)
userStore := ProvideStore(ss, cfg)
_, errUser := userStore.Insert(context.Background(),
&user.User{
Email: "test@email.com",
Name: "test",
Login: "test",
Created: time.Now(),
Updated: time.Now(),
},
)
require.NoError(t, errUser)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if db.IsTestDbMySQL() {
t.Skip("mysql is always case insensitive")
}
usr, err := userStore.Get(context.Background(),
&user.User{
Email: tc.searchEmail,
Login: tc.searchLogin,
},
)
if tc.wantErr != nil {
require.Error(t, err)
require.Nil(t, usr)
} else {
require.NoError(t, err)
require.NotNil(t, usr)
require.NotEmpty(t, usr.UID)
}
})
}
}
func TestIntegrationUserDataAccess(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
@ -120,12 +45,8 @@ func TestIntegrationUserDataAccess(t *testing.T) {
}
t.Run("user not found", func(t *testing.T) {
_, err := userStore.Get(context.Background(),
&user.User{
Email: "test@email.com",
Name: "test1",
Login: "test1",
},
_, err := userStore.GetByEmail(context.Background(),
&user.GetUserByEmailQuery{Email: "test@email.com"},
)
require.Error(t, err, user.ErrUserNotFound)
})
@ -143,6 +64,13 @@ func TestIntegrationUserDataAccess(t *testing.T) {
require.NoError(t, err)
})
t.Run("get user", func(t *testing.T) {
_, err := userStore.GetByEmail(context.Background(),
&user.GetUserByEmailQuery{Email: "test@email.com"},
)
require.NoError(t, err)
})
t.Run("insert user (with known UID)", func(t *testing.T) {
ctx := context.Background()
id, err := userStore.Insert(ctx,
@ -169,17 +97,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
require.Equal(t, "abcd", siu.UserUID)
})
t.Run("get user", func(t *testing.T) {
_, err := userStore.Get(context.Background(),
&user.User{
Email: "test@email.com",
Name: "test1",
Login: "test1",
},
)
require.NoError(t, err)
})
t.Run("Testing DB - creates and loads user", func(t *testing.T) {
ss := db.InitTestDB(t)
_, usrSvc := createOrgAndUserSvc(t, ss, cfg)
@ -458,15 +375,6 @@ func TestIntegrationUserDataAccess(t *testing.T) {
}
})
t.Run("update user", func(t *testing.T) {
err := userStore.UpdateUser(context.Background(), &user.User{ID: 1, Name: "testtestest", Login: "loginloginlogin"})
require.NoError(t, err)
result, err := userStore.GetByID(context.Background(), 1)
require.NoError(t, err)
assert.Equal(t, result.Name, "testtestest")
assert.Equal(t, result.Login, "loginloginlogin")
})
t.Run("Testing DB - grafana admin users", func(t *testing.T) {
ss := db.InitTestDB(t)
_, usrSvc := createOrgAndUserSvc(t, ss, cfg)
@ -483,7 +391,7 @@ func TestIntegrationUserDataAccess(t *testing.T) {
UserID: usr.ID,
IsGrafanaAdmin: boolPtr(false),
})
require.ErrorIs(t, user.ErrLastGrafanaAdmin, err)
require.ErrorIs(t, err, user.ErrLastGrafanaAdmin)
usr, err = userStore.GetByID(context.Background(), usr.ID)
require.NoError(t, err)
@ -518,9 +426,28 @@ func TestIntegrationUserDataAccess(t *testing.T) {
require.NoError(t, err)
})
t.Run("SetHelpFlag", func(t *testing.T) {
err := userStore.SetHelpFlag(context.Background(), &user.SetUserHelpFlagCommand{UserID: 1, HelpFlags1: user.HelpFlags1(1)})
t.Run("Update HelpFlags", func(t *testing.T) {
id, err := userStore.Insert(context.Background(), &user.User{
Email: "help@test.com",
Name: "help",
Login: "help",
Updated: time.Now(),
Created: time.Now(),
LastSeenAt: time.Now(),
})
require.NoError(t, err)
original, err := userStore.GetByID(context.Background(), id)
require.NoError(t, err)
helpflags := user.HelpFlags1(1)
err = userStore.Update(context.Background(), &user.UpdateUserCommand{UserID: id, HelpFlags1: &helpflags})
require.NoError(t, err)
got, err := userStore.GetByID(context.Background(), id)
require.NoError(t, err)
original.HelpFlags1 = helpflags
assertEqualUser(t, original, got)
})
t.Run("Testing DB - return list users based on their is_disabled flag", func(t *testing.T) {
@ -1013,6 +940,18 @@ func TestMetricsUsage(t *testing.T) {
})
}
func assertEqualUser(t *testing.T, expected, got *user.User) {
// zero out time fields
expected.Updated = time.Time{}
expected.Created = time.Time{}
expected.LastSeenAt = time.Time{}
got.Updated = time.Time{}
got.Created = time.Time{}
got.LastSeenAt = time.Time{}
assert.Equal(t, expected, got)
}
func createOrgAndUserSvc(t *testing.T, store db.DB, cfg *setting.Cfg) (org.Service, user.Service) {
t.Helper()

View File

@ -202,7 +202,7 @@ func (s *Service) Create(ctx context.Context, cmd *user.CreateUserCommand) (*use
}
func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error {
_, err := s.store.GetNotServiceAccount(ctx, cmd.UserID)
_, err := s.store.GetByID(ctx, cmd.UserID)
if err != nil {
return err
}
@ -251,6 +251,24 @@ func (s *Service) Update(ctx context.Context, cmd *user.UpdateUserCommand) error
cmd.Password = &hashed
}
if cmd.OrgID != nil {
orgs, err := s.orgService.GetUserOrgList(ctx, &org.GetUserOrgListQuery{UserID: cmd.UserID})
if err != nil {
return err
}
valid := false
for _, org := range orgs {
if org.OrgID == *cmd.OrgID {
valid = true
}
}
if !valid {
return fmt.Errorf("user does not belong to org")
}
}
return s.store.Update(ctx, cmd)
}
@ -275,28 +293,6 @@ func shouldUpdateLastSeen(t time.Time) bool {
return time.Since(t) > time.Minute*5
}
func (s *Service) SetUsingOrg(ctx context.Context, cmd *user.SetUsingOrgCommand) error {
getOrgsForUserCmd := &org.GetUserOrgListQuery{UserID: cmd.UserID}
orgsForUser, err := s.orgService.GetUserOrgList(ctx, getOrgsForUserCmd)
if err != nil {
return err
}
valid := false
for _, other := range orgsForUser {
if other.OrgID == cmd.OrgID {
valid = true
}
}
if !valid {
return fmt.Errorf("user does not belong to org")
}
return s.store.UpdateUser(ctx, &user.User{
ID: cmd.UserID,
OrgID: cmd.OrgID,
})
}
func (s *Service) GetSignedInUserWithCacheCtx(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) {
var signedInUser *user.SignedInUser
@ -350,10 +346,6 @@ func (s *Service) BatchDisableUsers(ctx context.Context, cmd *user.BatchDisableU
return s.store.BatchDisableUsers(ctx, cmd)
}
func (s *Service) SetUserHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlagCommand) error {
return s.store.SetHelpFlag(ctx, cmd)
}
func (s *Service) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (*user.UserProfileDTO, error) {
result, err := s.store.GetProfile(ctx, query)
return result, err

View File

@ -126,58 +126,82 @@ func TestUserService(t *testing.T) {
assert.Equal(t, query2.OrgID, result2.OrgID)
})
t.Run("Can set using org", func(t *testing.T) {
cmd := user.SetUsingOrgCommand{UserID: 2, OrgID: 1}
orgService.ExpectedUserOrgDTO = []*org.UserOrgDTO{{OrgID: 1}}
t.Run("SignedInUserQuery with a different org", func(t *testing.T) {
query := user.GetSignedInUserQuery{UserID: 2}
userStore.ExpectedSignedInUser = &user.SignedInUser{
OrgID: 1,
Email: "ac2@test.com",
Name: "ac2 name",
Login: "ac2",
OrgName: "ac1@test.com",
}
userStore.ExpectedError = nil
err := userService.SetUsingOrg(context.Background(), &cmd)
queryResult, err := userService.GetSignedInUser(context.Background(), &query)
require.NoError(t, err)
t.Run("SignedInUserQuery with a different org", func(t *testing.T) {
query := user.GetSignedInUserQuery{UserID: 2}
userStore.ExpectedSignedInUser = &user.SignedInUser{
OrgID: 1,
Email: "ac2@test.com",
Name: "ac2 name",
Login: "ac2",
OrgName: "ac1@test.com",
}
queryResult, err := userService.GetSignedInUser(context.Background(), &query)
require.NoError(t, err)
require.EqualValues(t, queryResult.OrgID, 1)
require.Equal(t, queryResult.Email, "ac2@test.com")
require.Equal(t, queryResult.Name, "ac2 name")
require.Equal(t, queryResult.Login, "ac2")
require.Equal(t, queryResult.OrgName, "ac1@test.com")
})
require.EqualValues(t, queryResult.OrgID, 1)
require.Equal(t, queryResult.Email, "ac2@test.com")
require.Equal(t, queryResult.Name, "ac2 name")
require.Equal(t, queryResult.Login, "ac2")
require.Equal(t, queryResult.OrgName, "ac1@test.com")
})
}
func TestService_Update(t *testing.T) {
t.Run("should return error if old password does not match stored password", func(t *testing.T) {
stored, err := user.Password("test").Hash("salt")
require.NoError(t, err)
service := &Service{store: &FakeUserStore{ExpectedUser: &user.User{Password: stored, Salt: "salt"}}}
setup := func(opts ...func(svc *Service)) *Service {
service := &Service{store: &FakeUserStore{}}
for _, o := range opts {
o(service)
}
return service
}
err = service.Update(context.Background(), &user.UpdateUserCommand{
OldPassword: passwordPtr("test123"),
t.Run("should return error if old password does not match stored password", func(t *testing.T) {
service := setup(func(svc *Service) {
stored, err := user.Password("test").Hash("salt")
require.NoError(t, err)
svc.store = &FakeUserStore{ExpectedUser: &user.User{Password: stored, Salt: "salt"}}
})
err := service.Update(context.Background(), &user.UpdateUserCommand{
OldPassword: passwordPtr("test123"),
})
assert.ErrorIs(t, err, user.ErrPasswordMissmatch)
})
t.Run("should return error new password is not valid", func(t *testing.T) {
stored, err := user.Password("test").Hash("salt")
require.NoError(t, err)
service := &Service{cfg: setting.NewCfg(), store: &FakeUserStore{ExpectedUser: &user.User{Password: stored, Salt: "salt"}}}
service := setup(func(svc *Service) {
stored, err := user.Password("test").Hash("salt")
require.NoError(t, err)
svc.cfg = setting.NewCfg()
svc.store = &FakeUserStore{ExpectedUser: &user.User{Password: stored, Salt: "salt"}}
})
err = service.Update(context.Background(), &user.UpdateUserCommand{
err := service.Update(context.Background(), &user.UpdateUserCommand{
OldPassword: passwordPtr("test"),
Password: passwordPtr("asd"),
})
require.ErrorIs(t, err, user.ErrPasswordTooShort)
})
t.Run("Can set using org", func(t *testing.T) {
orgID := int64(1)
service := setup(func(svc *Service) {
svc.orgService = &orgtest.FakeOrgService{ExpectedUserOrgDTO: []*org.UserOrgDTO{{OrgID: orgID}}}
})
err := service.Update(context.Background(), &user.UpdateUserCommand{UserID: 2, OrgID: &orgID})
require.NoError(t, err)
})
t.Run("Cannot set using org when user is not member of it", func(t *testing.T) {
orgID := int64(1)
service := setup(func(svc *Service) {
svc.orgService = &orgtest.FakeOrgService{ExpectedUserOrgDTO: []*org.UserOrgDTO{{OrgID: 2}}}
})
err := service.Update(context.Background(), &user.UpdateUserCommand{UserID: 2, OrgID: &orgID})
require.Error(t, err)
})
}
func TestMetrics(t *testing.T) {
@ -220,10 +244,6 @@ func newUserStoreFake() *FakeUserStore {
return &FakeUserStore{}
}
func (f *FakeUserStore) Get(ctx context.Context, query *user.User) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeUserStore) Insert(ctx context.Context, query *user.User) (int64, error) {
return 0, f.ExpectedError
}
@ -232,10 +252,6 @@ func (f *FakeUserStore) Delete(ctx context.Context, userID int64) error {
return f.ExpectedDeleteUserError
}
func (f *FakeUserStore) GetNotServiceAccount(ctx context.Context, userID int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeUserStore) GetByID(context.Context, int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
@ -268,22 +284,10 @@ func (f *FakeUserStore) GetSignedInUser(ctx context.Context, query *user.GetSign
return f.ExpectedSignedInUser, f.ExpectedError
}
func (f *FakeUserStore) UpdateUser(ctx context.Context, user *user.User) error {
return f.ExpectedError
}
func (f *FakeUserStore) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (*user.UserProfileDTO, error) {
return f.ExpectedUserProfile, f.ExpectedError
}
func (f *FakeUserStore) SetHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlagCommand) error {
return f.ExpectedError
}
func (f *FakeUserStore) UpdatePermissions(ctx context.Context, userID int64, isAdmin bool) error {
return f.ExpectedError
}
func (f *FakeUserStore) BatchDisableUsers(ctx context.Context, cmd *user.BatchDisableUsersCommand) error {
return f.ExpectedError
}

View File

@ -71,10 +71,6 @@ func (f *FakeUserService) UpdateLastSeenAt(ctx context.Context, cmd *user.Update
return f.ExpectedError
}
func (f *FakeUserService) SetUsingOrg(ctx context.Context, cmd *user.SetUsingOrgCommand) error {
return f.ExpectedSetUsingOrgError
}
func (f *FakeUserService) GetSignedInUserWithCacheCtx(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) {
return f.GetSignedInUser(ctx, query)
}
@ -104,10 +100,6 @@ func (f *FakeUserService) BatchDisableUsers(ctx context.Context, cmd *user.Batch
return f.ExpectedError
}
func (f *FakeUserService) SetUserHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlagCommand) error {
return f.ExpectedError
}
func (f *FakeUserService) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (*user.UserProfileDTO, error) {
if f.ExpectedUserProfileDTO != nil {
return f.ExpectedUserProfileDTO, f.ExpectedError

View File

@ -340,42 +340,6 @@ func (_m *MockService) Search(_a0 context.Context, _a1 *user.SearchUsersQuery) (
return r0, r1
}
// SetUserHelpFlag provides a mock function with given fields: _a0, _a1
func (_m *MockService) SetUserHelpFlag(_a0 context.Context, _a1 *user.SetUserHelpFlagCommand) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for SetUserHelpFlag")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *user.SetUserHelpFlagCommand) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// SetUsingOrg provides a mock function with given fields: _a0, _a1
func (_m *MockService) SetUsingOrg(_a0 context.Context, _a1 *user.SetUsingOrgCommand) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for SetUsingOrg")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *user.SetUsingOrgCommand) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// Update provides a mock function with given fields: _a0, _a1
func (_m *MockService) Update(_a0 context.Context, _a1 *user.UpdateUserCommand) error {
ret := _m.Called(_a0, _a1)